use crate::backoff::BackoffStrategy;
use crate::sleep::Sleeper;
use core::fmt;
use rand::SeedableRng;
use rand::rngs::SmallRng;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RetryErrorKind {
Exhausted,
PredicateRejected,
}
#[derive(Debug, Clone)]
pub struct RetryError<E> {
kind: RetryErrorKind,
attempts: u8,
max_attempts: u8,
last_delay_ms: Option<u64>,
cause: Option<E>,
}
impl<E> RetryError<E> {
fn new(
kind: RetryErrorKind,
attempts: u8,
max_attempts: u8,
last_delay_ms: Option<u64>,
cause: Option<E>,
) -> Self {
Self {
kind,
attempts,
max_attempts,
last_delay_ms,
cause,
}
}
pub fn cause(&self) -> Option<&E> {
self.cause.as_ref()
}
pub fn into_cause(self) -> Option<E> {
self.cause
}
pub fn attempts(&self) -> u8 {
self.attempts
}
pub fn max_attempts(&self) -> u8 {
self.max_attempts
}
pub fn last_delay_ms(&self) -> Option<u64> {
self.last_delay_ms
}
pub fn kind(&self) -> RetryErrorKind {
self.kind
}
}
impl<E> fmt::Display for RetryError<E>
where
E: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.kind {
RetryErrorKind::Exhausted => {
write!(
f,
"retry exhausted after {} of {} attempts",
self.attempts, self.max_attempts
)?;
}
RetryErrorKind::PredicateRejected => {
write!(f, "retry aborted by predicate on attempt {}", self.attempts)?;
}
}
if let Some(delay) = self.last_delay_ms {
write!(f, " (last delay {}ms)", delay)?;
}
if let Some(cause) = self.cause.as_ref() {
write!(f, ": {}", cause)?;
}
Ok(())
}
}
#[cfg(feature = "std")]
impl<E> std::error::Error for RetryError<E> where E: std::error::Error {}
#[derive(Debug)]
pub struct RetryOutcome<T> {
value: T,
attempts: u8,
cumulative_delay_ms: u64,
}
impl<T> RetryOutcome<T> {
fn new(value: T, attempts: u8, cumulative_delay_ms: u64) -> Self {
Self {
value,
attempts,
cumulative_delay_ms,
}
}
pub fn attempts(&self) -> u8 {
self.attempts
}
pub fn cumulative_delay_ms(&self) -> u64 {
self.cumulative_delay_ms
}
pub fn value(&self) -> &T {
&self.value
}
pub fn into_inner(self) -> T {
self.value
}
}
pub trait Retryable<T, E> {
fn retry<B: BackoffStrategy>(self, backoff: B) -> RetryBuilder<Self, B, T, E, fn(&E) -> bool>
where
Self: Sized;
}
impl<F, T, E> Retryable<T, E> for F
where
F: FnMut() -> Result<T, E>,
{
fn retry<B: BackoffStrategy>(self, backoff: B) -> RetryBuilder<Self, B, T, E, fn(&E) -> bool> {
RetryBuilder {
operation: self,
backoff,
when: None,
notify: None,
on_success: None,
on_failure: None,
_phantom_t: core::marker::PhantomData,
_phantom_e: core::marker::PhantomData,
}
}
}
pub struct RetryBuilder<F, B, T, E, W> {
operation: F,
backoff: B,
when: Option<W>,
notify: Option<fn(&E, u8, u64)>,
on_success: Option<fn(&T, u8)>,
on_failure: Option<fn(&RetryError<E>)>,
_phantom_t: core::marker::PhantomData<T>,
_phantom_e: core::marker::PhantomData<E>,
}
impl<F, B, T, E, W> RetryBuilder<F, B, T, E, W>
where
F: FnMut() -> Result<T, E>,
B: BackoffStrategy,
W: Fn(&E) -> bool,
{
pub fn when<P>(self, predicate: P) -> RetryBuilder<F, B, T, E, P>
where
P: Fn(&E) -> bool,
{
RetryBuilder {
operation: self.operation,
backoff: self.backoff,
when: Some(predicate),
notify: self.notify,
on_success: self.on_success,
on_failure: self.on_failure,
_phantom_t: core::marker::PhantomData,
_phantom_e: core::marker::PhantomData,
}
}
pub fn notify(mut self, callback: fn(&E, u8, u64)) -> Self {
self.notify = Some(callback);
self
}
pub fn on_success(mut self, callback: fn(&T, u8)) -> Self {
self.on_success = Some(callback);
self
}
pub fn on_failure(mut self, callback: fn(&RetryError<E>)) -> Self {
self.on_failure = Some(callback);
self
}
#[cfg(feature = "std")]
pub fn call(self) -> Result<RetryOutcome<T>, RetryError<E>> {
use crate::sleep::StdSleeper;
self.call_with_sleeper(StdSleeper)
}
pub fn call_with_sleeper<S: Sleeper>(
mut self,
sleeper: S,
) -> Result<RetryOutcome<T>, RetryError<E>> {
let mut rng = SmallRng::from_os_rng();
let mut attempt = 1u8;
let max_attempts = self.backoff.max_attempts();
let mut cumulative_delay_ms: u64 = 0;
let mut last_delay_ms: Option<u64> = None;
loop {
match (self.operation)() {
Ok(value) => {
if let Some(callback) = self.on_success {
callback(&value, attempt);
}
return Ok(RetryOutcome::new(value, attempt, cumulative_delay_ms));
}
Err(error) => {
if let Some(ref predicate) = self.when {
if !predicate(&error) {
let retry_error = RetryError::new(
RetryErrorKind::PredicateRejected,
attempt,
max_attempts,
last_delay_ms,
Some(error),
);
if let Some(callback) = self.on_failure {
callback(&retry_error);
}
return Err(retry_error);
}
}
if !self.backoff.should_retry(attempt) {
let retry_error = RetryError::new(
RetryErrorKind::Exhausted,
attempt,
max_attempts,
last_delay_ms,
Some(error),
);
if let Some(callback) = self.on_failure {
callback(&retry_error);
}
return Err(retry_error);
}
match self.backoff.delay(attempt, &mut rng) {
Some(delay_ms) => {
if let Some(notify) = self.notify {
notify(&error, attempt, delay_ms);
}
sleeper.sleep_ms(delay_ms);
cumulative_delay_ms = cumulative_delay_ms.saturating_add(delay_ms);
last_delay_ms = Some(delay_ms);
attempt = attempt.saturating_add(1);
}
None => {
let retry_error = RetryError::new(
RetryErrorKind::Exhausted,
attempt,
max_attempts,
last_delay_ms,
Some(error),
);
if let Some(callback) = self.on_failure {
callback(&retry_error);
}
return Err(retry_error);
}
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backoff::{ConstantBackoff, ExponentialBackoff};
use crate::sleep::FnSleeper;
#[derive(Debug, PartialEq)]
enum TestError {
Retryable,
Fatal,
}
#[test]
fn test_retry_success_on_first_attempt() {
fn always_succeeds() -> Result<i32, TestError> {
Ok(42)
}
let result = always_succeeds
.retry(ExponentialBackoff::default())
.call_with_sleeper(FnSleeper(|_| {}));
let outcome = result.expect("retry should succeed");
assert_eq!(outcome.attempts(), 1);
assert_eq!(outcome.into_inner(), 42);
}
#[test]
fn test_retry_success_after_failures() {
use core::cell::Cell;
let attempts = Cell::new(0);
let operation = || {
let current = attempts.get();
attempts.set(current + 1);
if current < 2 {
Err(TestError::Retryable)
} else {
Ok(42)
}
};
let result = operation
.retry(ExponentialBackoff::default().max_attempts(3))
.call_with_sleeper(FnSleeper(|_| {}));
let outcome = result.expect("retry should succeed");
assert_eq!(outcome.attempts(), 3);
assert_eq!(outcome.into_inner(), 42);
assert_eq!(attempts.get(), 3);
}
#[test]
fn test_retry_exhausted() {
fn always_fails() -> Result<i32, TestError> {
Err(TestError::Retryable)
}
let result = always_fails
.retry(ExponentialBackoff::default().max_attempts(3))
.call_with_sleeper(FnSleeper(|_| {}));
let err = result.expect_err("retry should exhaust");
assert_eq!(err.kind(), RetryErrorKind::Exhausted);
assert_eq!(err.attempts(), 3);
assert_eq!(err.max_attempts(), 3);
assert!(err.last_delay_ms().is_some());
if let Some(cause) = err.cause() {
assert_eq!(cause, &TestError::Retryable);
} else {
panic!("expected underlying cause");
}
}
#[test]
fn test_retry_when_predicate() {
fn fails_with_fatal() -> Result<i32, TestError> {
Err(TestError::Fatal)
}
let result = fails_with_fatal
.retry(ExponentialBackoff::default())
.when(|e| matches!(e, TestError::Retryable))
.call_with_sleeper(FnSleeper(|_| {}));
let err = result.expect_err("retry should stop due to predicate");
assert_eq!(err.kind(), RetryErrorKind::PredicateRejected);
if let Some(cause) = err.cause() {
assert_eq!(cause, &TestError::Fatal);
} else {
panic!("expected underlying cause");
}
}
#[test]
fn test_retry_notify_callback() {
use core::cell::Cell;
let attempts = Cell::new(0);
let operation = || {
let current = attempts.get();
attempts.set(current + 1);
if current < 2 {
Err(TestError::Retryable)
} else {
Ok(42)
}
};
fn test_notify(_: &TestError, attempt: u8, _: u64) {
assert!(attempt >= 1);
}
let result = operation
.retry(ExponentialBackoff::default().max_attempts(3))
.notify(test_notify)
.call_with_sleeper(FnSleeper(|_| {}));
let outcome = result.expect("retry should succeed");
assert_eq!(outcome.attempts(), 3);
}
#[test]
fn test_on_success_callback_invoked() {
use core::cell::Cell;
use core::sync::atomic::{AtomicUsize, Ordering};
static SUCCESS_ATTEMPT: AtomicUsize = AtomicUsize::new(0);
fn on_success(_: &i32, attempt: u8) {
SUCCESS_ATTEMPT.store(attempt as usize, Ordering::SeqCst);
}
let attempts = Cell::new(0);
let operation = || {
let current = attempts.get();
attempts.set(current + 1);
if current < 1 {
Err(TestError::Retryable)
} else {
Ok(7)
}
};
SUCCESS_ATTEMPT.store(0, Ordering::SeqCst);
let outcome = operation
.retry(ExponentialBackoff::default().max_attempts(3))
.on_success(on_success)
.call_with_sleeper(FnSleeper(|_| {}))
.expect("retry should succeed");
assert_eq!(outcome.into_inner(), 7);
assert_eq!(SUCCESS_ATTEMPT.load(Ordering::SeqCst), 2);
}
#[test]
fn test_on_failure_callback_invoked() {
use core::sync::atomic::{AtomicUsize, Ordering};
static FAILURE_KIND: AtomicUsize = AtomicUsize::new(0);
fn on_failure(err: &RetryError<TestError>) {
let marker = match err.kind() {
RetryErrorKind::Exhausted => 1,
RetryErrorKind::PredicateRejected => 2,
};
FAILURE_KIND.store(marker, Ordering::SeqCst);
}
fn always_fails() -> Result<(), TestError> {
Err(TestError::Retryable)
}
FAILURE_KIND.store(0, Ordering::SeqCst);
let result = always_fails
.retry(ExponentialBackoff::default().max_attempts(2))
.on_failure(on_failure)
.call_with_sleeper(FnSleeper(|_| {}));
assert!(result.is_err());
assert_eq!(FAILURE_KIND.load(Ordering::SeqCst), 1);
}
#[test]
fn test_constant_backoff_retry() {
use core::cell::Cell;
let attempts = Cell::new(0);
let operation = || {
let current = attempts.get();
attempts.set(current + 1);
if current < 1 {
Err(TestError::Retryable)
} else {
Ok(42)
}
};
let result = operation
.retry(ConstantBackoff::new().delay_ms(10).max_attempts(2))
.call_with_sleeper(FnSleeper(|_| {}));
let outcome = result.expect("retry should succeed");
assert_eq!(outcome.attempts(), 2);
assert_eq!(outcome.into_inner(), 42);
assert_eq!(attempts.get(), 2);
}
#[cfg(feature = "std")]
#[test]
fn test_retry_with_std_sleeper() {
use core::cell::Cell;
let attempts = Cell::new(0);
let operation = || {
let current = attempts.get();
attempts.set(current + 1);
if current < 1 {
Err(TestError::Retryable)
} else {
Ok(42)
}
};
let start = std::time::Instant::now();
let result = operation
.retry(
ConstantBackoff::new()
.delay_ms(10)
.max_attempts(2)
.jitter_factor(0.0),
)
.call();
let elapsed = start.elapsed();
let outcome = result.expect("retry should succeed");
assert_eq!(outcome.attempts(), 2);
assert_eq!(outcome.into_inner(), 42);
assert!(elapsed.as_millis() >= 9); }
}