use std::sync::Arc;
use std::time::Duration;
use tower::retry::budget::{Budget as _, TpsBudget as Budget};
#[derive(Debug)]
pub struct Builder {
budget: Option<f32>,
classifier: classify::Classifier,
max_retries_per_request: u32,
scope: scope::Scoped,
}
#[derive(Clone, Debug)]
pub(crate) struct Policy {
budget: Option<Arc<Budget>>,
classifier: classify::Classifier,
max_retries_per_request: u32,
retry_cnt: u32,
scope: scope::Scoped,
}
pub fn for_host<S>(host: S) -> Builder
where
S: for<'a> PartialEq<&'a str> + Send + Sync + 'static,
{
scoped(move |req| host == req.uri().host().unwrap_or(""))
}
pub fn never() -> Builder {
scoped(|_| false).no_budget()
}
fn scoped<F>(func: F) -> Builder
where
F: Fn(&Req) -> bool + Send + Sync + 'static,
{
Builder::scoped(scope::ScopeFn(func))
}
impl Builder {
pub fn scoped(scope: impl scope::Scope) -> Self {
Self {
budget: Some(0.2),
classifier: classify::Classifier::Never,
max_retries_per_request: 2, scope: scope::Scoped::Dyn(Arc::new(scope)),
}
}
pub fn no_budget(mut self) -> Self {
self.budget = None;
self
}
pub fn max_extra_load(mut self, extra_percent: f32) -> Self {
assert!(extra_percent >= 0.0);
assert!(extra_percent <= 1000.0);
self.budget = Some(extra_percent);
self
}
pub fn max_retries_per_request(mut self, max: u32) -> Self {
self.max_retries_per_request = max;
self
}
pub fn classify_fn<F>(self, func: F) -> Self
where
F: Fn(classify::ReqRep<'_>) -> classify::Action + Send + Sync + 'static,
{
self.classify(classify::ClassifyFn(func))
}
pub fn classify(mut self, classifier: impl classify::Classify) -> Self {
self.classifier = classify::Classifier::Dyn(Arc::new(classifier));
self
}
pub(crate) fn default() -> Builder {
Self {
budget: None,
classifier: classify::Classifier::ProtocolNacks,
max_retries_per_request: 2, scope: scope::Scoped::Unscoped,
}
}
pub(crate) fn into_policy(self) -> Policy {
let budget = self
.budget
.map(|p| Arc::new(Budget::new(Duration::from_secs(10), 10, p)));
Policy {
budget,
classifier: self.classifier,
max_retries_per_request: self.max_retries_per_request,
retry_cnt: 0,
scope: self.scope,
}
}
}
type Req = http::Request<crate::async_impl::body::Body>;
impl<B> tower::retry::Policy<Req, http::Response<B>, crate::Error> for Policy {
type Future = std::future::Ready<()>;
fn retry(
&mut self,
req: &mut Req,
result: &mut crate::Result<http::Response<B>>,
) -> Option<Self::Future> {
match self.classifier.classify(req, result) {
classify::Action::Success => {
log::trace!("shouldn't retry!");
if let Some(ref budget) = self.budget {
budget.deposit();
}
None
}
classify::Action::Retryable => {
log::trace!("could retry!");
if self.budget.as_ref().map(|b| b.withdraw()).unwrap_or(true) {
self.retry_cnt += 1;
Some(std::future::ready(()))
} else {
log::debug!("retryable but could not withdraw from budget");
None
}
}
}
}
fn clone_request(&mut self, req: &Req) -> Option<Req> {
if self.retry_cnt > 0 && !self.scope.applies_to(req) {
return None;
}
if self.retry_cnt >= self.max_retries_per_request {
log::trace!("max_retries_per_request hit");
return None;
}
let body = req.body().try_clone()?;
let mut new = http::Request::new(body);
*new.method_mut() = req.method().clone();
*new.uri_mut() = req.uri().clone();
*new.version_mut() = req.version();
*new.headers_mut() = req.headers().clone();
*new.extensions_mut() = req.extensions().clone();
Some(new)
}
}
fn is_retryable_error(err: &crate::Error) -> bool {
use std::error::Error as _;
let err = if let Some(err) = err.source() {
err
} else {
return false;
};
let err = if let Some(err) = err.source() {
err
} else {
return false;
};
#[cfg(not(any(feature = "http3", feature = "http2")))]
let _err = err;
#[cfg(feature = "http3")]
if let Some(cause) = err.source() {
if let Some(err) = cause.downcast_ref::<h3::error::ConnectionError>() {
log::trace!("determining if HTTP/3 error {err} can be retried");
return err.to_string().as_str() == "timeout";
}
}
#[cfg(feature = "http2")]
if let Some(cause) = err.source() {
if let Some(err) = cause.downcast_ref::<h2::Error>() {
if err.is_go_away() && err.is_remote() && err.reason() == Some(h2::Reason::NO_ERROR) {
return true;
}
if err.is_reset() && err.is_remote() && err.reason() == Some(h2::Reason::REFUSED_STREAM)
{
return true;
}
}
}
false
}
mod scope {
pub trait Scope: Send + Sync + 'static {
fn applies_to(&self, req: &super::Req) -> bool;
}
pub struct ScopeFn<F>(pub(super) F);
impl<F> Scope for ScopeFn<F>
where
F: Fn(&super::Req) -> bool + Send + Sync + 'static,
{
fn applies_to(&self, req: &super::Req) -> bool {
(self.0)(req)
}
}
#[derive(Clone)]
pub(super) enum Scoped {
Unscoped,
Dyn(std::sync::Arc<dyn Scope>),
}
impl Scoped {
pub(super) fn applies_to(&self, req: &super::Req) -> bool {
let ret = match self {
Self::Unscoped => true,
Self::Dyn(s) => s.applies_to(req),
};
log::trace!("retry in scope: {ret}");
ret
}
}
impl std::fmt::Debug for Scoped {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Unscoped => f.write_str("Unscoped"),
Self::Dyn(_) => f.write_str("Scoped"),
}
}
}
}
mod classify {
pub trait Classify: Send + Sync + 'static {
fn classify(&self, req_rep: ReqRep<'_>) -> Action;
}
pub struct ClassifyFn<F>(pub(super) F);
impl<F> Classify for ClassifyFn<F>
where
F: Fn(ReqRep<'_>) -> Action + Send + Sync + 'static,
{
fn classify(&self, req_rep: ReqRep<'_>) -> Action {
(self.0)(req_rep)
}
}
#[derive(Debug)]
pub struct ReqRep<'a>(&'a super::Req, Result<http::StatusCode, &'a crate::Error>);
impl ReqRep<'_> {
pub fn method(&self) -> &http::Method {
self.0.method()
}
pub fn uri(&self) -> &http::Uri {
self.0.uri()
}
pub fn status(&self) -> Option<http::StatusCode> {
self.1.ok()
}
pub fn error(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.1.as_ref().err().map(|e| &**e as _)
}
pub fn retryable(self) -> Action {
Action::Retryable
}
pub fn success(self) -> Action {
Action::Success
}
fn is_protocol_nack(&self) -> bool {
self.1
.as_ref()
.err()
.map(|&e| super::is_retryable_error(e))
.unwrap_or(false)
}
}
#[must_use]
#[derive(Debug)]
pub enum Action {
Success,
Retryable,
}
#[derive(Clone)]
pub(super) enum Classifier {
Never,
ProtocolNacks,
Dyn(std::sync::Arc<dyn Classify>),
}
impl Classifier {
pub(super) fn classify<B>(
&self,
req: &super::Req,
res: &Result<http::Response<B>, crate::Error>,
) -> Action {
let req_rep = ReqRep(req, res.as_ref().map(|r| r.status()));
match self {
Self::Never => Action::Success,
Self::ProtocolNacks => {
if req_rep.is_protocol_nack() {
Action::Retryable
} else {
Action::Success
}
}
Self::Dyn(c) => c.classify(req_rep),
}
}
}
impl std::fmt::Debug for Classifier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Never => f.write_str("Never"),
Self::ProtocolNacks => f.write_str("ProtocolNacks"),
Self::Dyn(_) => f.write_str("Classifier"),
}
}
}
}