use std::sync::Arc;
use std::time::Duration;
use http::{HeaderMap, StatusCode};
use crate::response::Response;
pub type ShouldRetryFn = Arc<dyn Fn(&Response) -> bool + Send + Sync>;
#[derive(Clone)]
#[must_use = "apply with `ClientBuilder::retry` or `RequestBuilder::retry`"]
pub enum RetryPolicy {
Count {
attempts: u32,
should_retry: Option<ShouldRetryFn>,
},
Linear {
attempts: u32,
delay: Duration,
should_retry: Option<ShouldRetryFn>,
jitter: bool,
},
Exponential {
attempts: u32,
base_delay: Duration,
max_delay: Duration,
should_retry: Option<ShouldRetryFn>,
jitter: bool,
},
}
impl RetryPolicy {
pub fn count(attempts: u32) -> Self {
Self::Count {
attempts,
should_retry: None,
}
}
pub fn linear(attempts: u32, delay: Duration) -> Self {
Self::Linear {
attempts,
delay,
should_retry: None,
jitter: false,
}
}
pub fn exponential(attempts: u32, base_delay: Duration, max_delay: Duration) -> Self {
Self::Exponential {
attempts,
base_delay,
max_delay,
should_retry: None,
jitter: true,
}
}
#[must_use = "chain with `ClientBuilder::retry` or `RequestBuilder::retry`"]
pub fn with_jitter(mut self, jitter: bool) -> Self {
match &mut self {
Self::Linear { jitter: j, .. } | Self::Exponential { jitter: j, .. } => *j = jitter,
Self::Count { .. } => {}
}
self
}
#[must_use = "chain with `ClientBuilder::retry` or `RequestBuilder::retry`"]
pub fn with_should_retry(self, f: ShouldRetryFn) -> Self {
match self {
Self::Count { attempts, .. } => Self::Count {
attempts,
should_retry: Some(f),
},
Self::Linear {
attempts,
delay,
jitter,
..
} => Self::Linear {
attempts,
delay,
should_retry: Some(f),
jitter,
},
Self::Exponential {
attempts,
base_delay,
max_delay,
jitter,
..
} => Self::Exponential {
attempts,
base_delay,
max_delay,
should_retry: Some(f),
jitter,
},
}
}
pub fn max_attempts(&self) -> u32 {
match self {
Self::Count { attempts, .. }
| Self::Linear { attempts, .. }
| Self::Exponential { attempts, .. } => *attempts,
}
}
pub(crate) fn delay_before_attempt(&self, attempt: u32) -> Duration {
match self {
Self::Count { .. } => Duration::from_secs(1),
Self::Linear { delay, .. } => *delay,
Self::Exponential {
base_delay,
max_delay,
..
} => {
let exp = base_delay.saturating_mul(2u32.saturating_pow(attempt));
exp.min(*max_delay)
}
}
}
pub(crate) fn delay_after_response(&self, attempt: u32, headers: &HeaderMap) -> Duration {
let base = self.delay_before_attempt(attempt);
let delay = parse_retry_after(headers).unwrap_or(base);
if self.uses_jitter() {
apply_jitter(delay)
} else {
delay
}
}
pub(crate) fn uses_jitter(&self) -> bool {
match self {
Self::Count { .. } => true,
Self::Linear { jitter, .. } | Self::Exponential { jitter, .. } => *jitter,
}
}
pub(crate) fn has_custom_should_retry(&self) -> bool {
matches!(
self,
Self::Count {
should_retry: Some(_),
..
} | Self::Linear {
should_retry: Some(_),
..
} | Self::Exponential {
should_retry: Some(_),
..
}
)
}
pub(crate) fn should_retry_response(
&self,
response: &Response,
transport_failed: bool,
) -> bool {
if transport_failed {
return true;
}
let custom = match self {
Self::Count { should_retry, .. }
| Self::Linear { should_retry, .. }
| Self::Exponential { should_retry, .. } => should_retry.as_ref(),
};
if let Some(f) = custom {
return f(response);
}
default_should_retry(response.status())
}
}
pub fn default_should_retry(status: StatusCode) -> bool {
matches!(status.as_u16(), 408 | 429 | 502 | 503 | 504)
}
pub fn parse_retry_after(headers: &HeaderMap) -> Option<Duration> {
let value = headers.get(http::header::RETRY_AFTER)?.to_str().ok()?;
let secs = value.trim().parse::<u64>().ok()?;
Some(Duration::from_secs(secs))
}
fn apply_jitter(delay: Duration) -> Duration {
let nanos = delay.as_nanos().min(u128::from(u64::MAX)) as u64;
if nanos == 0 {
return delay;
}
let half = nanos / 2;
let span = nanos.saturating_sub(half).max(1);
Duration::from_nanos(half + fastrand::u64(..span))
}
pub(crate) use crate::cancel::sleep_or_cancel;
#[cfg(test)]
mod tests {
use super::*;
use crate::response::Response;
use http::StatusCode;
fn response_with_status(status: u16) -> Response {
Response::new(
StatusCode::from_u16(status).unwrap(),
http::HeaderMap::new(),
bytes::Bytes::new(),
None,
#[cfg(feature = "json")]
None,
)
}
#[test]
fn default_should_retry_codes() {
assert!(default_should_retry(StatusCode::REQUEST_TIMEOUT));
assert!(default_should_retry(StatusCode::TOO_MANY_REQUESTS));
assert!(default_should_retry(StatusCode::SERVICE_UNAVAILABLE));
assert!(!default_should_retry(StatusCode::NOT_FOUND));
}
#[test]
fn count_policy_max_attempts() {
assert_eq!(RetryPolicy::count(3).max_attempts(), 3);
}
#[test]
fn count_with_should_retry_stays_count() {
let policy = RetryPolicy::count(2)
.with_should_retry(Arc::new(|r| r.status() == StatusCode::NOT_FOUND));
assert!(matches!(policy, RetryPolicy::Count { .. }));
assert!(policy.should_retry_response(&response_with_status(404), false));
assert!(!policy.should_retry_response(&response_with_status(503), false));
}
#[test]
fn linear_delay_is_constant() {
let policy = RetryPolicy::linear(3, Duration::from_millis(500));
assert_eq!(policy.delay_before_attempt(0), Duration::from_millis(500));
assert_eq!(policy.delay_before_attempt(2), Duration::from_millis(500));
}
#[test]
fn exponential_delay_caps_at_max() {
let policy = RetryPolicy::exponential(5, Duration::from_secs(1), Duration::from_secs(5));
assert_eq!(policy.delay_before_attempt(0), Duration::from_secs(1));
assert_eq!(policy.delay_before_attempt(10), Duration::from_secs(5));
}
#[test]
fn custom_should_retry_overrides_default() {
let policy = RetryPolicy::linear(2, Duration::from_millis(1))
.with_should_retry(Arc::new(|r| r.status() == StatusCode::NOT_FOUND));
assert!(policy.should_retry_response(&response_with_status(404), false));
assert!(!policy.should_retry_response(&response_with_status(503), false));
}
#[test]
fn parse_retry_after_seconds() {
let mut headers = HeaderMap::new();
headers.insert(http::header::RETRY_AFTER, "3".parse().unwrap());
assert_eq!(parse_retry_after(&headers), Some(Duration::from_secs(3)));
}
#[test]
fn delay_after_response_uses_retry_after() {
let mut headers = HeaderMap::new();
headers.insert(http::header::RETRY_AFTER, "2".parse().unwrap());
let policy = RetryPolicy::linear(1, Duration::from_millis(100)).with_jitter(false);
assert_eq!(
policy.delay_after_response(0, &headers),
Duration::from_secs(2)
);
}
#[test]
fn jitter_stays_within_bounds() {
let base = Duration::from_secs(4);
for _ in 0..20 {
let jittered = apply_jitter(base);
assert!(jittered >= Duration::from_secs(2));
assert!(jittered <= base);
}
}
#[test]
fn parse_retry_after_invalid_is_none() {
let mut headers = HeaderMap::new();
headers.insert(http::header::RETRY_AFTER, "not-a-number".parse().unwrap());
assert!(parse_retry_after(&headers).is_none());
}
#[test]
fn exponential_uses_jitter_by_default() {
let policy = RetryPolicy::exponential(3, Duration::from_secs(1), Duration::from_secs(8));
assert!(policy.uses_jitter());
}
#[test]
fn linear_jitter_disabled_by_default() {
assert!(!RetryPolicy::linear(1, Duration::from_secs(1)).uses_jitter());
}
}