altair-retry 0.3.0

Async retry with exponential backoff, auto-traced via the tracing crate
Documentation
//! The `retry` entry point.

use crate::config::Config;
use crate::error::{Error, PermanentError, Result};
use backon::{BackoffBuilder, ExponentialBuilder};
use std::future::Future;
use std::time::Instant;
use tracing::{Instrument, info_span};

type BoxedError = Box<dyn std::error::Error + Send + Sync>;

/// Run `op` with retry per `config`.
///
/// On success, returns the value. On error, retries with exponential backoff.
/// If `op` returns an error that downcasts to [`PermanentError`], retry stops
/// immediately and the wrapped error is returned via [`Error::Permanent`].
#[allow(
    clippy::cast_possible_truncation,
    clippy::cast_sign_loss,
    clippy::too_many_lines
)]
pub async fn retry<T, E, F, Fut>(config: Config, mut op: F) -> Result<T>
where
    F: FnMut() -> Fut + Send,
    Fut: Future<Output = std::result::Result<T, E>> + Send,
    E: Into<BoxedError>,
    T: Send,
{
    let mut backoff_builder = ExponentialBuilder::default()
        .with_min_delay(config.initial_interval)
        .with_max_delay(config.max_interval)
        .with_factor(config.multiplier as f32)
        .with_max_times(config.max_retries as usize);
    if config.jitter {
        backoff_builder = backoff_builder.with_jitter();
    }
    let mut delays = backoff_builder.build();

    let start = Instant::now();
    let mut attempt: u32 = 0;

    let span = info_span!(
        "retry",
        retry.name = %config.name,
        retry.max_attempts = config.max_retries + 1,
    );

    let cancellation = config.cancellation_token.clone();

    async {
        loop {
            if let Some(token) = &cancellation
                && token.is_cancelled()
            {
                tracing::warn!(
                    retry.outcome = "cancelled",
                    retry.attempts = attempt,
                    "retry cancelled before attempt",
                );
                return Err(Error::Cancelled {
                    name: config.name.clone(),
                });
            }

            attempt += 1;
            let attempt_span = info_span!("retry.attempt", retry.attempt = attempt,);

            let outcome = op().instrument(attempt_span).await;

            match outcome {
                Ok(v) => {
                    tracing::info!(
                        retry.outcome = "success",
                        retry.attempts = attempt,
                        retry.elapsed_ms = start.elapsed().as_millis() as u64,
                        "retry succeeded",
                    );
                    return Ok(v);
                }
                Err(e) => {
                    let boxed: BoxedError = e.into();
                    let boxed = match boxed.downcast::<PermanentError>() {
                        Ok(perm) => {
                            tracing::warn!(
                                retry.outcome = "permanent",
                                retry.attempts = attempt,
                                "permanent error encountered",
                            );
                            return Err(Error::Permanent {
                                name: config.name.clone(),
                                source: perm.inner,
                            });
                        }
                        Err(other) => other,
                    };

                    if let Some(delay) = delays.next() {
                        tracing::debug!(
                            retry.attempt = attempt,
                            retry.delay_ms = delay.as_millis() as u64,
                            "retrying after backoff",
                        );
                        if let Some(token) = &cancellation {
                            tokio::select! {
                                () = tokio::time::sleep(delay) => {}
                                () = token.cancelled() => {
                                    tracing::warn!(
                                        retry.outcome = "cancelled",
                                        retry.attempts = attempt,
                                        "retry cancelled during backoff",
                                    );
                                    return Err(Error::Cancelled {
                                        name: config.name.clone(),
                                    });
                                }
                            }
                        } else {
                            tokio::time::sleep(delay).await;
                        }
                    } else {
                        tracing::warn!(
                            retry.outcome = "exhausted",
                            retry.attempts = attempt,
                            "retry attempts exhausted",
                        );
                        return Err(Error::Exhausted {
                            name: config.name.clone(),
                            attempts: attempt,
                            source: boxed,
                        });
                    }
                }
            }
        }
    }
    .instrument(span)
    .await
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::sync::Arc;
    use std::sync::atomic::{AtomicU32, Ordering};
    use std::time::Duration;

    fn fast_config(name: &str) -> Config {
        Config::builder()
            .name(name)
            .max_retries(3)
            .initial_interval(Duration::from_millis(1))
            .max_interval(Duration::from_millis(10))
            .jitter(false)
            .build()
    }

    #[tokio::test]
    async fn succeeds_on_first_try() {
        let r: Result<u32> =
            retry(fast_config("ok"), || async { Ok::<_, std::io::Error>(42) }).await;
        assert_eq!(r.unwrap(), 42);
    }

    #[tokio::test]
    async fn retries_then_succeeds() {
        let counter = Arc::new(AtomicU32::new(0));
        let c = counter.clone();
        let r: Result<u32> = retry(fast_config("flaky"), move || {
            let c = c.clone();
            async move {
                let n = c.fetch_add(1, Ordering::SeqCst) + 1;
                if n < 3 {
                    Err::<u32, _>(std::io::Error::other("flake"))
                } else {
                    Ok(n)
                }
            }
        })
        .await;
        assert_eq!(r.unwrap(), 3);
        assert_eq!(counter.load(Ordering::SeqCst), 3);
    }

    #[tokio::test]
    async fn exhausts_with_attempts_count() {
        let r: Result<u32> = retry(fast_config("always_fail"), || async {
            Err::<u32, _>(std::io::Error::other("nope"))
        })
        .await;
        match r {
            Err(Error::Exhausted { attempts, .. }) => assert_eq!(attempts, 4),
            other => panic!("expected Exhausted, got {other:?}"),
        }
    }

    #[tokio::test]
    async fn cancellation_token_pre_check_aborts() {
        use tokio_util::sync::CancellationToken;
        let token = CancellationToken::new();
        token.cancel();
        let cfg = Config::builder()
            .name("pre_cancelled")
            .max_retries(3)
            .initial_interval(Duration::from_millis(1))
            .jitter(false)
            .cancellation_token(token)
            .build();
        let r: Result<u32> = retry(cfg, || async {
            Err::<u32, _>(std::io::Error::other("never reached"))
        })
        .await;
        assert!(matches!(r, Err(Error::Cancelled { .. })));
    }

    #[tokio::test]
    async fn cancellation_during_backoff_aborts() {
        use tokio_util::sync::CancellationToken;
        let token = CancellationToken::new();
        let inner = token.clone();
        let cfg = Config::builder()
            .name("backoff_cancel")
            .max_retries(5)
            .initial_interval(Duration::from_millis(50))
            .jitter(false)
            .cancellation_token(inner)
            .build();
        tokio::spawn(async move {
            tokio::time::sleep(Duration::from_millis(10)).await;
            token.cancel();
        });
        let r: Result<u32> = retry(cfg, || async {
            Err::<u32, _>(std::io::Error::other("transient"))
        })
        .await;
        assert!(matches!(r, Err(Error::Cancelled { .. })));
    }

    #[tokio::test]
    async fn permanent_error_short_circuits() {
        let counter = Arc::new(AtomicU32::new(0));
        let c = counter.clone();
        let r: Result<u32> = retry(fast_config("permanent"), move || {
            let c = c.clone();
            async move {
                c.fetch_add(1, Ordering::SeqCst);
                Err::<u32, _>(PermanentError::wrap("invalid creds"))
            }
        })
        .await;
        assert!(matches!(r, Err(Error::Permanent { .. })));
        assert_eq!(
            counter.load(Ordering::SeqCst),
            1,
            "should not retry on permanent"
        );
    }
}