use std::{fmt, sync::Arc, time::Duration};
use trillium_client::Conn;
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub(crate) enum Jitter {
None,
#[default]
Full,
}
type CustomFn = Arc<dyn Fn(u32, &Conn) -> Duration + Send + Sync>;
#[derive(Clone)]
pub(crate) enum Kind {
Constant(Duration),
Linear(Duration),
Exponential(Duration),
Custom(CustomFn),
}
#[derive(Clone)]
pub(crate) struct Backoff {
pub(crate) kind: Kind,
pub(crate) max_delay: Option<Duration>,
pub(crate) jitter: Jitter,
}
impl Default for Backoff {
fn default() -> Self {
Self {
kind: Kind::Exponential(Duration::from_millis(100)),
max_delay: None,
jitter: Jitter::Full,
}
}
}
impl Backoff {
pub(crate) fn delay(&self, retry_number: u32, conn: &Conn) -> Duration {
let base = match &self.kind {
Kind::Constant(delay) => *delay,
Kind::Linear(step) => step.saturating_mul(retry_number),
Kind::Exponential(base) => {
base.saturating_mul(2u32.saturating_pow(retry_number.saturating_sub(1)))
}
Kind::Custom(f) => f(retry_number, conn),
};
let capped = self.max_delay.map_or(base, |max| base.min(max));
match self.jitter {
Jitter::None => capped,
Jitter::Full => full_jitter(capped),
}
}
}
fn full_jitter(max: Duration) -> Duration {
let max_nanos = u64::try_from(max.as_nanos()).unwrap_or(u64::MAX);
if max_nanos == 0 {
Duration::ZERO
} else {
Duration::from_nanos(fastrand::u64(0..=max_nanos))
}
}
impl fmt::Debug for Backoff {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Backoff")
.field("kind", &self.kind)
.field("max_delay", &self.max_delay)
.field("jitter", &self.jitter)
.finish()
}
}
impl fmt::Debug for Kind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Constant(d) => f.debug_tuple("Constant").field(d).finish(),
Self::Linear(d) => f.debug_tuple("Linear").field(d).finish(),
Self::Exponential(d) => f.debug_tuple("Exponential").field(d).finish(),
Self::Custom(_) => f.debug_tuple("Custom").field(&"<fn>").finish(),
}
}
}