use std::collections::{BTreeMap, BTreeSet};
use std::sync::Arc;
use std::time::Duration;
use http::{HeaderMap, Method, StatusCode};
use crate::IDEMPOTENCY_KEY_HEADER;
use crate::error::{Error, TimeoutPhase, TransportErrorKind};
use crate::util::exponential_backoff_with_jitter;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum RetryReason {
Status(StatusCode),
Transport(TransportErrorKind),
Timeout(TimeoutPhase),
ResponseBodyRead,
}
#[derive(Clone, Debug)]
pub struct RetryDecision {
attempt: usize,
max_attempts: usize,
method: Method,
uri: String,
reason: RetryReason,
}
impl RetryDecision {
pub(crate) fn new(
attempt: usize,
max_attempts: usize,
method: Method,
uri: String,
reason: RetryReason,
) -> Self {
Self {
attempt,
max_attempts,
method,
uri,
reason,
}
}
pub fn attempt(&self) -> usize {
self.attempt
}
pub fn max_attempts(&self) -> usize {
self.max_attempts
}
pub fn method(&self) -> &Method {
&self.method
}
pub fn uri(&self) -> &str {
&self.uri
}
pub fn reason(&self) -> RetryReason {
self.reason
}
pub fn status(&self) -> Option<StatusCode> {
match self.reason {
RetryReason::Status(status) => Some(status),
_ => None,
}
}
pub fn transport_error_kind(&self) -> Option<TransportErrorKind> {
match self.reason {
RetryReason::Transport(kind) => Some(kind),
_ => None,
}
}
pub fn timeout_phase(&self) -> Option<TimeoutPhase> {
match self.reason {
RetryReason::Timeout(phase) => Some(phase),
_ => None,
}
}
pub fn is_response_body_read_error(&self) -> bool {
matches!(self.reason, RetryReason::ResponseBodyRead)
}
}
pub trait RetryClassifier: Send + Sync {
fn should_retry(&self, decision: &RetryDecision) -> bool;
}
pub trait RetryEligibility: Send + Sync {
fn supports_retry(&self, method: &Method, headers: &HeaderMap) -> bool;
}
#[derive(Default)]
pub struct StrictRetryEligibility;
impl RetryEligibility for StrictRetryEligibility {
fn supports_retry(&self, method: &Method, headers: &HeaderMap) -> bool {
request_supports_retry(method, headers)
}
}
#[derive(Default)]
pub struct PermissiveRetryEligibility;
impl RetryEligibility for PermissiveRetryEligibility {
fn supports_retry(&self, _method: &Method, _headers: &HeaderMap) -> bool {
true
}
}
#[derive(Clone)]
pub struct RetryPolicy {
max_attempts: usize,
base_backoff: Duration,
max_backoff: Duration,
jitter_ratio: f64,
retryable_status_codes: BTreeSet<u16>,
retryable_transport_error_kinds: BTreeSet<TransportErrorKind>,
retryable_timeout_phases: BTreeSet<TimeoutPhase>,
retry_on_response_body_read_error: bool,
status_retry_windows: BTreeMap<u16, usize>,
transport_retry_windows: BTreeMap<TransportErrorKind, usize>,
timeout_retry_windows: BTreeMap<TimeoutPhase, usize>,
response_body_read_retry_window: Option<usize>,
retry_classifier: Option<Arc<dyn RetryClassifier>>,
}
impl std::fmt::Debug for RetryPolicy {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter
.debug_struct("RetryPolicy")
.field("max_attempts", &self.max_attempts)
.field("base_backoff", &self.base_backoff)
.field("max_backoff", &self.max_backoff)
.field("jitter_ratio", &self.jitter_ratio)
.field("retryable_status_codes", &self.retryable_status_codes)
.field(
"retryable_transport_error_kinds",
&self.retryable_transport_error_kinds,
)
.field("retryable_timeout_phases", &self.retryable_timeout_phases)
.field(
"retry_on_response_body_read_error",
&self.retry_on_response_body_read_error,
)
.field("status_retry_windows", &self.status_retry_windows)
.field("transport_retry_windows", &self.transport_retry_windows)
.field("timeout_retry_windows", &self.timeout_retry_windows)
.field(
"response_body_read_retry_window",
&self.response_body_read_retry_window,
)
.finish()
}
}
impl RetryPolicy {
pub fn disabled() -> Self {
Self {
max_attempts: 1,
base_backoff: Duration::from_millis(200),
max_backoff: Duration::from_secs(2),
jitter_ratio: 0.0,
retryable_status_codes: default_retryable_status_codes(),
retryable_transport_error_kinds: default_retryable_transport_error_kinds(),
retryable_timeout_phases: default_retryable_timeout_phases(),
retry_on_response_body_read_error: true,
status_retry_windows: BTreeMap::new(),
transport_retry_windows: BTreeMap::new(),
timeout_retry_windows: BTreeMap::new(),
response_body_read_retry_window: None,
retry_classifier: None,
}
}
pub fn standard() -> Self {
Self {
max_attempts: 3,
base_backoff: Duration::from_millis(200),
max_backoff: Duration::from_secs(2),
jitter_ratio: 0.2,
retryable_status_codes: default_retryable_status_codes(),
retryable_transport_error_kinds: default_retryable_transport_error_kinds(),
retryable_timeout_phases: default_retryable_timeout_phases(),
retry_on_response_body_read_error: true,
status_retry_windows: BTreeMap::new(),
transport_retry_windows: BTreeMap::new(),
timeout_retry_windows: BTreeMap::new(),
response_body_read_retry_window: None,
retry_classifier: None,
}
}
pub fn max_attempts(mut self, max_attempts: usize) -> Self {
self.max_attempts = max_attempts;
self
}
pub fn base_backoff(mut self, base_backoff: Duration) -> Self {
self.base_backoff = base_backoff;
self
}
pub fn max_backoff(mut self, max_backoff: Duration) -> Self {
self.max_backoff = max_backoff;
self
}
pub fn jitter_ratio(mut self, jitter_ratio: f64) -> Self {
self.jitter_ratio = jitter_ratio;
self
}
pub fn retryable_status_codes(mut self, codes: impl IntoIterator<Item = u16>) -> Self {
self.retryable_status_codes = codes.into_iter().collect();
self
}
pub fn retryable_transport_error_kinds(
mut self,
kinds: impl IntoIterator<Item = TransportErrorKind>,
) -> Self {
self.retryable_transport_error_kinds = kinds.into_iter().collect();
self
}
pub fn retryable_timeout_phases(
mut self,
phases: impl IntoIterator<Item = TimeoutPhase>,
) -> Self {
self.retryable_timeout_phases = phases.into_iter().collect();
self
}
pub fn retry_on_response_body_read_error(mut self, retry: bool) -> Self {
self.retry_on_response_body_read_error = retry;
self
}
pub fn status_retry_window(mut self, status: u16, max_attempts: usize) -> Self {
self.status_retry_windows.insert(status, max_attempts);
self
}
pub fn transport_retry_window(mut self, kind: TransportErrorKind, max_attempts: usize) -> Self {
self.transport_retry_windows.insert(kind, max_attempts);
self
}
pub fn timeout_retry_window(mut self, phase: TimeoutPhase, max_attempts: usize) -> Self {
self.timeout_retry_windows.insert(phase, max_attempts);
self
}
pub fn response_body_read_retry_window(mut self, max_attempts: usize) -> Self {
self.response_body_read_retry_window = Some(max_attempts);
self
}
pub fn retry_classifier(mut self, retry_classifier: Arc<dyn RetryClassifier>) -> Self {
self.retry_classifier = Some(retry_classifier);
self
}
pub(crate) fn configured_max_attempts(&self) -> usize {
self.max_attempts
}
pub(crate) fn configured_max_backoff(&self) -> Duration {
self.max_backoff
}
pub(crate) fn validate(&self) -> crate::Result<()> {
if self.max_attempts == 0 {
return Err(self.invalid_policy("max_attempts must be greater than zero"));
}
if self.base_backoff.is_zero() {
return Err(self.invalid_policy("base_backoff must be greater than zero"));
}
if self.max_backoff.is_zero() {
return Err(self.invalid_policy("max_backoff must be greater than zero"));
}
if self.max_backoff < self.base_backoff {
return Err(
self.invalid_policy("max_backoff must be greater than or equal to base_backoff")
);
}
if !self.jitter_ratio.is_finite() || !(0.0..=1.0).contains(&self.jitter_ratio) {
return Err(self.invalid_policy("jitter_ratio must be finite and between 0.0 and 1.0"));
}
if self
.retryable_status_codes
.iter()
.any(|status| !is_valid_retryable_status_code(*status))
{
return Err(
self.invalid_policy("retryable status codes must be valid non-success statuses")
);
}
if self
.status_retry_windows
.keys()
.any(|status| !is_valid_retryable_status_code(*status))
{
return Err(
self.invalid_policy("status retry windows must target valid non-success statuses")
);
}
if self.status_retry_windows.values().any(|limit| *limit == 0) {
return Err(self.invalid_policy("status retry windows must be greater than zero"));
}
if self
.transport_retry_windows
.values()
.any(|limit| *limit == 0)
{
return Err(self.invalid_policy("transport retry windows must be greater than zero"));
}
if self.timeout_retry_windows.values().any(|limit| *limit == 0) {
return Err(self.invalid_policy("timeout retry windows must be greater than zero"));
}
if self.response_body_read_retry_window == Some(0) {
return Err(
self.invalid_policy("response body read retry window must be greater than zero")
);
}
Ok(())
}
fn invalid_policy(&self, message: &'static str) -> Error {
Error::InvalidRetryPolicy {
max_attempts: self.max_attempts,
base_backoff_ms: self.base_backoff.as_millis(),
max_backoff_ms: self.max_backoff.as_millis(),
jitter_ratio: self.jitter_ratio,
message,
}
}
fn should_retry_status(&self, status: StatusCode) -> bool {
self.retryable_status_codes.contains(&status.as_u16())
}
pub(crate) fn is_retryable_status(&self, status: StatusCode) -> bool {
self.should_retry_status(status)
}
fn is_within_retry_window(limit: Option<usize>, attempt: usize) -> bool {
match limit {
Some(limit) => attempt < limit,
None => true,
}
}
pub(crate) fn should_retry_decision(&self, decision: &RetryDecision) -> bool {
if let Some(retry_classifier) = &self.retry_classifier {
return retry_classifier.should_retry(decision);
}
match decision.reason() {
RetryReason::Status(status) => {
let window = self.status_retry_windows.get(&status.as_u16()).copied();
self.should_retry_status(status)
&& Self::is_within_retry_window(window, decision.attempt())
}
RetryReason::Transport(kind) => {
let window = self.transport_retry_windows.get(&kind).copied();
self.retryable_transport_error_kinds.contains(&kind)
&& Self::is_within_retry_window(window, decision.attempt())
}
RetryReason::Timeout(phase) => {
let window = self.timeout_retry_windows.get(&phase).copied();
self.retryable_timeout_phases.contains(&phase)
&& Self::is_within_retry_window(window, decision.attempt())
}
RetryReason::ResponseBodyRead => {
self.retry_on_response_body_read_error
&& Self::is_within_retry_window(
self.response_body_read_retry_window,
decision.attempt(),
)
}
}
}
pub(crate) fn backoff_for_retry(&self, retry_index: usize) -> Duration {
exponential_backoff_with_jitter(
retry_index,
self.base_backoff,
self.max_backoff,
self.jitter_ratio,
)
}
}
impl Default for RetryPolicy {
fn default() -> Self {
Self::standard()
}
}
fn default_retryable_status_codes() -> BTreeSet<u16> {
[429_u16, 500, 502, 503, 504].into_iter().collect()
}
fn is_valid_retryable_status_code(status: u16) -> bool {
StatusCode::from_u16(status).is_ok_and(|status| !status.is_success())
}
fn default_retryable_transport_error_kinds() -> BTreeSet<TransportErrorKind> {
[
TransportErrorKind::Dns,
TransportErrorKind::Connect,
TransportErrorKind::Read,
]
.into_iter()
.collect()
}
fn default_retryable_timeout_phases() -> BTreeSet<TimeoutPhase> {
[TimeoutPhase::Transport, TimeoutPhase::ResponseBody]
.into_iter()
.collect()
}
pub(crate) fn request_supports_retry(method: &Method, headers: &HeaderMap) -> bool {
is_method_idempotent(method) || headers.get(IDEMPOTENCY_KEY_HEADER).is_some()
}
fn is_method_idempotent(method: &Method) -> bool {
matches!(
*method,
Method::GET | Method::HEAD | Method::PUT | Method::DELETE | Method::OPTIONS | Method::TRACE
)
}
#[cfg(test)]
mod tests {
use super::{RetryDecision, RetryPolicy, RetryReason};
use http::{Method, StatusCode};
#[test]
fn jittered_backoff_never_exceeds_configured_max_backoff() {
let policy = RetryPolicy::standard()
.base_backoff(std::time::Duration::from_millis(100))
.max_backoff(std::time::Duration::from_millis(120))
.jitter_ratio(1.0);
for _ in 0..256 {
let backoff = policy.backoff_for_retry(3);
assert!(backoff <= std::time::Duration::from_millis(120));
}
}
#[test]
fn validate_rejects_nan_jitter_ratio() {
let policy = RetryPolicy::standard()
.base_backoff(std::time::Duration::from_millis(100))
.max_backoff(std::time::Duration::from_millis(500))
.jitter_ratio(f64::NAN);
assert!(policy.validate().is_err());
}
#[test]
fn validate_rejects_zero_attempts_and_backoff() {
assert!(RetryPolicy::standard().max_attempts(0).validate().is_err());
assert!(
RetryPolicy::standard()
.base_backoff(std::time::Duration::ZERO)
.validate()
.is_err()
);
assert!(
RetryPolicy::standard()
.max_backoff(std::time::Duration::ZERO)
.validate()
.is_err()
);
}
#[test]
fn validate_rejects_zero_retry_windows() {
assert!(
RetryPolicy::standard()
.status_retry_window(503, 0)
.validate()
.is_err()
);
assert!(
RetryPolicy::standard()
.response_body_read_retry_window(0)
.validate()
.is_err()
);
}
#[test]
fn validate_rejects_invalid_or_success_status_policy_entries() {
assert!(
RetryPolicy::standard()
.retryable_status_codes([99])
.validate()
.is_err()
);
assert!(
RetryPolicy::standard()
.retryable_status_codes([200])
.validate()
.is_err()
);
assert!(
RetryPolicy::standard()
.status_retry_window(1000, 2)
.validate()
.is_err()
);
assert!(
RetryPolicy::standard()
.status_retry_window(204, 2)
.validate()
.is_err()
);
}
#[test]
fn retry_decision_accessors_reflect_reason() {
let decision = RetryDecision::new(
1,
3,
Method::GET,
"https://api.example.com/v1/items".to_owned(),
RetryReason::Status(StatusCode::TOO_MANY_REQUESTS),
);
assert_eq!(decision.attempt(), 1);
assert_eq!(decision.max_attempts(), 3);
assert_eq!(decision.method(), &Method::GET);
assert_eq!(decision.uri(), "https://api.example.com/v1/items");
assert_eq!(decision.status(), Some(StatusCode::TOO_MANY_REQUESTS));
assert_eq!(decision.transport_error_kind(), None);
assert_eq!(decision.timeout_phase(), None);
assert!(!decision.is_response_body_read_error());
}
}