use std::{num::NonZeroU32, time::Duration};
use futures::Future;
use thiserror::Error;
use tracing::error;
#[derive(Debug, Clone, Copy)]
pub enum RetryStrategy {
Immediate { max_retries: NonZeroU32 },
After {
max_retries: NonZeroU32,
duration: Duration,
},
Exponential {
min_duration: Duration,
max_duration: Duration,
},
}
impl Default for RetryStrategy {
fn default() -> Self {
Self::Immediate {
max_retries: NonZeroU32::new(3).unwrap(),
}
}
}
type DynTracer<E> = Box<dyn Fn(E) + Send + Sync>;
impl RetryStrategy {
pub async fn retry<O, E, Fut, F>(self, f: F) -> std::result::Result<O, E>
where
E: std::fmt::Debug,
Fut: Future<Output = std::result::Result<O, E>>,
F: Fn() -> Fut,
{
retry_with_strategy(f, self, None as Option<DynTracer<E>>).await
}
pub async fn retry_trace<O, E, Fut, F, T>(self, f: F, tracer: T) -> std::result::Result<O, E>
where
E: std::fmt::Debug,
Fut: Future<Output = std::result::Result<O, E>>,
F: Fn() -> Fut,
T: Fn(E),
{
retry_with_strategy(f, self, Some(tracer)).await
}
pub fn into_backoff(self) -> anyhow::Result<backoff::ExponentialBackoff> {
match self {
Self::Exponential {
min_duration,
max_duration,
} => {
let mut backoff = backoff::ExponentialBackoffBuilder::new();
backoff.with_initial_interval(min_duration);
backoff.with_max_elapsed_time(Some(max_duration));
Ok(backoff.build())
}
_ => anyhow::bail!("retry strategy is not exponential"),
}
}
}
impl TryFrom<RetryStrategy> for backoff::ExponentialBackoff {
type Error = anyhow::Error;
fn try_from(value: RetryStrategy) -> anyhow::Result<Self, Self::Error> {
value.into_backoff()
}
}
#[derive(Debug, Default, Clone, Copy)]
pub enum FatalStrategy {
#[default]
Terminate,
Ignore,
}
#[derive(Error, Debug)]
pub enum OperationError {
#[error("Transient operation error: {err}")]
Transient {
err: anyhow::Error,
retry_strategy: RetryStrategy,
fatal_strategy: FatalStrategy,
},
#[error("Fatal operation error: {err}")]
Fatal {
err: anyhow::Error,
strategy: FatalStrategy,
},
}
impl OperationError {
async fn retry_internal<O, Fut, F, T>(self, f: F, tracer: Option<T>) -> Result<O>
where
Fut: Future<Output = Result<O>>,
F: Fn() -> Fut,
T: Fn(OperationError),
{
match self {
Self::Transient {
retry_strategy,
fatal_strategy,
..
} => {
let result = retry_with_strategy(f, retry_strategy, tracer).await;
result.map_err(|err| Self::Fatal {
err: err.into_err(),
strategy: fatal_strategy,
})
}
_ => Err(self),
}
}
pub async fn retry<O, Fut, F>(self, f: F) -> Result<O>
where
Fut: Future<Output = Result<O>>,
F: Fn() -> Fut,
{
self.retry_internal(f, None as Option<DynTracer<OperationError>>)
.await
}
pub async fn retry_trace<O, Fut, F, T>(self, f: F, tracer: T) -> Result<O>
where
Fut: Future<Output = Result<O>>,
F: Fn() -> Fut,
T: Fn(OperationError),
{
self.retry_internal(f, Some(tracer)).await
}
pub fn into_err(self) -> anyhow::Error {
match self {
Self::Transient { err, .. } => err,
Self::Fatal { err, .. } => err,
}
}
pub fn as_err(&self) -> &anyhow::Error {
match self {
Self::Transient { err, .. } => err,
Self::Fatal { err, .. } => err,
}
}
pub fn into_fatal(self) -> Self {
match self {
Self::Transient {
err,
fatal_strategy,
..
} => Self::Fatal {
err,
strategy: fatal_strategy,
},
_ => self,
}
}
pub fn fatal_strategy(&self) -> FatalStrategy {
match self {
Self::Transient { fatal_strategy, .. } => *fatal_strategy,
Self::Fatal { strategy, .. } => *strategy,
}
}
}
#[derive(Debug)]
pub struct TransientError {
err: anyhow::Error,
retry_strategy: RetryStrategy,
fatal_strategy: FatalStrategy,
}
impl TransientError {
pub fn new(
err: impl std::error::Error + Send + Sync + 'static,
retry_strategy: RetryStrategy,
fatal_strategy: FatalStrategy,
) -> Self {
Self {
err: err.into(),
retry_strategy,
fatal_strategy,
}
}
pub fn from_anyhow(
err: anyhow::Error,
retry_strategy: RetryStrategy,
fatal_strategy: FatalStrategy,
) -> Self {
Self {
err,
retry_strategy,
fatal_strategy,
}
}
pub fn from_str(
err: &str,
retry_strategy: RetryStrategy,
fatal_strategy: FatalStrategy,
) -> Self {
Self {
err: anyhow::Error::msg(err.to_string()),
retry_strategy,
fatal_strategy,
}
}
}
impl<E> From<E> for TransientError
where
E: std::error::Error + Send + Sync + 'static,
{
fn from(value: E) -> Self {
Self {
err: value.into(),
retry_strategy: RetryStrategy::default(),
fatal_strategy: FatalStrategy::default(),
}
}
}
impl From<TransientError> for OperationError {
fn from(value: TransientError) -> Self {
Self::Transient {
err: value.err,
retry_strategy: value.retry_strategy,
fatal_strategy: value.fatal_strategy,
}
}
}
impl<T> From<TransientError> for Result<T> {
fn from(value: TransientError) -> Self {
Err(value.into())
}
}
#[derive(Debug)]
pub struct FatalError {
err: anyhow::Error,
strategy: FatalStrategy,
}
impl FatalError {
pub fn new(
err: impl std::error::Error + Send + Sync + 'static,
strategy: FatalStrategy,
) -> Self {
Self {
err: err.into(),
strategy,
}
}
pub fn from_anyhow(err: anyhow::Error, strategy: FatalStrategy) -> Self {
Self { err, strategy }
}
pub fn from_str(err: &str, strategy: FatalStrategy) -> Self {
Self {
err: anyhow::Error::msg(err.to_string()),
strategy,
}
}
}
impl<E> From<E> for FatalError
where
E: std::error::Error + Send + Sync + 'static,
{
fn from(value: E) -> Self {
Self {
err: value.into(),
strategy: FatalStrategy::default(),
}
}
}
impl From<FatalError> for OperationError {
fn from(value: FatalError) -> Self {
Self::Fatal {
err: value.err,
strategy: value.strategy,
}
}
}
impl<T> From<FatalError> for Result<T> {
fn from(value: FatalError) -> Self {
Err(value.into())
}
}
pub type Result<T> = std::result::Result<T, OperationError>;
async fn retry_simple<O, E, Fut, F, T>(
f: F,
max_retries: NonZeroU32,
duration: Option<Duration>,
tracer: Option<T>,
) -> std::result::Result<O, E>
where
E: std::fmt::Debug,
Fut: Future<Output = std::result::Result<O, E>>,
F: Fn() -> Fut,
T: Fn(E),
{
let mut num_retries = 0;
let mut result = f().await;
while let Err(err) = result {
if num_retries >= max_retries.get() {
return Err(err);
}
if let Some(tracer) = tracer.as_ref() {
tracer(err);
}
num_retries += 1;
if let Some(duration) = duration {
tokio::time::sleep(duration).await;
}
result = f().await;
}
Ok(result.unwrap())
}
async fn retry_with_strategy<O, E, Fut, F, T>(
f: F,
strategy: RetryStrategy,
tracer: Option<T>,
) -> std::result::Result<O, E>
where
E: std::fmt::Debug,
Fut: Future<Output = std::result::Result<O, E>>,
F: Fn() -> Fut,
T: Fn(E),
{
match strategy {
RetryStrategy::Immediate { max_retries } => {
retry_simple(f, max_retries, None, tracer).await
}
RetryStrategy::After {
max_retries,
duration,
} => retry_simple(f, max_retries, Some(duration), tracer).await,
exp @ RetryStrategy::Exponential { .. } => {
let backoff = exp.into_backoff().unwrap();
let result = backoff::future::retry_notify(
backoff,
|| async { Ok(f().await?) },
|err, _| {
if let Some(t) = tracer.as_ref() {
t(err)
}
},
)
.await?;
Ok(result)
}
}
}