use crate::{Error, Response, Url};
use budget::Budget;
use std::{sync::Arc, time::Duration};
pub(crate) use classify::Action;
use classify::ReqRep;
type ClassifyFn = Arc<dyn for<'a> Fn(ReqRep<'a>) -> Action + Send + Sync>;
type ScopeFn = Arc<dyn Fn(&Url, &http::Method) -> bool + Send + Sync>;
pub fn for_host<S>(host: S) -> Builder
where
S: for<'a> PartialEq<&'a str> + Send + Sync + 'static,
{
scoped(move |url: &Url, _| host == url.host_str().unwrap_or(""))
}
pub fn never() -> Builder {
scoped(|_, _| false).no_budget()
}
fn scoped<F>(func: F) -> Builder
where
F: Fn(&Url, &http::Method) -> bool + Send + Sync + 'static,
{
Builder::scoped(func)
}
pub struct Builder {
budget: Option<f32>,
classify: Option<ClassifyFn>,
max_retries_per_request: u32,
scope: Option<ScopeFn>,
}
impl Builder {
pub(crate) fn scoped(
scope: impl Fn(&Url, &http::Method) -> bool + Send + Sync + 'static,
) -> Self {
Self {
budget: Some(0.2),
classify: None,
max_retries_per_request: 2,
scope: Some(Arc::new(scope)),
}
}
#[must_use]
pub fn no_budget(mut self) -> Self {
self.budget = None;
self
}
#[must_use]
pub fn max_extra_load(mut self, extra_percent: f32) -> Self {
assert!(extra_percent >= 0.0);
assert!(extra_percent <= 1000.0);
self.budget = Some(extra_percent);
self
}
#[must_use]
pub fn max_retries_per_request(mut self, max: u32) -> Self {
self.max_retries_per_request = max;
self
}
#[must_use]
pub fn classify_fn<F>(mut self, func: F) -> Self
where
F: Fn(ReqRep<'_>) -> Action + Send + Sync + 'static,
{
self.classify = Some(Arc::new(func));
self
}
pub(crate) fn default() -> Builder {
Self {
budget: None,
classify: Some(Arc::new(|rr| {
if rr.is_protocol_nack() {
Action::Retryable
} else {
Action::Success
}
})),
max_retries_per_request: 2,
scope: None,
}
}
pub(crate) fn into_policy(self) -> Policy {
let budget = self
.budget
.map(|p| Arc::new(Budget::new(Duration::from_secs(10), 10, p)));
Policy {
budget,
classify: self.classify,
max_retries_per_request: self.max_retries_per_request,
scope: self.scope,
}
}
}
impl std::fmt::Debug for Builder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Builder")
.field("budget", &self.budget)
.field("max_retries_per_request", &self.max_retries_per_request)
.finish()
}
}
#[derive(Clone)]
pub(crate) struct Policy {
budget: Option<Arc<Budget>>,
classify: Option<ClassifyFn>,
max_retries_per_request: u32,
scope: Option<ScopeFn>,
}
impl Policy {
pub(crate) fn classify_result(
&self,
url: &Url,
method: &http::Method,
result: &Result<Response, Error>,
) -> Action {
if let Some(ref scope) = self.scope
&& !scope(url, method)
{
return Action::Success;
}
let Some(ref classify) = self.classify else {
return Action::Success;
};
let Ok(rr) = ReqRep::new(url, method, result) else {
return Action::Success;
};
classify(rr)
}
pub(crate) fn deposit(&self) {
if let Some(ref budget) = self.budget {
budget.deposit();
}
}
pub(crate) fn can_withdraw(&self) -> bool {
self.budget.as_ref().is_none_or(|b| b.withdraw())
}
pub(crate) fn max_retries(&self) -> u32 {
self.max_retries_per_request
}
}
impl std::fmt::Debug for Policy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Policy")
.field("budget", &self.budget)
.field("max_retries_per_request", &self.max_retries_per_request)
.finish()
}
}
mod budget {
use std::sync::Mutex;
use std::time::{Duration, Instant};
pub(super) struct Budget {
state: Mutex<BudgetState>,
reserve: isize,
slots: usize,
slot_duration: Duration,
pub(super) deposit_amount: isize,
pub(super) withdraw_amount: isize,
}
struct BudgetState {
buckets: Vec<isize>,
writer: isize,
gen_index: usize,
gen_time: Instant,
}
impl Budget {
pub(super) fn new(ttl: Duration, min_per_sec: u32, retry_percent: f32) -> Self {
assert!(ttl >= Duration::from_secs(1));
assert!(ttl <= Duration::from_secs(60));
assert!(retry_percent >= 0.0);
assert!(retry_percent <= 1000.0);
let (deposit_amount, withdraw_amount) = if retry_percent == 0.0 {
(0isize, 1isize)
} else if retry_percent <= 1.0 {
(1, (1.0 / retry_percent) as isize)
} else {
(1000, (1000.0 / retry_percent) as isize)
};
let reserve = (min_per_sec as isize)
.saturating_mul(ttl.as_secs() as isize)
.saturating_mul(withdraw_amount);
let num_slots: usize = 10;
let slot_duration = ttl / num_slots as u32;
Budget {
state: Mutex::new(BudgetState {
buckets: vec![0isize; num_slots],
writer: 0,
gen_index: 0,
gen_time: Instant::now(),
}),
reserve,
slots: num_slots,
slot_duration,
deposit_amount,
withdraw_amount,
}
}
pub(super) fn deposit(&self) {
let mut state = self.state.lock().unwrap_or_else(|e| e.into_inner());
self.advance(&mut state);
state.writer += self.deposit_amount;
}
pub(super) fn withdraw(&self) -> bool {
let mut state = self.state.lock().unwrap_or_else(|e| e.into_inner());
self.advance(&mut state);
let sum = self.sum(&state);
if sum >= self.withdraw_amount {
state.writer -= self.withdraw_amount;
true
} else {
false
}
}
fn sum(&self, state: &BudgetState) -> isize {
let windowed: isize = state.buckets.iter().copied().fold(0, isize::saturating_add);
state
.writer
.saturating_add(windowed)
.saturating_add(self.reserve)
}
fn advance(&self, state: &mut BudgetState) {
let now = Instant::now();
let elapsed = now.duration_since(state.gen_time);
if elapsed < self.slot_duration {
return;
}
let committed = std::mem::take(&mut state.writer);
state.buckets[state.gen_index] = committed;
let mut remaining = elapsed;
let mut idx = (state.gen_index + 1) % self.slots;
while remaining > self.slot_duration {
state.buckets[idx] = 0;
remaining -= self.slot_duration;
idx = (idx + 1) % self.slots;
}
state.gen_index = idx;
state.gen_time = now;
}
}
impl std::fmt::Debug for Budget {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Budget")
.field("deposit_amount", &self.deposit_amount)
.field("withdraw_amount", &self.withdraw_amount)
.field("reserve", &self.reserve)
.finish()
}
}
}
mod classify {
use super::{Error, Response, Url};
#[derive(Debug)]
pub struct ReqRep<'a> {
uri: http::Uri,
method: &'a http::Method,
result: Result<http::StatusCode, &'a Error>,
}
impl<'a> ReqRep<'a> {
pub(super) fn new(
url: &Url,
method: &'a http::Method,
result: &'a Result<Response, Error>,
) -> Result<Self, http::Error> {
Ok(Self {
uri: url.to_http_uri()?,
method,
result: match result {
Ok(resp) => Ok(resp.status()),
Err(e) => Err(e),
},
})
}
pub fn method(&self) -> &http::Method {
self.method
}
pub fn uri(&self) -> &http::Uri {
&self.uri
}
pub fn status(&self) -> Option<http::StatusCode> {
self.result.ok()
}
pub fn error(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.result.as_ref().err().map(|e| &**e as _)
}
pub fn retryable(self) -> Action {
Action::Retryable
}
pub fn success(self) -> Action {
Action::Success
}
pub(super) fn is_protocol_nack(&self) -> bool {
self.result
.as_ref()
.err()
.map(|e| e.is_connection_reset())
.unwrap_or(false)
}
}
#[must_use]
#[derive(Debug, PartialEq)]
pub enum Action {
Success,
Retryable,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn budget_scaling_table() {
let cases: &[(f32, isize, isize, &str)] = &[
(0.0, 0, 1, "zero percent: deposit=0, withdraw=1"),
(0.2, 1, 5, "20%: 5 deposits per retry"),
(0.5, 1, 2, "50%: 2 deposits per retry"),
(1.0, 1, 1, "100%: 1:1 ratio"),
(2.5, 1000, 400, "250%: high-precision scaling"),
(10.0, 1000, 100, "1000%: high-precision scaling"),
];
for &(pct, exp_deposit, exp_withdraw, desc) in cases {
let b = Budget::new(Duration::from_secs(10), 0, pct);
assert_eq!(b.deposit_amount, exp_deposit, "{desc}: deposit_amount");
assert_eq!(b.withdraw_amount, exp_withdraw, "{desc}: withdraw_amount");
}
}
#[test]
fn budget_reserve_table() {
let cases: &[(u32, u64, f32, u32, u32, &str)] = &[
(5, 10, 0.2, 10, 52, "standard budget with deposits"),
(0, 10, 1.0, 5, 5, "no reserve, 1:1 deposits"),
(3, 10, 0.0, 0, 30, "reserve only, zero percent"),
(0, 10, 0.5, 0, 0, "no reserve, no deposits"),
];
for &(min_per_sec, ttl, pct, deposits, expected, desc) in cases {
let budget = Budget::new(Duration::from_secs(ttl), min_per_sec, pct);
for _ in 0..deposits {
budget.deposit();
}
let mut count = 0u32;
while budget.withdraw() {
count += 1;
if count > 10_000 {
break;
}
}
assert_eq!(count, expected, "{desc}");
}
}
#[test]
fn scope_table() {
let policy = for_host("example.com".to_string()).into_policy();
let scope = policy.scope.as_ref().expect("for_host should set a scope");
let cases: &[(&str, bool, &str)] = &[
("https://example.com/test", true, "exact host match"),
("https://example.com/a/b/c", true, "deep path on matching host"),
("https://other.com/test", false, "different host"),
("https://sub.example.com/", false, "subdomain is not exact match"),
("http://example.com/", true, "http scheme, same host"),
];
for &(url_str, expected, desc) in cases {
let url: Url = url_str.parse().unwrap();
assert_eq!(scope(&url, &http::Method::GET), expected, "{desc}");
}
}
#[test]
fn max_retries_table() {
let cases: &[(u32, &str)] = &[
(0, "zero means no retries"),
(1, "one retry"),
(2, "default value"),
(5, "custom higher value"),
(100, "large value"),
];
for &(max, desc) in cases {
let policy = for_host("x")
.max_retries_per_request(max)
.no_budget()
.into_policy();
assert_eq!(policy.max_retries(), max, "{desc}");
}
}
#[test]
fn classify_result_table() {
let policy_503 = for_host("example.com".to_string())
.no_budget()
.classify_fn(|rr| {
if rr.status() == Some(http::StatusCode::SERVICE_UNAVAILABLE) {
rr.retryable()
} else {
rr.success()
}
})
.into_policy();
let policy_never = never().into_policy();
let policy_default = Builder::default().into_policy();
let policy_no_classify = for_host("example.com".to_string())
.no_budget()
.into_policy();
type RequestResult = Result<Response, Error>;
let err: RequestResult = Err(Error::builder("test"));
let conn_reset_err: RequestResult =
Err(Error::request(std::io::Error::from(std::io::ErrorKind::ConnectionReset)));
let resp_503: RequestResult = Ok(Response::synthetic(
http::StatusCode::SERVICE_UNAVAILABLE,
"https://example.com/api",
));
let cases: &[(&Policy, &str, &RequestResult, Action, &str)] = &[
(
&policy_503,
"https://example.com/api",
&resp_503,
Action::Retryable,
"503-policy: 503 response IS retryable",
),
(
&policy_503,
"https://example.com/api",
&err,
Action::Success,
"503-policy: error is not a 503",
),
(
&policy_503,
"https://other.com/api",
&err,
Action::Success,
"503-policy: out of scope",
),
(&policy_never, "https://example.com/", &err, Action::Success, "never: always success"),
(
&policy_default,
"https://example.com/",
&err,
Action::Success,
"default: builder error is not a connection reset",
),
(
&policy_default,
"https://example.com/",
&conn_reset_err,
Action::Retryable,
"default: connection reset IS retryable",
),
(
&policy_no_classify,
"https://example.com/",
&err,
Action::Success,
"no-classify: scope matches but no classify fn → success",
),
];
for (policy, url_str, result, expected, desc) in cases {
let url: Url = url_str.parse().unwrap();
let action = policy.classify_result(&url, &http::Method::GET, result);
assert_eq!(action, *expected, "{desc}");
}
}
#[test]
fn budget_multi_slot_advance() {
let budget = Budget::new(Duration::from_secs(1), 0, 1.0);
for _ in 0..5 {
budget.deposit();
}
std::thread::sleep(Duration::from_millis(350));
budget.deposit();
let mut count = 0u32;
while budget.withdraw() {
count += 1;
if count > 100 {
break;
}
}
assert!(count >= 1, "expected at least 1 withdrawal after multi-slot advance, got {count}");
}
#[test]
fn debug_table() {
let cases: &[(&dyn std::fmt::Debug, &str, &str)] = &[
(&for_host("example.com"), "Builder", "Builder debug"),
(&Action::Success, "Success", "Action::Success"),
(&Action::Retryable, "Retryable", "Action::Retryable"),
(&Builder::default().into_policy(), "Policy", "Policy debug"),
(&Budget::new(Duration::from_secs(10), 5, 0.2), "Budget", "Budget debug"),
];
for &(val, needle, desc) in cases {
let s = format!("{val:?}");
assert!(s.contains(needle), "{desc}: expected {needle:?} in {s:?}");
}
}
#[test]
fn classify_fn_exercises_accessors() {
let policy = for_host("example.com".to_string())
.no_budget()
.classify_fn(|rr| {
let _method = rr.method();
let _uri = rr.uri();
let _status = rr.status();
let _err = rr.error();
rr.retryable()
})
.into_policy();
let url: Url = "https://example.com/api".parse().unwrap();
let result: Result<Response, Error> = Err(Error::builder("test"));
assert_eq!(policy.classify_result(&url, &http::Method::POST, &result), Action::Retryable,);
}
#[test]
fn no_budget() {
let policy = for_host("x").no_budget().into_policy();
for _ in 0..100 {
assert!(policy.can_withdraw());
}
}
#[test]
fn max_extra_load() {
let policy = for_host("x").max_extra_load(0.5).into_policy();
for _ in 0..20 {
policy.deposit();
}
let mut count = 0;
while policy.can_withdraw() {
count += 1;
if count > 200 {
break;
}
}
assert_eq!(count, 110, "expected 110 withdrawals, got {count}");
}
#[test]
#[should_panic]
fn max_extra_load_panics_on_negative() {
let _ = for_host("x").max_extra_load(-1.0);
}
}