#![forbid(unsafe_code)]
#![deny(
clippy::dbg_macro,
missing_copy_implementations,
rustdoc::missing_crate_level_docs,
missing_debug_implementations,
missing_docs,
nonstandard_style,
unused_qualifications
)]
#[cfg(doctest)]
#[doc = include_str!("../README.md")]
mod readme {}
mod backoff;
use backoff::{Backoff, Kind};
use std::{
borrow::Cow,
fmt,
sync::Arc,
time::{Duration, Instant},
};
use trillium_client::{
Body, ClientHandler, Conn, ConnExt,
KnownHeaderName::{Connection, ContentLength, Expect, Host, RetryAfter, TransferEncoding},
Method, Result, Status,
};
fn is_idempotent(method: Method) -> bool {
matches!(
method,
Method::Get | Method::Head | Method::Put | Method::Delete | Method::Options | Method::Trace
)
}
type Predicate = Arc<dyn Fn(&Conn) -> bool + Send + Sync>;
type Decision = Arc<dyn Fn(&Conn, u32) -> Option<Duration> + Send + Sync>;
#[derive(Clone)]
pub struct RetryHandler {
backoff: Backoff,
max_attempts: u32,
max_elapsed: Duration,
statuses: Arc<[Status]>,
all_methods: bool,
transport_errors: bool,
honor_retry_after: bool,
max_retry_after: Option<Duration>,
predicate: Option<Predicate>,
decision: Option<Decision>,
}
impl Default for RetryHandler {
fn default() -> Self {
Self {
backoff: Backoff::default(),
max_attempts: 4,
max_elapsed: Duration::from_secs(30),
statuses: Arc::from([Status::TooManyRequests, Status::ServiceUnavailable].as_slice()),
all_methods: false,
transport_errors: true,
honor_retry_after: true,
max_retry_after: None,
predicate: None,
decision: None,
}
}
}
impl RetryHandler {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_constant_backoff(mut self, delay: Duration) -> Self {
self.backoff.kind = Kind::Constant(delay);
self
}
#[must_use]
pub fn with_linear_backoff(mut self, step: Duration) -> Self {
self.backoff.kind = Kind::Linear(step);
self
}
#[must_use]
pub fn with_exponential_backoff(mut self, base: Duration) -> Self {
self.backoff.kind = Kind::Exponential(base);
self
}
#[must_use]
pub fn with_custom_backoff(
mut self,
f: impl Fn(u32, &Conn) -> Duration + Send + Sync + 'static,
) -> Self {
self.backoff.kind = Kind::Custom(Arc::new(f));
self
}
#[must_use]
pub fn with_max_delay(mut self, max: Duration) -> Self {
self.backoff.max_delay = Some(max);
self
}
#[must_use]
pub fn without_jitter(mut self) -> Self {
self.backoff.jitter = backoff::Jitter::None;
self
}
#[must_use]
pub fn with_max_attempts(mut self, max_attempts: u32) -> Self {
self.max_attempts = max_attempts;
self
}
#[must_use]
pub fn with_max_elapsed(mut self, max_elapsed: Duration) -> Self {
self.max_elapsed = max_elapsed;
self
}
#[must_use]
pub fn with_statuses(mut self, statuses: impl IntoIterator<Item = Status>) -> Self {
self.statuses = statuses.into_iter().collect();
self
}
#[must_use]
pub fn with_all_methods(mut self) -> Self {
self.all_methods = true;
self
}
#[must_use]
pub fn with_transport_errors(mut self, retry: bool) -> Self {
self.transport_errors = retry;
self
}
#[must_use]
pub fn with_honor_retry_after(mut self, honor: bool) -> Self {
self.honor_retry_after = honor;
self
}
#[must_use]
pub fn with_max_retry_after(mut self, max: Duration) -> Self {
self.max_retry_after = Some(max);
self
}
#[must_use]
pub fn retry_when(mut self, predicate: impl Fn(&Conn) -> bool + Send + Sync + 'static) -> Self {
self.predicate = Some(Arc::new(predicate));
self
}
#[must_use]
pub fn with_decision(
mut self,
decision: impl Fn(&Conn, u32) -> Option<Duration> + Send + Sync + 'static,
) -> Self {
self.decision = Some(Arc::new(decision));
self
}
fn decide(&self, conn: &Conn, retry_number: u32) -> Option<Duration> {
if let Some(decision) = &self.decision {
return decision(conn, retry_number);
}
self.should_retry(conn)
.then(|| self.backoff.delay(retry_number, conn))
}
fn should_retry(&self, conn: &Conn) -> bool {
if let Some(predicate) = &self.predicate {
return predicate(conn);
}
if !self.all_methods && !is_idempotent(conn.method()) {
return false;
}
if conn.error().is_some() {
return self.transport_errors;
}
conn.status()
.is_some_and(|status| self.statuses.contains(&status))
}
fn effective_delay(&self, conn: &Conn, base_delay: Duration) -> Duration {
if !self.honor_retry_after {
return base_delay;
}
match retry_after(conn) {
Some(advised) => self.max_retry_after.map_or(advised, |cap| advised.min(cap)),
None => base_delay,
}
}
fn build_followup(&self, conn: &Conn, state: RetryState, remaining: Duration) -> Conn {
let mut followup = conn.client().build_conn(conn.method(), conn.url().clone());
let mut headers = conn.request_headers().clone();
headers.remove_all([Host, ContentLength, TransferEncoding, Expect, Connection]);
*followup.request_headers_mut() = headers;
if let Some(BodyReplay::Replayable(body)) = conn.state::<BodyReplay>()
&& let Some(replayed) = body.try_clone()
{
followup.set_request_body(replayed);
}
let timeout = conn.timeout().map_or(remaining, |t| t.min(remaining));
followup.set_timeout(timeout);
followup.insert_state(RetryState {
attempts: state.attempts + 1,
deadline: state.deadline,
});
followup
}
}
impl ClientHandler for RetryHandler {
async fn run(&self, conn: &mut Conn) -> Result<()> {
if conn.state::<RetryState>().is_none() {
conn.insert_state(RetryState {
attempts: 1,
deadline: Instant::now() + self.max_elapsed,
});
}
let replay = match conn.request_body() {
None => BodyReplay::None,
Some(body) => match body.try_clone() {
Some(clone) => BodyReplay::Replayable(clone),
None => BodyReplay::OneShot,
},
};
conn.insert_state(replay);
Ok(())
}
async fn after_response(&self, conn: &mut Conn) -> Result<()> {
let Some(state) = conn.state::<RetryState>().copied() else {
return Ok(());
};
if state.attempts >= self.max_attempts {
return Ok(());
}
if matches!(conn.state::<BodyReplay>(), Some(BodyReplay::OneShot)) {
return Ok(());
}
let retry_number = state.attempts;
let Some(base_delay) = self.decide(conn, retry_number) else {
return Ok(());
};
let delay = self.effective_delay(conn, base_delay);
if Instant::now() + delay >= state.deadline {
return Ok(());
}
conn.client().connector().runtime().delay(delay).await;
let remaining = state.deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
return Ok(());
}
let followup = self.build_followup(conn, state, remaining);
conn.take_error();
conn.set_followup(followup);
Ok(())
}
fn name(&self) -> Cow<'static, str> {
"RetryHandler".into()
}
}
fn retry_after(conn: &Conn) -> Option<Duration> {
conn.response_headers()
.get_str(RetryAfter)?
.trim()
.parse::<u64>()
.ok()
.map(Duration::from_secs)
}
#[derive(Clone, Copy)]
struct RetryState {
attempts: u32,
deadline: Instant,
}
enum BodyReplay {
None,
Replayable(Body),
OneShot,
}
impl fmt::Debug for RetryHandler {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RetryHandler")
.field("backoff", &self.backoff)
.field("max_attempts", &self.max_attempts)
.field("max_elapsed", &self.max_elapsed)
.field("statuses", &self.statuses)
.field("all_methods", &self.all_methods)
.field("transport_errors", &self.transport_errors)
.field("honor_retry_after", &self.honor_retry_after)
.field("max_retry_after", &self.max_retry_after)
.field("predicate", &self.predicate.as_ref().map(|_| "<fn>"))
.field("decision", &self.decision.as_ref().map(|_| "<fn>"))
.finish()
}
}