use core::time::Duration;
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub enum BackoffStrategy {
Fixed {
base: Duration,
},
Exponential {
base: Duration,
max_delay: Duration,
},
DecorrelatedJitter {
base: Duration,
max_delay: Duration,
},
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct RetryPolicy {
pub max_attempts: u32,
pub strategy: BackoffStrategy,
}
impl RetryPolicy {
#[must_use]
pub fn fixed(max_attempts: u32, base: Duration) -> Self {
Self {
max_attempts,
strategy: BackoffStrategy::Fixed { base },
}
}
#[must_use]
pub fn exponential(max_attempts: u32, base: Duration) -> Self {
let max_delay = base * 1024;
Self {
max_attempts,
strategy: BackoffStrategy::Exponential { base, max_delay },
}
}
#[must_use]
pub fn decorrelated_jitter(max_attempts: u32, base: Duration) -> Self {
let max_delay = base * 1024;
Self {
max_attempts,
strategy: BackoffStrategy::DecorrelatedJitter { base, max_delay },
}
}
#[must_use]
pub fn next_delay(&self, attempt: u32) -> Duration {
match &self.strategy {
BackoffStrategy::Fixed { base } => *base,
BackoffStrategy::Exponential { base, max_delay } => {
let multiplier = 1_u64.checked_shl(attempt).unwrap_or(u64::MAX);
let delay = base.saturating_mul(u32::try_from(multiplier).unwrap_or(u32::MAX));
delay.min(*max_delay)
}
BackoffStrategy::DecorrelatedJitter { base, max_delay } => {
let mut upper = *base;
for _ in 0..=attempt {
upper = upper.saturating_mul(3).min(*max_delay);
}
let half_range = upper.saturating_sub(*base) / 2;
(*base + half_range).min(*max_delay)
}
}
}
}
#[cfg(any(feature = "chrono", feature = "std", feature = "alloc"))]
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum RetryAfter {
Delay(Duration),
Date(
#[cfg_attr(feature = "serde", serde(rename = "date"))]
RetryAfterDate,
),
}
#[cfg(feature = "chrono")]
pub type RetryAfterDate = chrono::DateTime<chrono::FixedOffset>;
#[cfg(all(not(feature = "chrono"), any(feature = "std", feature = "alloc")))]
#[cfg(not(feature = "std"))]
pub type RetryAfterDate = alloc::string::String;
#[cfg(all(not(feature = "chrono"), feature = "std"))]
pub type RetryAfterDate = std::string::String;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RetryAfterParseError(
#[cfg(any(feature = "std", feature = "alloc"))]
#[cfg_attr(not(feature = "std"), allow(dead_code))]
RetryAfterParseErrorInner,
);
#[cfg(any(feature = "std", feature = "alloc"))]
#[derive(Debug, Clone, PartialEq, Eq)]
enum RetryAfterParseErrorInner {
#[cfg(feature = "chrono")]
InvalidDate(chrono::ParseError),
#[cfg(not(feature = "chrono"))]
InvalidFormat,
}
impl core::fmt::Display for RetryAfterParseError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
#[cfg(any(feature = "std", feature = "alloc"))]
match &self.0 {
#[cfg(feature = "chrono")]
RetryAfterParseErrorInner::InvalidDate(e) => {
write!(f, "invalid Retry-After date: {e}")
}
#[cfg(not(feature = "chrono"))]
RetryAfterParseErrorInner::InvalidFormat => {
f.write_str("Retry-After value must be delta-seconds or an HTTP-date")
}
}
#[cfg(not(any(feature = "std", feature = "alloc")))]
f.write_str("invalid Retry-After value")
}
}
#[cfg(feature = "std")]
impl std::error::Error for RetryAfterParseError {
#[cfg(feature = "chrono")]
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match &self.0 {
RetryAfterParseErrorInner::InvalidDate(e) => Some(e),
}
}
}
#[cfg(any(feature = "std", feature = "alloc"))]
impl core::str::FromStr for RetryAfter {
type Err = RetryAfterParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let trimmed = s.trim();
if let Ok(secs) = trimmed.parse::<u64>() {
return Ok(Self::Delay(Duration::from_secs(secs)));
}
#[cfg(feature = "chrono")]
{
chrono::DateTime::parse_from_rfc2822(trimmed)
.map(Self::Date)
.map_err(|e| RetryAfterParseError(RetryAfterParseErrorInner::InvalidDate(e)))
}
#[cfg(not(feature = "chrono"))]
{
if trimmed.is_empty() {
Err(RetryAfterParseError(
RetryAfterParseErrorInner::InvalidFormat,
))
} else {
Ok(Self::Date(trimmed.into()))
}
}
}
}
#[cfg(any(feature = "std", feature = "alloc"))]
impl core::fmt::Display for RetryAfter {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::Delay(d) => write!(f, "{}", d.as_secs()),
#[cfg(feature = "chrono")]
Self::Date(dt) => write!(f, "{}", dt.to_rfc2822()),
#[cfg(not(feature = "chrono"))]
Self::Date(s) => f.write_str(s),
}
}
}
pub trait Idempotent {}
#[cfg(test)]
mod tests {
use super::*;
use core::time::Duration;
#[test]
fn fixed_delay_is_constant() {
let p = RetryPolicy::fixed(5, Duration::from_millis(250));
assert_eq!(p.next_delay(0), Duration::from_millis(250));
assert_eq!(p.next_delay(3), Duration::from_millis(250));
assert_eq!(p.next_delay(100), Duration::from_millis(250));
}
#[test]
fn exponential_doubles_each_attempt() {
let p = RetryPolicy::exponential(10, Duration::from_millis(100));
assert_eq!(p.next_delay(0), Duration::from_millis(100));
assert_eq!(p.next_delay(1), Duration::from_millis(200));
assert_eq!(p.next_delay(2), Duration::from_millis(400));
assert_eq!(p.next_delay(3), Duration::from_millis(800));
}
#[test]
fn exponential_caps_at_max_delay() {
let p = RetryPolicy::exponential(5, Duration::from_millis(100));
let cap = Duration::from_millis(100) * 1024;
assert_eq!(p.next_delay(100), cap);
}
#[test]
fn exponential_handles_overflow_gracefully() {
let p = RetryPolicy::exponential(5, Duration::from_secs(1));
let d = p.next_delay(u32::MAX);
assert!(d <= Duration::from_secs(1) * 1024);
}
#[test]
fn jitter_delay_gte_base() {
let base = Duration::from_millis(100);
let p = RetryPolicy::decorrelated_jitter(5, base);
for attempt in 0..10 {
assert!(
p.next_delay(attempt) >= base,
"attempt {attempt}: delay < base"
);
}
}
#[test]
fn jitter_delay_lte_max() {
let base = Duration::from_millis(100);
let p = RetryPolicy::decorrelated_jitter(5, base);
let max = base * 1024;
for attempt in 0..20 {
assert!(
p.next_delay(attempt) <= max,
"attempt {attempt}: delay > max_delay"
);
}
}
#[cfg(any(feature = "std", feature = "alloc"))]
#[test]
fn parse_delta_seconds() {
let r: RetryAfter = "120".parse().unwrap();
assert_eq!(r, RetryAfter::Delay(Duration::from_secs(120)));
}
#[cfg(any(feature = "std", feature = "alloc"))]
#[test]
fn parse_zero_seconds() {
let r: RetryAfter = "0".parse().unwrap();
assert_eq!(r, RetryAfter::Delay(Duration::ZERO));
}
#[cfg(all(any(feature = "std", feature = "alloc"), feature = "chrono"))]
#[test]
fn parse_http_date() {
let r: RetryAfter = "Wed, 21 Oct 2015 07:28:00 GMT".parse().unwrap();
assert!(matches!(r, RetryAfter::Date(_)));
}
#[cfg(all(any(feature = "std", feature = "alloc"), feature = "chrono"))]
#[test]
fn parse_invalid_returns_error() {
let r: Result<RetryAfter, _> = "not-a-valid-value".parse();
assert!(r.is_err());
}
#[cfg(any(feature = "std", feature = "alloc"))]
#[test]
fn display_delay_round_trips() {
let r = RetryAfter::Delay(Duration::from_secs(60));
assert_eq!(r.to_string(), "60");
}
#[cfg(all(any(feature = "std", feature = "alloc"), feature = "chrono"))]
#[test]
fn display_date_round_trips() {
let original = "Wed, 21 Oct 2015 07:28:00 +0000";
let r: RetryAfter = original.parse().unwrap();
let back: RetryAfter = r.to_string().parse().unwrap();
assert_eq!(r, back);
}
#[cfg(all(any(feature = "std", feature = "alloc"), feature = "serde"))]
#[test]
fn serde_delay_round_trip() {
let r = RetryAfter::Delay(Duration::from_secs(30));
let json = serde_json::to_value(&r).unwrap();
let back: RetryAfter = serde_json::from_value(json).unwrap();
assert_eq!(back, r);
}
struct GetItems;
impl Idempotent for GetItems {}
fn require_idempotent<R: Idempotent>(_: &R) {}
#[test]
fn idempotent_implementor_accepted_by_generic_fn() {
let req = GetItems;
require_idempotent(&req);
}
#[cfg(all(any(feature = "std", feature = "alloc"), feature = "chrono"))]
#[test]
fn retry_after_parse_error_display() {
let err: Result<RetryAfter, _> = "not-a-valid-date-or-number".parse();
let e = err.unwrap_err();
let s = e.to_string();
assert!(s.contains("invalid Retry-After"));
}
#[cfg(all(feature = "std", feature = "chrono"))]
#[test]
fn retry_after_parse_error_source() {
use std::error::Error;
let err: Result<RetryAfter, _> = "not-a-date".parse();
let e = err.unwrap_err();
assert!(e.source().is_some());
}
}