use crate::backoff::Backoff;
const DEFAULT_MAX_ATTEMPTS: u32 = 5;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RetryAction {
Retry,
RetryAfter(core::time::Duration),
GiveUp,
}
#[derive(Debug, Clone, Copy)]
pub struct Retry {
backoff: Backoff,
max_attempts: u32,
respect_retry_after: bool,
}
impl Retry {
#[must_use]
pub fn new(backoff: Backoff) -> Self {
Self {
backoff,
max_attempts: DEFAULT_MAX_ATTEMPTS,
respect_retry_after: true,
}
}
#[must_use]
pub fn max_attempts(mut self, attempts: u32) -> Self {
self.max_attempts = attempts.max(1);
self
}
#[must_use]
pub fn respect_retry_after(mut self, yes: bool) -> Self {
self.respect_retry_after = yes;
self
}
#[must_use]
pub const fn attempts(&self) -> u32 {
self.max_attempts
}
#[must_use]
pub const fn backoff(&self) -> &Backoff {
&self.backoff
}
}
#[cfg(feature = "runtime")]
#[cfg_attr(docsrs, doc(cfg(feature = "runtime")))]
impl Retry {
pub async fn run<F, Fut, T, E, C>(&self, mut operation: F, classify: C) -> Result<T, E>
where
F: FnMut() -> Fut,
Fut: core::future::Future<Output = Result<T, E>>,
C: Fn(&E) -> RetryAction,
{
let mut delays = self.backoff.iter();
let mut attempt = 1u32;
loop {
match operation().await {
Ok(value) => return Ok(value),
Err(error) => {
if attempt >= self.max_attempts {
return Err(error);
}
let delay = match classify(&error) {
RetryAction::GiveUp => return Err(error),
RetryAction::Retry => delays.next_delay(),
RetryAction::RetryAfter(after) => {
let computed = delays.next_delay();
if self.respect_retry_after {
after
} else {
computed
}
}
};
crate::rt::sleep(delay).await;
attempt += 1;
}
}
}
}
}
#[must_use]
pub fn retry_if_retryable<E: error_forge::ForgeError>(error: &E) -> RetryAction {
if error.is_retryable() {
RetryAction::Retry
} else {
RetryAction::GiveUp
}
}
#[cfg(test)]
mod tests {
use super::{Retry, RetryAction, retry_if_retryable};
use crate::backoff::Backoff;
use core::time::Duration;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
fn fast_policy() -> Retry {
Retry::new(Backoff::constant(Duration::from_millis(10)))
}
#[test]
fn test_max_attempts_floor_is_one() {
assert_eq!(fast_policy().max_attempts(0).attempts(), 1);
assert_eq!(fast_policy().max_attempts(7).attempts(), 7);
}
#[tokio::test(start_paused = true)]
async fn test_succeeds_after_transient_failures() {
let calls = Arc::new(AtomicU32::new(0));
let c = calls.clone();
let result: Result<u32, &str> = fast_policy()
.max_attempts(5)
.run(
move || {
let c = c.clone();
async move {
let n = c.fetch_add(1, Ordering::Relaxed) + 1;
if n < 3 { Err("transient") } else { Ok(n) }
}
},
|_| RetryAction::Retry,
)
.await;
assert_eq!(result, Ok(3));
assert_eq!(calls.load(Ordering::Relaxed), 3);
}
#[tokio::test(start_paused = true)]
async fn test_gives_up_after_max_attempts() {
let calls = Arc::new(AtomicU32::new(0));
let c = calls.clone();
let result: Result<(), &str> = fast_policy()
.max_attempts(3)
.run(
move || {
let c = c.clone();
async move {
let _ = c.fetch_add(1, Ordering::Relaxed);
Err("always")
}
},
|_| RetryAction::Retry,
)
.await;
assert_eq!(result, Err("always"));
assert_eq!(
calls.load(Ordering::Relaxed),
3,
"operation runs exactly max_attempts times"
);
}
#[tokio::test(start_paused = true)]
async fn test_give_up_classification_stops_immediately() {
let calls = Arc::new(AtomicU32::new(0));
let c = calls.clone();
let result: Result<(), &str> = fast_policy()
.max_attempts(10)
.run(
move || {
let c = c.clone();
async move {
let _ = c.fetch_add(1, Ordering::Relaxed);
Err("fatal")
}
},
|_| RetryAction::GiveUp,
)
.await;
assert_eq!(result, Err("fatal"));
assert_eq!(
calls.load(Ordering::Relaxed),
1,
"GiveUp stops after the first attempt"
);
}
#[cfg(feature = "tokio")]
#[tokio::test(start_paused = true)]
async fn test_retry_after_is_honored_when_enabled() {
let start = tokio::time::Instant::now();
let policy = Retry::new(Backoff::constant(Duration::from_secs(1)))
.max_attempts(2)
.respect_retry_after(true);
let _: Result<(), &str> = policy
.run(
|| async { Err("rejected") },
|_| RetryAction::RetryAfter(Duration::from_secs(30)),
)
.await;
assert_eq!(start.elapsed(), Duration::from_secs(30));
}
#[cfg(feature = "tokio")]
#[tokio::test(start_paused = true)]
async fn test_retry_after_is_ignored_when_disabled() {
let start = tokio::time::Instant::now();
let policy = Retry::new(Backoff::constant(Duration::from_secs(1)))
.max_attempts(2)
.respect_retry_after(false);
let _: Result<(), &str> = policy
.run(
|| async { Err("rejected") },
|_| RetryAction::RetryAfter(Duration::from_secs(30)),
)
.await;
assert_eq!(start.elapsed(), Duration::from_secs(1));
}
#[tokio::test(start_paused = true)]
async fn test_retry_if_retryable_helper() {
use error_forge::AppError;
let calls = Arc::new(AtomicU32::new(0));
let c = calls.clone();
let result: Result<(), AppError> = fast_policy()
.max_attempts(5)
.run(
move || {
let c = c.clone();
async move {
let _ = c.fetch_add(1, Ordering::Relaxed);
Err(AppError::config("bad"))
}
},
retry_if_retryable,
)
.await;
assert!(result.is_err());
assert_eq!(
calls.load(Ordering::Relaxed),
1,
"config errors are not retryable"
);
}
}