apalis_core/backend/poll_strategy/strategies/
backoff.rs

1use std::{
2    sync::atomic::{AtomicU64, Ordering},
3    time::Duration,
4};
5
6use futures_core::stream::BoxStream;
7use futures_util::{StreamExt, stream};
8
9use crate::backend::poll_strategy::{IntervalStrategy, PollContext, PollStrategy};
10
11// Simple PRNG state for jitter (thread-safe)
12static JITTER_STATE: AtomicU64 = AtomicU64::new(1);
13
14/// A polling strategy that applies exponential backoff to an inner interval strategy
15#[derive(Clone, Debug)]
16pub struct BackoffStrategy {
17    interval: IntervalStrategy,
18    backoff_config: BackoffConfig,
19    default_delay: Duration,
20}
21impl BackoffStrategy {
22    /// Create a new BackoffStrategy wrapping an IntervalStrategy with the given BackoffConfig
23    #[must_use]
24    pub fn new(inner: IntervalStrategy, config: BackoffConfig) -> Self {
25        Self {
26            default_delay: inner.poll_interval,
27            interval: inner,
28            backoff_config: config,
29        }
30    }
31}
32
33impl PollStrategy for BackoffStrategy {
34    type Stream = BoxStream<'static, ()>;
35
36    fn poll_strategy(self: Box<Self>, ctx: &PollContext) -> Self::Stream {
37        let backoff_config = self.backoff_config.clone();
38        let current_delay = self.interval.poll_interval;
39        let default_delay = self.default_delay;
40
41        stream::unfold(
42            (ctx.clone(), current_delay),
43            move |(ctx, mut current_delay)| {
44                let fut = futures_timer::Delay::new(current_delay);
45                let backoff_config = backoff_config.clone();
46                async move {
47                    fut.await;
48                    let failed = ctx.prev_count.load(Ordering::Relaxed) == 0;
49                    current_delay = backoff_config.next_delay(default_delay, current_delay, failed);
50                    Some(((), (ctx, current_delay)))
51                }
52            },
53        )
54        .boxed()
55    }
56}
57
58/// Backoff configuration for strategies
59#[derive(Debug, Clone)]
60pub struct BackoffConfig {
61    max_delay: Duration,
62    multiplier: f64,
63    jitter_factor: f64, // 0.0 to 1.0
64}
65
66impl Default for BackoffConfig {
67    fn default() -> Self {
68        Self {
69            max_delay: Duration::from_secs(60),
70            multiplier: 2.0,
71            jitter_factor: 0.1,
72        }
73    }
74}
75
76impl BackoffConfig {
77    /// Create a new BackoffConfig with the specified maximum delay
78    #[must_use]
79    pub fn new(max: Duration) -> Self {
80        Self {
81            max_delay: max,
82            ..Default::default()
83        }
84    }
85
86    /// Set the multiplier for exponential backoff
87    #[must_use]
88    pub fn with_multiplier(mut self, multiplier: f64) -> Self {
89        self.multiplier = multiplier;
90        self
91    }
92
93    /// Set the jitter factor (0.0 to 1.0) for randomizing delays
94    #[must_use]
95    pub fn with_jitter(mut self, jitter_factor: f64) -> Self {
96        self.jitter_factor = jitter_factor.clamp(0.0, 1.0);
97        self
98    }
99
100    /// Calculate the next delay with backoff and jitter
101    fn next_delay(
102        &self,
103        default_delay: Duration,
104        current_delay: Duration,
105        failed: bool,
106    ) -> Duration {
107        let base_delay = if failed {
108            // Exponential backoff on failure
109            let next = Duration::from_secs_f64(current_delay.as_secs_f64() * self.multiplier);
110            next.min(self.max_delay)
111        } else {
112            // Reset to initial on success
113            default_delay
114        };
115
116        // Add jitter using a simple LCG (Linear Congruential Generator)
117        if self.jitter_factor > 0.0 {
118            // Simple deterministic pseudo-random number generation
119            let mut state = JITTER_STATE.load(Ordering::Relaxed);
120            state = state.wrapping_mul(1103515245).wrapping_add(12345);
121            JITTER_STATE.store(state, Ordering::Relaxed);
122
123            // Convert to 0.0-1.0 range
124            let normalized = (state as f64) / (u64::MAX as f64);
125
126            // Apply jitter: -jitter_factor to +jitter_factor
127            let jitter_range = base_delay.as_secs_f64() * self.jitter_factor;
128            let jitter = (normalized - 0.5) * 2.0 * jitter_range;
129            let jittered = base_delay.as_secs_f64() + jitter;
130            Duration::from_secs_f64(jittered.max(0.0))
131        } else {
132            base_delay
133        }
134    }
135}