use std::{error::Error as StdError, time::Duration};
#[derive(Debug, Clone, PartialEq)]
pub enum RetryOptions {
Disabled,
Constant {
delay: Duration,
max_retries: Option<usize>,
},
Exponential {
initial_delay: Duration,
factor: f64,
max_delay: Option<Duration>,
max_retries: Option<usize>,
},
}
impl Default for RetryOptions {
fn default() -> Self {
Self::Disabled
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct RetryDecision {
pub(crate) attempt: usize,
pub(crate) delay: Duration,
}
#[derive(Debug, Clone)]
pub(crate) struct RetryState {
options: RetryOptions,
attempts: usize,
last_delay: Option<Duration>,
}
impl RetryState {
pub(crate) fn new(options: RetryOptions) -> Self {
Self {
options,
attempts: 0,
last_delay: None,
}
}
pub(crate) fn next(&mut self) -> Option<RetryDecision> {
match &self.options {
RetryOptions::Disabled => None,
RetryOptions::Constant { delay, max_retries } => {
let attempt = self.attempts + 1;
if max_retries.is_some_and(|max_retries| attempt > max_retries) {
return None;
}
self.attempts = attempt;
self.last_delay = Some(*delay);
Some(RetryDecision {
attempt,
delay: *delay,
})
}
RetryOptions::Exponential {
initial_delay,
factor,
max_delay,
max_retries,
} => {
let attempt = self.attempts + 1;
if max_retries.is_some_and(|max_retries| attempt > max_retries) {
return None;
}
let delay = if attempt == 1 {
*initial_delay
} else {
self.last_delay.unwrap_or(*initial_delay).mul_f64(*factor)
};
let delay = max_delay.map_or(delay, |max_delay| delay.min(max_delay));
self.attempts = attempt;
self.last_delay = Some(delay);
Some(RetryDecision { attempt, delay })
}
}
}
pub(crate) fn reset(&mut self) {
self.attempts = 0;
self.last_delay = None;
}
pub(crate) fn override_delay(&mut self, delay: Duration) {
match &mut self.options {
RetryOptions::Disabled => {}
RetryOptions::Constant { delay: current, .. } => *current = delay,
RetryOptions::Exponential {
initial_delay,
max_delay,
..
} => {
*initial_delay = delay;
if let Some(max_delay) = max_delay.as_mut() {
if *max_delay < delay {
*max_delay = delay;
}
}
}
}
}
}
pub(crate) fn is_retryable_request_error(error: &reqwest::Error) -> bool {
error.is_timeout() || error.is_connect() || source_contains_io_error(error)
}
fn source_contains_io_error(error: &reqwest::Error) -> bool {
let mut source = error.source();
while let Some(current) = source {
if current.downcast_ref::<std::io::Error>().is_some() {
return true;
}
source = current.source();
}
false
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn retry_state_resets_after_successful_connection() {
let mut retry_state = RetryState::new(RetryOptions::Exponential {
initial_delay: Duration::from_millis(50),
factor: 2.0,
max_delay: Some(Duration::from_millis(200)),
max_retries: Some(3),
});
assert_eq!(
retry_state.next(),
Some(RetryDecision {
attempt: 1,
delay: Duration::from_millis(50),
})
);
assert_eq!(
retry_state.next(),
Some(RetryDecision {
attempt: 2,
delay: Duration::from_millis(100),
})
);
retry_state.reset();
assert_eq!(
retry_state.next(),
Some(RetryDecision {
attempt: 1,
delay: Duration::from_millis(50),
})
);
}
#[test]
fn retry_state_applies_override_delay() {
let mut retry_state = RetryState::new(RetryOptions::Constant {
delay: Duration::from_millis(50),
max_retries: Some(2),
});
retry_state.override_delay(Duration::from_millis(10));
assert_eq!(
retry_state.next(),
Some(RetryDecision {
attempt: 1,
delay: Duration::from_millis(10),
})
);
assert_eq!(
retry_state.next(),
Some(RetryDecision {
attempt: 2,
delay: Duration::from_millis(10),
})
);
assert_eq!(retry_state.next(), None);
}
}