use std::fmt;
#[cfg(feature = "tokio")]
use std::future::Future;
#[cfg(feature = "tokio")]
use std::pin::Pin;
use std::time::{Duration, Instant};
use qubit_common::BoxError;
use qubit_function::{BiConsumer, BiFunction, Consumer};
use crate::event::RetryListeners;
use crate::{
AttemptFailure, AttemptFailureDecision, RetryAfterHint, RetryBuilder, RetryConfigError,
RetryContext, RetryError, RetryErrorReason, RetryOptions,
};
#[derive(Clone)]
pub struct Retry<E = BoxError> {
options: RetryOptions,
attempt_timeout: Option<Duration>,
retry_after_hint: Option<RetryAfterHint<E>>,
isolate_listener_panics: bool,
listeners: RetryListeners<E>,
}
impl<E> Retry<E> {
#[inline]
pub fn builder() -> RetryBuilder<E> {
RetryBuilder::new()
}
pub fn from_options(options: RetryOptions) -> Result<Self, RetryConfigError> {
Self::builder().options(options).build()
}
#[inline]
pub fn options(&self) -> &RetryOptions {
&self.options
}
pub fn run<T, F>(&self, mut operation: F) -> Result<T, RetryError<E>>
where
F: FnMut() -> Result<T, E>,
{
let mut operation = SyncValueOperation::new(&mut operation);
self.run_sync_operation(&mut operation)?;
Ok(operation.into_value())
}
fn run_sync_operation(&self, operation: &mut dyn SyncAttempt<E>) -> Result<(), RetryError<E>> {
let start = Instant::now();
let mut attempts = 0;
let mut last_failure = None;
loop {
if let Some(error) = self.elapsed_error(start, attempts, last_failure.take(), None) {
return Err(self.emit_error(error));
}
attempts += 1;
let before_context = self.context(start, attempts, Duration::ZERO, None);
self.emit_before_attempt(&before_context);
let attempt_start = Instant::now();
match operation.call() {
Ok(()) => {
let context = self.context(start, attempts, attempt_start.elapsed(), None);
self.emit_attempt_success(&context);
return Ok(());
}
Err(failure) => {
let context = self.context(start, attempts, attempt_start.elapsed(), None);
match self.handle_failure(start, attempts, failure, context) {
RetryFlowAction::Retry { delay, failure } => {
if !delay.is_zero() {
std::thread::sleep(delay);
}
last_failure = Some(failure);
}
RetryFlowAction::Finished(error) => return Err(self.emit_error(error)),
}
}
}
}
}
#[cfg(feature = "tokio")]
pub async fn run_async<T, F, Fut>(&self, mut operation: F) -> Result<T, RetryError<E>>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, E>>,
{
let mut operation = AsyncValueOperation::new(&mut operation);
self.run_async_operation(&mut operation).await?;
Ok(operation.into_value())
}
#[cfg(feature = "tokio")]
async fn run_async_operation(
&self,
operation: &mut dyn AsyncAttempt<E>,
) -> Result<(), RetryError<E>> {
let start = Instant::now();
let mut attempts = 0;
let mut last_failure = None;
loop {
if let Some(error) =
self.elapsed_error(start, attempts, last_failure.take(), self.attempt_timeout)
{
return Err(self.emit_error(error));
}
attempts += 1;
let before_context =
self.context(start, attempts, Duration::ZERO, self.attempt_timeout);
self.emit_before_attempt(&before_context);
let attempt_start = Instant::now();
let result = if let Some(timeout) = self.attempt_timeout {
match tokio::time::timeout(timeout, operation.call()).await {
Ok(result) => result,
Err(_) => Err(AttemptFailure::Timeout),
}
} else {
operation.call().await
};
let context = self.context(
start,
attempts,
attempt_start.elapsed(),
self.attempt_timeout,
);
match result {
Ok(()) => {
self.emit_attempt_success(&context);
return Ok(());
}
Err(failure) => match self.handle_failure(start, attempts, failure, context) {
RetryFlowAction::Retry { delay, failure } => {
sleep_async(delay).await;
last_failure = Some(failure);
}
RetryFlowAction::Finished(error) => return Err(self.emit_error(error)),
},
}
}
}
pub(super) fn new(
options: RetryOptions,
attempt_timeout: Option<Duration>,
retry_after_hint: Option<RetryAfterHint<E>>,
isolate_listener_panics: bool,
listeners: RetryListeners<E>,
) -> Self {
Self {
options,
attempt_timeout,
retry_after_hint,
isolate_listener_panics,
listeners,
}
}
fn context(
&self,
start: Instant,
attempt: u32,
attempt_elapsed: Duration,
attempt_timeout: Option<Duration>,
) -> RetryContext {
RetryContext::new(
attempt,
self.options.max_attempts.get(),
self.options.max_elapsed,
start.elapsed(),
attempt_elapsed,
attempt_timeout,
)
}
fn handle_failure(
&self,
start: Instant,
attempts: u32,
failure: AttemptFailure<E>,
context: RetryContext,
) -> RetryFlowAction<E> {
let hint = self
.retry_after_hint
.as_ref()
.and_then(|hint| hint.apply(&failure, &context));
let context = context.with_retry_after_hint(hint);
let decision = self.failure_decision(&failure, &context);
if decision == AttemptFailureDecision::Abort {
return RetryFlowAction::Finished(RetryError::new(
RetryErrorReason::Aborted,
Some(failure),
context,
));
}
let max_attempts = self.options.max_attempts.get();
if attempts >= max_attempts {
return RetryFlowAction::Finished(RetryError::new(
RetryErrorReason::AttemptsExceeded,
Some(failure),
context,
));
}
let delay = self.retry_delay(decision, attempts, hint);
let context = context.with_next_delay(delay);
if let Some(max_elapsed) = self.options.max_elapsed
&& will_exceed_elapsed(start.elapsed(), delay, max_elapsed)
{
return RetryFlowAction::Finished(RetryError::new(
RetryErrorReason::MaxElapsedExceeded,
Some(failure),
context,
));
}
RetryFlowAction::Retry { delay, failure }
}
fn failure_decision(
&self,
failure: &AttemptFailure<E>,
context: &RetryContext,
) -> AttemptFailureDecision {
let mut decision = AttemptFailureDecision::UseDefault;
for listener in &self.listeners.failure {
let current = self.invoke_listener(|| listener.apply(failure, context));
if current != AttemptFailureDecision::UseDefault {
decision = current;
}
}
decision
}
fn retry_delay(
&self,
decision: AttemptFailureDecision,
attempts: u32,
hint: Option<Duration>,
) -> Duration {
match decision {
AttemptFailureDecision::RetryAfter(delay) => delay,
AttemptFailureDecision::UseDefault => hint.unwrap_or_else(|| {
self.options
.jitter
.delay_for_attempt(&self.options.delay, attempts)
}),
AttemptFailureDecision::Retry | AttemptFailureDecision::Abort => self
.options
.jitter
.delay_for_attempt(&self.options.delay, attempts),
}
}
fn elapsed_error(
&self,
start: Instant,
attempts: u32,
last_failure: Option<AttemptFailure<E>>,
attempt_timeout: Option<Duration>,
) -> Option<RetryError<E>> {
let max_elapsed = self.options.max_elapsed?;
let elapsed = start.elapsed();
if elapsed < max_elapsed {
return None;
}
Some(RetryError::new(
RetryErrorReason::MaxElapsedExceeded,
last_failure,
self.context(start, attempts, Duration::ZERO, attempt_timeout),
))
}
fn emit_before_attempt(&self, context: &RetryContext) {
for listener in &self.listeners.before_attempt {
self.invoke_listener(|| {
listener.accept(context);
});
}
}
fn emit_attempt_success(&self, context: &RetryContext) {
for listener in &self.listeners.attempt_success {
self.invoke_listener(|| {
listener.accept(context);
});
}
}
fn emit_error(&self, error: RetryError<E>) -> RetryError<E> {
for listener in &self.listeners.error {
self.invoke_listener(|| {
listener.accept(&error, error.context());
});
}
error
}
fn invoke_listener<R>(&self, call: impl FnOnce() -> R) -> R
where
R: Default,
{
if self.isolate_listener_panics {
std::panic::catch_unwind(std::panic::AssertUnwindSafe(call)).unwrap_or_default()
} else {
call()
}
}
}
impl<E> fmt::Debug for Retry<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Retry")
.field("options", &self.options)
.field("attempt_timeout", &self.attempt_timeout)
.finish_non_exhaustive()
}
}
trait SyncAttempt<E> {
fn call(&mut self) -> Result<(), AttemptFailure<E>>;
}
struct SyncValueOperation<T, F> {
operation: F,
value: Option<T>,
}
impl<T, F> SyncValueOperation<T, F> {
fn new(operation: F) -> Self {
Self {
operation,
value: None,
}
}
fn into_value(self) -> T {
self.value
.expect("retry loop succeeded without an operation value")
}
}
impl<T, E, F> SyncAttempt<E> for SyncValueOperation<T, F>
where
F: FnMut() -> Result<T, E>,
{
fn call(&mut self) -> Result<(), AttemptFailure<E>> {
match (self.operation)() {
Ok(value) => {
self.value = Some(value);
Ok(())
}
Err(error) => Err(AttemptFailure::Error(error)),
}
}
}
#[cfg(feature = "tokio")]
type AsyncAttemptFuture<'a, E> = Pin<Box<dyn Future<Output = Result<(), AttemptFailure<E>>> + 'a>>;
#[cfg(feature = "tokio")]
trait AsyncAttempt<E> {
fn call(&mut self) -> AsyncAttemptFuture<'_, E>;
}
#[cfg(feature = "tokio")]
struct AsyncValueOperation<T, F> {
operation: F,
value: Option<T>,
}
#[cfg(feature = "tokio")]
impl<T, F> AsyncValueOperation<T, F> {
fn new(operation: F) -> Self {
Self {
operation,
value: None,
}
}
fn into_value(self) -> T {
self.value
.expect("retry loop succeeded without an operation value")
}
}
#[cfg(feature = "tokio")]
impl<T, E, F, Fut> AsyncAttempt<E> for AsyncValueOperation<T, F>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, E>>,
{
fn call(&mut self) -> AsyncAttemptFuture<'_, E> {
Box::pin(async move {
match (self.operation)().await {
Ok(value) => {
self.value = Some(value);
Ok(())
}
Err(error) => Err(AttemptFailure::Error(error)),
}
})
}
}
enum RetryFlowAction<E> {
Retry {
delay: Duration,
failure: AttemptFailure<E>,
},
Finished(RetryError<E>),
}
fn will_exceed_elapsed(elapsed: Duration, delay: Duration, max_elapsed: Duration) -> bool {
elapsed
.checked_add(delay)
.is_none_or(|next_elapsed| next_elapsed >= max_elapsed)
}
#[cfg(feature = "tokio")]
async fn sleep_async(delay: Duration) {
if !delay.is_zero() {
tokio::time::sleep(delay).await;
}
}