stream-tungstenite 0.6.1

A streaming implementation of the Tungstenite WebSocket protocol
Documentation
//! Retry strategies for connection attempts.

use rand::{rngs::StdRng, Rng, SeedableRng};
use std::time::Duration;

use crate::error::ConnectError;

use dyn_clone::DynClone;

/// Retry strategy trait for determining retry delays
pub trait RetryStrategy: Send + Sync + DynClone + 'static {
    /// Get the next retry delay based on the error and attempt number
    /// Returns None if no more retries should be attempted
    fn next_delay(&mut self, error: &ConnectError, attempt: u32) -> Option<Duration>;

    /// Reset the strategy state (called after successful connection)
    fn reset(&mut self);

    /// Clone the strategy into a boxed trait object
    fn clone_box(&self) -> Box<dyn RetryStrategy>
    where
        Self: Sized + 'static,
    {
        dyn_clone::clone_box(self)
    }
}

dyn_clone::clone_trait_object!(RetryStrategy);

/// Exponential backoff retry strategy
#[derive(Clone)]
pub struct ExponentialBackoff {
    /// Initial delay
    initial: Duration,
    /// Maximum delay
    max: Duration,
    /// Growth factor
    factor: f64,
    /// Jitter percentage (0.0 to 1.0)
    jitter: f64,
    /// Maximum number of attempts (None for infinite)
    max_attempts: Option<u32>,
    /// Random number generator seed (for deterministic testing)
    seed: Option<u64>,

    // Runtime state
    current_delay: Duration,
    rng: StdRng,
}

impl ExponentialBackoff {
    /// Create a new exponential backoff strategy
    #[must_use]
    pub fn new(initial: Duration, max: Duration) -> Self {
        Self {
            initial,
            max,
            factor: 2.0,
            jitter: 0.1,
            max_attempts: None,
            seed: None,
            current_delay: initial,
            rng: StdRng::from_os_rng(),
        }
    }

    /// Set the growth factor
    #[must_use]
    pub const fn with_factor(mut self, factor: f64) -> Self {
        self.factor = factor;
        self
    }

    /// Set the jitter percentage (0.0 to 1.0)
    #[must_use]
    pub fn with_jitter(mut self, jitter: f64) -> Self {
        self.jitter = jitter.clamp(0.0, 1.0);
        self
    }

    /// Set maximum number of attempts
    #[must_use]
    pub const fn with_max_attempts(mut self, max: u32) -> Self {
        self.max_attempts = Some(max);
        self
    }

    /// Set random seed (for deterministic testing)
    #[must_use]
    pub fn with_seed(mut self, seed: u64) -> Self {
        self.seed = Some(seed);
        self.rng = StdRng::seed_from_u64(seed);
        self
    }

    /// Create a fast reconnect strategy
    /// - 100ms initial delay
    /// - 5s max delay
    /// - 1.5x growth factor
    #[must_use]
    pub fn fast() -> Self {
        Self::new(Duration::from_millis(100), Duration::from_secs(5))
            .with_factor(1.5)
            .with_jitter(0.1)
    }

    /// Create a standard strategy
    /// - 1s initial delay
    /// - 60s max delay
    /// - 2x growth factor
    #[must_use]
    pub fn standard() -> Self {
        Self::new(Duration::from_secs(1), Duration::from_secs(60))
            .with_factor(2.0)
            .with_jitter(0.1)
    }

    /// Create a conservative strategy
    /// - 2s initial delay
    /// - 120s max delay
    /// - 2x growth factor
    #[must_use]
    pub fn conservative() -> Self {
        Self::new(Duration::from_secs(2), Duration::from_secs(120))
            .with_factor(2.0)
            .with_jitter(0.05)
    }

    /// Apply jitter to a duration
    fn apply_jitter(&mut self, duration: Duration) -> Duration {
        if self.jitter == 0.0 {
            return duration;
        }

        let secs = duration.as_secs_f64();
        let jitter_range = secs * self.jitter;
        let jitter = self.rng.random_range(-jitter_range..jitter_range);
        let result = (secs + jitter).max(0.0);

        Duration::from_secs_f64(result)
    }
}

impl Default for ExponentialBackoff {
    fn default() -> Self {
        Self::standard()
    }
}

impl RetryStrategy for ExponentialBackoff {
    fn next_delay(&mut self, error: &ConnectError, attempt: u32) -> Option<Duration> {
        // Check if error is retryable
        if !error.is_retryable() {
            return None;
        }

        // Check max attempts
        if let Some(max) = self.max_attempts {
            if attempt >= max {
                return None;
            }
        }

        // Check for error-suggested delay
        if let Some(suggested) = error.suggested_delay() {
            return Some(self.apply_jitter(suggested));
        }

        // Calculate delay with exponential backoff
        let delay = self.current_delay.min(self.max);
        let delay_with_jitter = self.apply_jitter(delay);

        // Update current delay for next iteration
        let next_delay_secs = self.current_delay.as_secs_f64() * self.factor;
        self.current_delay = Duration::from_secs_f64(next_delay_secs.min(self.max.as_secs_f64()));

        Some(delay_with_jitter)
    }

    fn reset(&mut self) {
        self.current_delay = self.initial;
        // Re-seed RNG if we have a seed
        if let Some(seed) = self.seed {
            self.rng = StdRng::seed_from_u64(seed);
        }
    }
}

