mx-core 0.1.0

Core utilities for MultiversX Rust services.
Documentation
//! Retry policy with exponential backoff and jitter support.

use std::future::Future;
use std::time::Duration;

use rand::Rng;
use tokio::time::sleep;

const DEFAULT_MAX_DELAY: Duration = Duration::from_secs(30);
const DEFAULT_JITTER_FACTOR: f64 = 0.25;

/// Configuration for retry behavior with exponential backoff support.
#[derive(Debug, Clone)]
pub struct RetryPolicy {
    pub max_attempts: u32,
    pub base_delay: Duration,
    pub max_delay: Duration,
    pub jitter_factor: f64,
    use_exponential_backoff: bool,
    retryable_predicate: Option<fn(&str) -> bool>,
}

impl RetryPolicy {
    /// Fixed delay retry (legacy).
    pub fn new(max_attempts: u32, delay_ms: u64) -> Self {
        Self {
            max_attempts,
            base_delay: Duration::from_millis(delay_ms),
            max_delay: Duration::from_millis(delay_ms),
            jitter_factor: 0.0,
            use_exponential_backoff: false,
            retryable_predicate: None,
        }
    }

    /// Exponential backoff retry with jitter.
    pub fn with_exponential_backoff(
        max_attempts: u32,
        base_delay: Duration,
        max_delay: Duration,
        jitter_factor: f64,
    ) -> Self {
        Self {
            max_attempts,
            base_delay,
            max_delay,
            jitter_factor: jitter_factor.clamp(0.0, 1.0),
            use_exponential_backoff: true,
            retryable_predicate: None,
        }
    }

    /// Sets a predicate function to determine if an error should trigger a retry.
    pub fn with_retryable_predicate(mut self, predicate: fn(&str) -> bool) -> Self {
        self.retryable_predicate = Some(predicate);
        self
    }

    /// Computes the delay for a given attempt number.
    ///
    /// This method contains self-contained exponential backoff logic
    /// (`base_delay * 2^attempt`, capped at `max_delay`, with optional jitter).
    /// There is a separate [`super::ExponentialBackoff`] type elsewhere in the
    /// resilience module, but this implementation is intentionally kept inline:
    /// `RetryPolicy` needs tight control over attempt numbering and jitter that
    /// would not benefit from an extra layer of indirection.
    pub fn next_delay(&self, attempt: u32) -> Duration {
        if !self.use_exponential_backoff {
            return self.base_delay;
        }

        let base_ms = self.base_delay.as_millis() as u64;
        let max_ms = self.max_delay.as_millis() as u64;

        let delay_ms = if attempt >= 64 {
            max_ms
        } else {
            let multiplier = 1u64.checked_shl(attempt).unwrap_or(u64::MAX);
            base_ms.saturating_mul(multiplier).min(max_ms)
        };

        if self.jitter_factor > 0.0 {
            self.apply_jitter(Duration::from_millis(delay_ms))
        } else {
            Duration::from_millis(delay_ms)
        }
    }

    fn apply_jitter(&self, delay: Duration) -> Duration {
        let mut rng = rand::rng();
        let jitter_range = self.jitter_factor * 2.0;
        let jitter_offset = rng.random::<f64>() * jitter_range - self.jitter_factor;
        let factor = 1.0 + jitter_offset;
        let delay_ms = delay.as_millis() as f64;
        let jittered_ms = (delay_ms * factor).max(1.0) as u64;
        Duration::from_millis(jittered_ms)
    }

    fn is_retryable(&self, error_msg: &str) -> bool {
        match self.retryable_predicate {
            Some(pred) => pred(error_msg),
            None => true,
        }
    }

    /// Executes an async operation with retry logic.
    pub async fn execute<F, Fut, T, E>(&self, mut operation: F) -> Result<T, E>
    where
        F: FnMut() -> Fut,
        Fut: Future<Output = Result<T, E>>,
        E: std::fmt::Display,
    {
        let mut last_error: Option<E> = None;

        for attempt in 0..self.max_attempts {
            match operation().await {
                Ok(result) => return Ok(result),
                Err(e) => {
                    let error_msg = e.to_string();
                    if !self.is_retryable(&error_msg) {
                        return Err(e);
                    }
                    last_error = Some(e);
                    if attempt < self.max_attempts - 1 {
                        let delay = self.next_delay(attempt);
                        sleep(delay).await;
                    }
                }
            }
        }

        Err(last_error.expect("at least one attempt must have been made"))
    }
}

impl Default for RetryPolicy {
    fn default() -> Self {
        Self::with_exponential_backoff(
            3,
            Duration::from_millis(200),
            DEFAULT_MAX_DELAY,
            DEFAULT_JITTER_FACTOR,
        )
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_fixed_delay() {
        let policy = RetryPolicy::new(3, 200);
        assert_eq!(policy.next_delay(0), Duration::from_millis(200));
        assert_eq!(policy.next_delay(1), Duration::from_millis(200));
        assert_eq!(policy.next_delay(2), Duration::from_millis(200));
    }

    #[test]
    fn test_exponential_no_jitter() {
        let policy = RetryPolicy::with_exponential_backoff(
            5,
            Duration::from_millis(100),
            Duration::from_secs(30),
            0.0,
        );
        assert_eq!(policy.next_delay(0), Duration::from_millis(100));
        assert_eq!(policy.next_delay(1), Duration::from_millis(200));
        assert_eq!(policy.next_delay(2), Duration::from_millis(400));
    }

    #[test]
    fn test_capped_at_max_delay() {
        let policy = RetryPolicy::with_exponential_backoff(
            5,
            Duration::from_millis(100),
            Duration::from_secs(1),
            0.0,
        );
        assert_eq!(policy.next_delay(20), Duration::from_secs(1));
    }

    #[test]
    fn test_overflow_protection() {
        let policy = RetryPolicy::with_exponential_backoff(
            5,
            Duration::from_secs(1),
            Duration::from_secs(3600),
            0.0,
        );
        let delay = policy.next_delay(100);
        assert!(delay <= Duration::from_secs(3600));
    }

    #[tokio::test]
    async fn test_execute_success() {
        let policy = RetryPolicy::new(3, 10);
        let result: Result<i32, String> = policy.execute(|| async { Ok(42) }).await;
        assert_eq!(result.unwrap(), 42);
    }

    #[tokio::test]
    async fn test_execute_retries_then_succeeds() {
        use std::sync::Arc;
        use std::sync::atomic::{AtomicU32, Ordering};

        let attempts = Arc::new(AtomicU32::new(0));
        let attempts_clone = Arc::clone(&attempts);

        let policy = RetryPolicy::new(3, 1);
        let result: Result<i32, String> = policy
            .execute(|| {
                let a = Arc::clone(&attempts_clone);
                async move {
                    let count = a.fetch_add(1, Ordering::SeqCst);
                    if count < 2 {
                        Err("transient".to_string())
                    } else {
                        Ok(42)
                    }
                }
            })
            .await;

        assert_eq!(result.unwrap(), 42);
        assert_eq!(attempts.load(Ordering::SeqCst), 3);
    }
}