Skip to main content

amaters_server/
retry.rs

1//! Retry logic for transient failures in the AmateRS server.
2//!
3//! This module provides:
4//! - [`RetryPolicy`] — configurable exponential backoff with jitter.
5//! - [`ErrorClassification`] — trait for classifying errors as transient or permanent.
6//! - [`retry_with_backoff`] — generic async retry driver.
7//!
8//! **Important:** Only use [`retry_with_backoff`] for *idempotent* operations.
9//! Non-idempotent writes MUST NOT be wrapped in retry logic without sequence
10//! numbers or other deduplication mechanisms at the caller level.
11
12use std::time::Duration;
13
14// ---------------------------------------------------------------------------
15// Jitter PRNG
16// ---------------------------------------------------------------------------
17
18/// Minimal xorshift64 PRNG seeded from the current system time.
19///
20/// Used to produce approximate uniform jitter without pulling in an external
21/// PRNG crate.  The output is sufficient for backoff jitter purposes; it is
22/// NOT cryptographically secure.
23struct Xorshift64(u64);
24
25impl Xorshift64 {
26    /// Seed from the current wall clock (nanoseconds since UNIX epoch).
27    /// Falls back to a non-zero constant on platforms where the clock is
28    /// unavailable.
29    fn seeded() -> Self {
30        use std::time::{SystemTime, UNIX_EPOCH};
31        let seed = SystemTime::now()
32            .duration_since(UNIX_EPOCH)
33            .map(|d| d.as_nanos() as u64)
34            .unwrap_or(0xDEAD_BEEF_CAFE_BABEu64);
35        // xorshift requires a non-zero state.
36        Self(if seed == 0 {
37            0xDEAD_BEEF_CAFE_BABEu64
38        } else {
39            seed
40        })
41    }
42
43    /// Produce the next pseudo-random u64.
44    fn next(&mut self) -> u64 {
45        let mut x = self.0;
46        x ^= x << 13;
47        x ^= x >> 7;
48        x ^= x << 17;
49        self.0 = x;
50        x
51    }
52
53    /// Produce a value in `[0.0, 1.0)`.
54    fn next_f64(&mut self) -> f64 {
55        // Use the top 53 bits for a clean f64 mantissa.
56        (self.next() >> 11) as f64 / (1u64 << 53) as f64
57    }
58}
59
60// ---------------------------------------------------------------------------
61// Public types
62// ---------------------------------------------------------------------------
63
64/// Retry policy for transient failures.
65///
66/// ## Safety note
67/// IMPORTANT: Only use for idempotent operations — non-idempotent writes must
68/// not be wrapped in [`retry_with_backoff`] without explicit caller opt-in and
69/// deduplication.
70#[derive(Debug, Clone)]
71pub struct RetryPolicy {
72    /// Total number of attempts including the first (1 = no retry).
73    pub max_attempts: u32,
74    /// Base delay in milliseconds for the first retry.
75    pub base_delay_ms: u64,
76    /// Maximum delay cap in milliseconds.
77    pub max_delay_ms: u64,
78    /// Jitter factor applied to each computed delay.
79    ///
80    /// `0.0` = no jitter; `0.1` = ±10% uniform jitter.
81    /// Valid range: `[0.0, 1.0)`.
82    pub jitter_factor: f64,
83}
84
85impl Default for RetryPolicy {
86    fn default() -> Self {
87        Self {
88            max_attempts: 3,
89            base_delay_ms: 100,
90            max_delay_ms: 5_000,
91            jitter_factor: 0.1,
92        }
93    }
94}
95
96impl RetryPolicy {
97    /// Compute the sleep duration for retry attempt `n` (0-indexed; n=0 is the
98    /// first retry, i.e. after the first failed attempt).
99    ///
100    /// Formula: `min(max_delay_ms, base_delay_ms * 2^n) * uniform(1 - jitter, 1 + jitter)`
101    fn delay_for_attempt(&self, n: u32, rng: &mut Xorshift64) -> Duration {
102        // Saturating exponentiation to avoid u64 overflow.
103        // 2^n using checked_shl to guard against n >= 64.
104        let multiplier: u64 = 1u64.checked_shl(n).unwrap_or(u64::MAX);
105        let base: u64 = self.base_delay_ms.saturating_mul(multiplier);
106        let capped = base.min(self.max_delay_ms);
107
108        let factor = if self.jitter_factor <= 0.0 {
109            1.0_f64
110        } else {
111            let j = self.jitter_factor.min(1.0);
112            // uniform in [1 - j, 1 + j]
113            let r = rng.next_f64(); // [0, 1)
114            1.0 - j + 2.0 * j * r
115        };
116
117        let ms = (capped as f64 * factor).max(0.0) as u64;
118        Duration::from_millis(ms)
119    }
120}
121
122// ---------------------------------------------------------------------------
123// Error classification
124// ---------------------------------------------------------------------------
125
126/// Trait for classifying errors as transient (retriable) or permanent.
127///
128/// Transient errors are those where a retry might succeed (e.g. a momentary
129/// I/O interruption).  Permanent errors (e.g. `NotFound`, auth failure) should
130/// return `false` so they are surfaced immediately.
131pub trait ErrorClassification {
132    /// Returns `true` if this error is transient and the operation should be
133    /// retried (subject to the [`RetryPolicy`] limits).
134    fn is_transient(&self) -> bool;
135}
136
137// ---------------------------------------------------------------------------
138// Core retry driver
139// ---------------------------------------------------------------------------
140
141/// Retry `op` with exponential backoff and jitter according to `policy`.
142///
143/// - If `op` returns `Ok(v)`, returns immediately.
144/// - If `op` returns `Err(e)` and `e.is_transient()` is `true`, sleeps for
145///   the computed delay and tries again (up to `policy.max_attempts` times total).
146/// - If `op` returns `Err(e)` and `e.is_transient()` is `false`, returns
147///   the error immediately (no further attempts).
148/// - After exhausting all attempts, returns the last error.
149///
150/// # Example
151/// ```rust,no_run
152/// # use amaters_server::retry::{RetryPolicy, ErrorClassification, retry_with_backoff};
153/// # #[derive(Debug)] struct MyErr { transient: bool }
154/// # impl std::fmt::Display for MyErr { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "err") } }
155/// # impl ErrorClassification for MyErr { fn is_transient(&self) -> bool { self.transient } }
156/// # async fn demo() {
157/// let policy = RetryPolicy::default();
158/// let result = retry_with_backoff(|| async { Ok::<_, MyErr>(42) }, &policy).await;
159/// # }
160/// ```
161pub async fn retry_with_backoff<F, T, E, Fut>(mut op: F, policy: &RetryPolicy) -> Result<T, E>
162where
163    F: FnMut() -> Fut,
164    Fut: std::future::Future<Output = Result<T, E>>,
165    E: ErrorClassification + std::fmt::Debug,
166{
167    let mut rng = Xorshift64::seeded();
168    let max = policy.max_attempts.max(1);
169
170    for attempt in 0..max {
171        match op().await {
172            Ok(val) => return Ok(val),
173            Err(err) => {
174                let is_last = attempt + 1 >= max;
175                if is_last || !err.is_transient() {
176                    return Err(err);
177                }
178                // Compute retry delay: n = attempt (0-indexed first retry).
179                let delay = policy.delay_for_attempt(attempt, &mut rng);
180                tokio::time::sleep(delay).await;
181            }
182        }
183    }
184
185    // Unreachable: the loop always returns on the last attempt, but the
186    // compiler cannot see that without an explicit unreachable.  Calling op
187    // one final time satisfies the type-checker without introducing a panic.
188    op().await
189}
190
191// ---------------------------------------------------------------------------
192// Tests
193// ---------------------------------------------------------------------------
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198    use std::sync::{Arc, Mutex};
199
200    // ---- Minimal test error types -----------------------------------------
201
202    #[derive(Debug, Clone, PartialEq)]
203    enum TestError {
204        Transient,
205        Permanent,
206    }
207
208    impl std::fmt::Display for TestError {
209        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210            match self {
211                TestError::Transient => write!(f, "transient error"),
212                TestError::Permanent => write!(f, "permanent error"),
213            }
214        }
215    }
216
217    impl ErrorClassification for TestError {
218        fn is_transient(&self) -> bool {
219            matches!(self, TestError::Transient)
220        }
221    }
222
223    // ---- Tests ------------------------------------------------------------
224
225    /// op fails with a transient error on attempts 1 and 2, succeeds on attempt 3.
226    #[tokio::test]
227    async fn test_retry_succeeds_on_third_attempt() {
228        let call_count = Arc::new(Mutex::new(0u32));
229        let counter = Arc::clone(&call_count);
230
231        let policy = RetryPolicy {
232            max_attempts: 3,
233            base_delay_ms: 1,
234            max_delay_ms: 5,
235            jitter_factor: 0.0,
236        };
237
238        let result = retry_with_backoff(
239            || {
240                let counter = Arc::clone(&counter);
241                async move {
242                    let mut guard = counter.lock().expect("lock poisoned");
243                    *guard += 1;
244                    let n = *guard;
245                    drop(guard);
246                    if n < 3 {
247                        Err(TestError::Transient)
248                    } else {
249                        Ok(n)
250                    }
251                }
252            },
253            &policy,
254        )
255        .await;
256
257        assert!(result.is_ok(), "expected success on third attempt");
258        assert_eq!(result.expect("ok"), 3);
259        assert_eq!(*call_count.lock().expect("lock"), 3);
260    }
261
262    /// A permanent error must not be retried — total calls should be 1.
263    #[tokio::test]
264    async fn test_retry_permanent_error_not_retried() {
265        let call_count = Arc::new(Mutex::new(0u32));
266        let counter = Arc::clone(&call_count);
267
268        let policy = RetryPolicy {
269            max_attempts: 5,
270            base_delay_ms: 1,
271            max_delay_ms: 10,
272            jitter_factor: 0.0,
273        };
274
275        let result: Result<u32, TestError> = retry_with_backoff(
276            || {
277                let counter = Arc::clone(&counter);
278                async move {
279                    let mut guard = counter.lock().expect("lock poisoned");
280                    *guard += 1;
281                    Err(TestError::Permanent)
282                }
283            },
284            &policy,
285        )
286        .await;
287
288        assert_eq!(result, Err(TestError::Permanent));
289        assert_eq!(
290            *call_count.lock().expect("lock"),
291            1,
292            "permanent error must not be retried"
293        );
294    }
295
296    /// When every attempt returns a transient error, total calls must equal
297    /// `policy.max_attempts`.
298    #[tokio::test]
299    async fn test_retry_respects_max_attempts() {
300        let call_count = Arc::new(Mutex::new(0u32));
301        let counter = Arc::clone(&call_count);
302
303        let policy = RetryPolicy {
304            max_attempts: 4,
305            base_delay_ms: 1,
306            max_delay_ms: 5,
307            jitter_factor: 0.0,
308        };
309
310        let result: Result<u32, TestError> = retry_with_backoff(
311            || {
312                let counter = Arc::clone(&counter);
313                async move {
314                    let mut guard = counter.lock().expect("lock poisoned");
315                    *guard += 1;
316                    Err(TestError::Transient)
317                }
318            },
319            &policy,
320        )
321        .await;
322
323        assert_eq!(result, Err(TestError::Transient));
324        assert_eq!(
325            *call_count.lock().expect("lock"),
326            policy.max_attempts,
327            "total calls must equal max_attempts"
328        );
329    }
330
331    /// With `base_delay_ms = 50` and no jitter, two inter-attempt delays
332    /// are 50 ms and 100 ms, totalling ≥ 150 ms.
333    #[tokio::test]
334    async fn test_retry_backoff_increases_exponentially() {
335        let call_count = Arc::new(Mutex::new(0u32));
336        let counter = Arc::clone(&call_count);
337
338        let policy = RetryPolicy {
339            max_attempts: 3,
340            base_delay_ms: 50,
341            max_delay_ms: 5_000,
342            jitter_factor: 0.0, // no jitter so we can assert exact lower bound
343        };
344
345        let start = std::time::Instant::now();
346
347        let result: Result<u32, TestError> = retry_with_backoff(
348            || {
349                let counter = Arc::clone(&counter);
350                async move {
351                    let mut guard = counter.lock().expect("lock poisoned");
352                    *guard += 1;
353                    Err(TestError::Transient)
354                }
355            },
356            &policy,
357        )
358        .await;
359
360        let elapsed = start.elapsed();
361
362        assert!(result.is_err());
363        // Two sleeps: 50 ms + 100 ms = 150 ms minimum.
364        assert!(
365            elapsed >= Duration::from_millis(148), // 2 ms tolerance for timer precision
366            "expected elapsed >= 150 ms, got {:?}",
367            elapsed
368        );
369        assert_eq!(*call_count.lock().expect("lock"), 3);
370    }
371
372    // ---- Xorshift64 smoke-tests -------------------------------------------
373
374    #[test]
375    fn test_xorshift64_non_zero() {
376        let mut rng = Xorshift64::seeded();
377        // Ten consecutive values should all be non-zero (the seed is non-zero
378        // and xorshift preserves that).
379        for _ in 0..10 {
380            assert_ne!(rng.next(), 0);
381        }
382    }
383
384    #[test]
385    fn test_xorshift64_f64_in_range() {
386        let mut rng = Xorshift64::seeded();
387        for _ in 0..1000 {
388            let v = rng.next_f64();
389            assert!((0.0..1.0).contains(&v), "out of range: {v}");
390        }
391    }
392
393    // ---- RetryPolicy delay computation ------------------------------------
394
395    #[test]
396    fn test_delay_for_attempt_no_jitter() {
397        let policy = RetryPolicy {
398            max_attempts: 5,
399            base_delay_ms: 100,
400            max_delay_ms: 1_000,
401            jitter_factor: 0.0,
402        };
403        let mut rng = Xorshift64::seeded();
404        assert_eq!(
405            policy.delay_for_attempt(0, &mut rng),
406            Duration::from_millis(100)
407        );
408        assert_eq!(
409            policy.delay_for_attempt(1, &mut rng),
410            Duration::from_millis(200)
411        );
412        assert_eq!(
413            policy.delay_for_attempt(2, &mut rng),
414            Duration::from_millis(400)
415        );
416        assert_eq!(
417            policy.delay_for_attempt(3, &mut rng),
418            Duration::from_millis(800)
419        );
420        // Capped at max_delay_ms.
421        assert_eq!(
422            policy.delay_for_attempt(4, &mut rng),
423            Duration::from_millis(1_000)
424        );
425    }
426}