use std::fmt;
use std::future::Future;
use std::time::{Duration, Instant};
use qubit_common::BoxError;
use qubit_function::{BiConsumer, BiFunction, Consumer};
use crate::event::RetryListeners;
use crate::{
RetryAbortContext, RetryAttemptContext, RetryAttemptFailure, RetryConfigError, RetryContext,
RetryDecision, RetryError, RetryFailureContext, RetryOptions, RetrySuccessContext,
};
use crate::error::RetryDecider;
use crate::error::RetryFailureAction;
use crate::retry_executor_builder::RetryExecutorBuilder;
#[derive(Clone)]
pub struct RetryExecutor<E = BoxError> {
options: RetryOptions,
retry_decider: RetryDecider<E>,
listeners: RetryListeners<E>,
}
impl<E> RetryExecutor<E> {
#[inline]
pub fn builder() -> RetryExecutorBuilder<E> {
RetryExecutorBuilder::new()
}
pub fn from_options(options: RetryOptions) -> Result<Self, RetryConfigError> {
Self::builder().options(options).build()
}
#[inline]
pub fn options(&self) -> &RetryOptions {
&self.options
}
pub fn run<T, F>(&self, mut operation: F) -> Result<T, RetryError<E>>
where
F: FnMut() -> Result<T, E>,
{
let start = Instant::now();
let mut attempts = 0;
let mut last_failure = None;
loop {
if let Some(error) = self.take_elapsed_error(start, attempts, &mut last_failure) {
return Err(error);
}
attempts += 1;
match operation() {
Ok(value) => {
self.emit_success(attempts, start.elapsed());
return Ok(value);
}
Err(error) => {
let failure = RetryAttemptFailure::Error(error);
match self.handle_failure(attempts, start, failure) {
RetryFailureAction::Retry { delay, failure } => {
if !delay.is_zero() {
std::thread::sleep(delay);
}
last_failure = Some(failure);
}
RetryFailureAction::Finished(error) => return Err(error),
}
}
}
}
}
pub async fn run_async<T, F, Fut>(&self, mut operation: F) -> Result<T, RetryError<E>>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, E>>,
{
let start = Instant::now();
let mut attempts = 0;
let mut last_failure = None;
loop {
if let Some(error) = self.take_elapsed_error(start, attempts, &mut last_failure) {
return Err(error);
}
attempts += 1;
match operation().await {
Ok(value) => {
self.emit_success(attempts, start.elapsed());
return Ok(value);
}
Err(error) => {
let failure = RetryAttemptFailure::Error(error);
match self.handle_failure(attempts, start, failure) {
RetryFailureAction::Retry { delay, failure } => {
sleep_async(delay).await;
last_failure = Some(failure);
}
RetryFailureAction::Finished(error) => return Err(error),
}
}
}
}
}
pub async fn run_async_with_timeout<T, F, Fut>(
&self,
attempt_timeout: Duration,
mut operation: F,
) -> Result<T, RetryError<E>>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, E>>,
{
let start = Instant::now();
let mut attempts = 0;
let mut last_failure = None;
loop {
if let Some(error) = self.take_elapsed_error(start, attempts, &mut last_failure) {
return Err(error);
}
attempts += 1;
let attempt_start = Instant::now();
match tokio::time::timeout(attempt_timeout, operation()).await {
Ok(Ok(value)) => {
self.emit_success(attempts, start.elapsed());
return Ok(value);
}
Ok(Err(error)) => {
let failure = RetryAttemptFailure::Error(error);
match self.handle_failure(attempts, start, failure) {
RetryFailureAction::Retry { delay, failure } => {
sleep_async(delay).await;
last_failure = Some(failure);
}
RetryFailureAction::Finished(error) => return Err(error),
}
}
Err(_) => {
let failure = RetryAttemptFailure::AttemptTimeout {
elapsed: attempt_start.elapsed(),
timeout: attempt_timeout,
};
match self.handle_failure(attempts, start, failure) {
RetryFailureAction::Retry { delay, failure } => {
sleep_async(delay).await;
last_failure = Some(failure);
}
RetryFailureAction::Finished(error) => return Err(error),
}
}
}
}
}
pub(super) fn new(
options: RetryOptions,
retry_decider: RetryDecider<E>,
listeners: RetryListeners<E>,
) -> Self {
Self {
options,
retry_decider,
listeners,
}
}
fn handle_failure(
&self,
attempts: u32,
start: Instant,
failure: RetryAttemptFailure<E>,
) -> RetryFailureAction<E> {
let elapsed = start.elapsed();
let context = RetryAttemptContext {
attempt: attempts,
max_attempts: self.options.max_attempts.get(),
elapsed,
};
let decision = match &failure {
RetryAttemptFailure::Error(error) => self.retry_decider.apply(error, &context),
RetryAttemptFailure::AttemptTimeout { .. } => RetryDecision::Retry,
};
if decision == RetryDecision::Abort {
self.emit_abort(attempts, elapsed, &failure);
return RetryFailureAction::Finished(RetryError::Aborted {
attempts,
elapsed,
failure,
});
}
let max_attempts = self.options.max_attempts.get();
if attempts >= max_attempts {
let Some(failure) = self.emit_failure(attempts, elapsed, Some(failure)) else {
unreachable!("failure must exist when attempts exceed max_attempts");
};
return RetryFailureAction::Finished(RetryError::AttemptsExceeded {
attempts,
max_attempts,
elapsed,
last_failure: failure,
});
}
let base_delay = self.options.delay.base_delay(attempts);
let delay = self.options.jitter.apply(base_delay);
if let Some(max_elapsed) = self.options.max_elapsed {
if will_exceed_elapsed(start.elapsed(), delay, max_elapsed) {
let last_failure = self.emit_failure(attempts, elapsed, Some(failure));
let error = RetryError::MaxElapsedExceeded {
attempts,
elapsed,
max_elapsed,
last_failure,
};
return RetryFailureAction::Finished(error);
}
}
self.emit_retry(attempts, elapsed, delay, &failure);
RetryFailureAction::Retry { delay, failure }
}
fn take_elapsed_error(
&self,
start: Instant,
attempts: u32,
last_failure: &mut Option<RetryAttemptFailure<E>>,
) -> Option<RetryError<E>> {
let max_elapsed = self.options.max_elapsed?;
let elapsed = start.elapsed();
if elapsed < max_elapsed {
return None;
}
let last_failure = self.emit_failure(attempts, elapsed, last_failure.take());
Some(RetryError::MaxElapsedExceeded {
attempts,
elapsed,
max_elapsed,
last_failure,
})
}
fn emit_retry(
&self,
attempt: u32,
elapsed: Duration,
next_delay: Duration,
failure: &RetryAttemptFailure<E>,
) {
if let Some(listener) = &self.listeners.retry {
listener.accept(
&RetryContext {
attempt,
max_attempts: self.options.max_attempts.get(),
elapsed,
next_delay,
},
failure,
);
}
}
fn emit_success(&self, attempts: u32, elapsed: Duration) {
if let Some(listener) = &self.listeners.success {
listener.accept(&RetrySuccessContext { attempts, elapsed });
}
}
fn emit_failure(
&self,
attempts: u32,
elapsed: Duration,
failure: Option<RetryAttemptFailure<E>>,
) -> Option<RetryAttemptFailure<E>> {
if let Some(listener) = &self.listeners.failure {
listener.accept(&RetryFailureContext { attempts, elapsed }, &failure);
}
failure
}
fn emit_abort(&self, attempts: u32, elapsed: Duration, failure: &RetryAttemptFailure<E>) {
if let Some(listener) = &self.listeners.abort {
listener.accept(&RetryAbortContext { attempts, elapsed }, failure);
}
}
}
impl<E> fmt::Debug for RetryExecutor<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RetryExecutor")
.field("options", &self.options)
.finish_non_exhaustive()
}
}
fn will_exceed_elapsed(elapsed: Duration, delay: Duration, max_elapsed: Duration) -> bool {
elapsed
.checked_add(delay)
.map_or(true, |next_elapsed| next_elapsed >= max_elapsed)
}
async fn sleep_async(delay: Duration) {
if !delay.is_zero() {
tokio::time::sleep(delay).await;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_will_exceed_elapsed_handles_boundaries_and_overflow() {
let one_ms = Duration::from_millis(1);
let two_ms = Duration::from_millis(2);
assert!(!will_exceed_elapsed(one_ms, Duration::ZERO, two_ms));
assert!(will_exceed_elapsed(one_ms, one_ms, two_ms));
assert!(will_exceed_elapsed(Duration::MAX, one_ms, Duration::MAX));
}
#[tokio::test]
async fn test_sleep_async_handles_zero_and_nonzero_delays() {
sleep_async(Duration::ZERO).await;
sleep_async(Duration::from_millis(1)).await;
}
#[test]
fn test_box_error_executor_runs_success_and_failure_paths() {
let executor: RetryExecutor<BoxError> = RetryExecutor::builder()
.max_attempts(1)
.delay(crate::RetryDelay::none())
.build()
.expect("executor should be built");
assert_eq!(executor.options().max_attempts.get(), 1);
assert!(format!("{executor:?}").contains("RetryExecutor"));
let value = executor
.run(|| Ok::<_, BoxError>("default-box-error"))
.expect("boxed-error executor should return success");
assert_eq!(value, "default-box-error");
let error = executor
.run(|| -> Result<(), BoxError> {
let source = std::io::Error::new(std::io::ErrorKind::Other, "boxed failure");
Err(Box::new(source) as BoxError)
})
.expect_err("single failed attempt should exceed attempts");
assert_eq!(error.attempts(), 1);
assert_eq!(
error.last_error().map(ToString::to_string).as_deref(),
Some("boxed failure")
);
}
}