newton-core 0.4.16

newton protocol core sdk
use std::{future::Future, time::Duration};

/// Configuration for retry with exponential backoff.
#[derive(Debug, Clone)]
pub struct RetryConfig {
    /// Maximum number of retry attempts (0 = no retries, just one attempt).
    pub max_retries: u32,
    /// Initial backoff duration before the first retry.
    pub initial_backoff: Duration,
    /// Maximum backoff duration (cap).
    pub max_backoff: Duration,
    /// Backoff multiplier applied after each retry (typically 2).
    pub backoff_multiplier: u32,
}

impl Default for RetryConfig {
    fn default() -> Self {
        Self {
            max_retries: 3,
            initial_backoff: Duration::from_millis(100),
            max_backoff: Duration::from_secs(5),
            backoff_multiplier: 2,
        }
    }
}

/// Retry an async operation with exponential backoff.
///
/// Calls `operation` up to `max_retries + 1` times. On failure, `is_retryable`
/// determines whether to retry or return the error immediately. Backoff doubles
/// (by default) between retries, capped at `max_backoff`.
///
/// # Arguments
///
/// * `config` - Retry parameters (max retries, backoff timing)
/// * `operation_name` - Label for tracing logs
/// * `operation` - Async closure producing `Result<T, E>`
/// * `is_retryable` - Predicate: returns `true` if the error is transient
pub async fn retry_with_backoff<T, E, F, Fut, C>(
    config: &RetryConfig,
    operation_name: &str,
    mut operation: F,
    is_retryable: C,
) -> Result<T, E>
where
    F: FnMut() -> Fut,
    Fut: Future<Output = Result<T, E>>,
    C: Fn(&E) -> bool,
    E: std::fmt::Display,
{
    let mut backoff = config.initial_backoff;

    for attempt in 0..=config.max_retries {
        match operation().await {
            Ok(val) => {
                if attempt > 0 {
                    tracing::debug!(operation = operation_name, attempt, "succeeded after retry");
                }
                return Ok(val);
            }
            Err(e) if attempt < config.max_retries && is_retryable(&e) => {
                tracing::warn!(
                    operation = operation_name,
                    error = %e,
                    attempt = attempt + 1,
                    max_retries = config.max_retries,
                    backoff_ms = backoff.as_millis() as u64,
                    "retryable error, backing off"
                );
                tokio::time::sleep(backoff).await;
                backoff = std::cmp::min(backoff * config.backoff_multiplier, config.max_backoff);
            }
            Err(e) => return Err(e),
        }
    }

    unreachable!("loop runs max_retries + 1 times and always returns")
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::sync::{
        atomic::{AtomicU32, Ordering},
        Arc,
    };

    #[tokio::test]
    async fn succeeds_on_first_attempt() {
        let config = RetryConfig {
            max_retries: 3,
            ..Default::default()
        };
        let result: Result<u32, String> = retry_with_backoff(&config, "test", || async { Ok(42) }, |_| true).await;
        assert_eq!(result.unwrap(), 42);
    }

    #[tokio::test]
    async fn retries_on_transient_error() {
        let attempts = Arc::new(AtomicU32::new(0));
        let attempts_clone = attempts.clone();

        let config = RetryConfig {
            max_retries: 3,
            initial_backoff: Duration::from_millis(1),
            max_backoff: Duration::from_millis(10),
            backoff_multiplier: 2,
        };

        let result: Result<u32, String> = retry_with_backoff(
            &config,
            "test",
            || {
                let a = attempts_clone.clone();
                async move {
                    let n = a.fetch_add(1, Ordering::SeqCst);
                    if n < 2 {
                        Err("transient".to_string())
                    } else {
                        Ok(99)
                    }
                }
            },
            |_| true,
        )
        .await;

        assert_eq!(result.unwrap(), 99);
        assert_eq!(attempts.load(Ordering::SeqCst), 3);
    }

    #[tokio::test]
    async fn fails_immediately_on_non_retryable() {
        let attempts = Arc::new(AtomicU32::new(0));
        let attempts_clone = attempts.clone();

        let config = RetryConfig {
            max_retries: 5,
            initial_backoff: Duration::from_millis(1),
            ..Default::default()
        };

        let result: Result<u32, String> = retry_with_backoff(
            &config,
            "test",
            || {
                let a = attempts_clone.clone();
                async move {
                    a.fetch_add(1, Ordering::SeqCst);
                    Err("permanent".to_string())
                }
            },
            |_| false, // nothing is retryable
        )
        .await;

        assert_eq!(result.unwrap_err(), "permanent");
        assert_eq!(attempts.load(Ordering::SeqCst), 1);
    }

    #[tokio::test]
    async fn exhausts_retries_and_returns_last_error() {
        let attempts = Arc::new(AtomicU32::new(0));
        let attempts_clone = attempts.clone();

        let config = RetryConfig {
            max_retries: 2,
            initial_backoff: Duration::from_millis(1),
            max_backoff: Duration::from_millis(5),
            backoff_multiplier: 2,
        };

        let result: Result<u32, String> = retry_with_backoff(
            &config,
            "test",
            || {
                let a = attempts_clone.clone();
                async move {
                    let n = a.fetch_add(1, Ordering::SeqCst);
                    Err(format!("fail-{}", n))
                }
            },
            |_| true,
        )
        .await;

        // Last attempt error is returned (attempt index 2)
        assert_eq!(result.unwrap_err(), "fail-2");
        // 1 initial + 2 retries = 3 total attempts
        assert_eq!(attempts.load(Ordering::SeqCst), 3);
    }
}