use std::sync::Arc;
use http::Request;
use crate::{
Body,
client::layer::retry::{Action, Classifier, ClassifyFn, ReqRep, ScopeFn, Scoped},
};
#[derive(Clone)]
pub struct Policy {
pub(crate) budget: Option<f32>,
pub(crate) classifier: Classifier,
pub(crate) max_retries_per_request: u32,
pub(crate) scope: Scoped,
}
impl Policy {
#[inline]
pub fn never() -> Policy {
Self::scoped(|_| false).no_budget()
}
#[inline]
pub fn for_host<S>(host: S) -> Policy
where
S: for<'a> PartialEq<&'a str> + Send + Sync + 'static,
{
Self::scoped(move |req| {
req.uri()
.host()
.is_some_and(|request_host| host == request_host)
})
}
#[inline]
fn scoped<F>(func: F) -> Policy
where
F: Fn(&Request<Body>) -> bool + Send + Sync + 'static,
{
Self {
budget: Some(0.2),
classifier: Classifier::Never,
max_retries_per_request: 2,
scope: Scoped::Dyn(Arc::new(ScopeFn(func))),
}
}
#[inline]
pub fn no_budget(mut self) -> Self {
self.budget = None;
self
}
#[inline]
pub fn max_extra_load(mut self, extra_percent: f32) -> Self {
assert!(extra_percent >= 0.0);
assert!(extra_percent <= 1000.0);
self.budget = Some(extra_percent);
self
}
#[inline]
pub fn max_retries_per_request(mut self, max: u32) -> Self {
self.max_retries_per_request = max;
self
}
#[inline]
pub fn classify_fn<F>(mut self, func: F) -> Self
where
F: Fn(ReqRep<'_>) -> Action + Send + Sync + 'static,
{
self.classifier = Classifier::Dyn(Arc::new(ClassifyFn(func)));
self
}
}
impl Default for Policy {
fn default() -> Self {
Self {
budget: None,
classifier: Classifier::ProtocolNacks,
max_retries_per_request: 2,
scope: Scoped::Unscoped,
}
}
}