/// Fixed delay retry strategy
#[derive(Clone)]
pub struct FixedDelay {
    /// Delay between attempts
    delay: Duration,
    /// Maximum number of attempts
    max_attempts: Option<u32>,
}

impl FixedDelay {
    /// Create a new fixed delay strategy
    #[must_use]
    pub const fn new(delay: Duration) -> Self {
        Self {
            delay,
            max_attempts: None,
        }
    }

    /// Set maximum number of attempts
    #[must_use]
    pub const fn with_max_attempts(mut self, max: u32) -> Self {
        self.max_attempts = Some(max);
        self
    }
}

impl RetryStrategy for FixedDelay {
    fn next_delay(&mut self, error: &ConnectError, attempt: u32) -> Option<Duration> {
        if !error.is_retryable() {
            return None;
        }

        if let Some(max) = self.max_attempts {
            if attempt >= max {
                return None;
            }
        }

        Some(self.delay)
    }

    fn reset(&mut self) {
        // No state to reset
    }
}

/// No retry strategy - fails immediately
#[derive(Clone, Copy, Default)]
pub struct NoRetry;

impl RetryStrategy for NoRetry {
    fn next_delay(&mut self, _error: &ConnectError, _attempt: u32) -> Option<Duration> {
        None
    }

    fn reset(&mut self) {}
}

/// Retry strategy that uses a closure
pub struct CustomRetry<F>
where
    F: Fn(&ConnectError, u32) -> Option<Duration> + Send + Sync + Clone + 'static,
{
    f: F,
}

impl<F> CustomRetry<F>
where
    F: Fn(&ConnectError, u32) -> Option<Duration> + Send + Sync + Clone + 'static,
{
    /// Create a custom retry strategy
    #[must_use]
    pub const fn new(f: F) -> Self {
        Self { f }
    }
}

impl<F> Clone for CustomRetry<F>
where
    F: Fn(&ConnectError, u32) -> Option<Duration> + Send + Sync + Clone + 'static,
{
    fn clone(&self) -> Self {
        Self { f: self.f.clone() }
    }
}

impl<F> RetryStrategy for CustomRetry<F>
where
    F: Fn(&ConnectError, u32) -> Option<Duration> + Send + Sync + Clone + 'static,
{
    fn next_delay(&mut self, error: &ConnectError, attempt: u32) -> Option<Duration> {
        (self.f)(error, attempt)
    }

    fn reset(&mut self) {}
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_exponential_backoff() {
        let mut strategy = ExponentialBackoff::new(Duration::from_secs(1), Duration::from_secs(60))
            .with_factor(2.0)
            .with_jitter(0.0) // No jitter for deterministic testing
            .with_seed(42);

        // Use TcpConnect error which has no suggested_delay
        // (ConnectError::Refused has suggested_delay of 2s)
        let error = ConnectError::TcpConnect("test".into());

        // First delay should be ~1s
        let d1 = strategy.next_delay(&error, 1).unwrap();
        assert_eq!(d1, Duration::from_secs(1));

        // Second delay should be ~2s
        let d2 = strategy.next_delay(&error, 2).unwrap();
        assert_eq!(d2, Duration::from_secs(2));

        // Third delay should be ~4s
        let d3 = strategy.next_delay(&error, 3).unwrap();
        assert_eq!(d3, Duration::from_secs(4));

        // After reset, should start from beginning
        strategy.reset();
        let d1_again = strategy.next_delay(&error, 1).unwrap();
        assert_eq!(d1_again, Duration::from_secs(1));
    }

    #[test]
    fn test_exponential_backoff_max() {
        let mut strategy = ExponentialBackoff::new(Duration::from_secs(1), Duration::from_secs(5))
            .with_jitter(0.0);

        let error = ConnectError::Refused;

        // Should cap at max
        for _ in 0..10 {
            let delay = strategy.next_delay(&error, 1).unwrap();
            assert!(delay <= Duration::from_secs(5));
        }
    }

    #[test]
    fn test_non_retryable_error() {
        let mut strategy = ExponentialBackoff::standard();
        let error = ConnectError::InvalidUri("bad".into());

        // Non-retryable error should return None
        assert!(strategy.next_delay(&error, 1).is_none());
    }

    #[test]
    fn test_max_attempts() {
        let mut strategy = ExponentialBackoff::fast().with_max_attempts(3);
        let error = ConnectError::Refused;

        assert!(strategy.next_delay(&error, 1).is_some());
        assert!(strategy.next_delay(&error, 2).is_some());
        assert!(strategy.next_delay(&error, 3).is_none()); // Exceeded max
    }

    #[test]
    fn test_fixed_delay() {
        let mut strategy = FixedDelay::new(Duration::from_secs(5));
        let error = ConnectError::Refused;

        assert_eq!(strategy.next_delay(&error, 1), Some(Duration::from_secs(5)));
        assert_eq!(strategy.next_delay(&error, 2), Some(Duration::from_secs(5)));
        assert_eq!(strategy.next_delay(&error, 3), Some(Duration::from_secs(5)));
    }

    #[test]
    fn test_no_retry() {
        let mut strategy = NoRetry;
        let error = ConnectError::Refused;

        assert!(strategy.next_delay(&error, 1).is_none());
    }
}