amaters_server/retry.rs
1//! Retry logic for transient failures in the AmateRS server.
2//!
3//! This module provides:
4//! - [`RetryPolicy`] — configurable exponential backoff with jitter.
5//! - [`ErrorClassification`] — trait for classifying errors as transient or permanent.
6//! - [`retry_with_backoff`] — generic async retry driver.
7//!
8//! **Important:** Only use [`retry_with_backoff`] for *idempotent* operations.
9//! Non-idempotent writes MUST NOT be wrapped in retry logic without sequence
10//! numbers or other deduplication mechanisms at the caller level.
11
12use std::time::Duration;
13
14// ---------------------------------------------------------------------------
15// Jitter PRNG
16// ---------------------------------------------------------------------------
17
18/// Minimal xorshift64 PRNG seeded from the current system time.
19///
20/// Used to produce approximate uniform jitter without pulling in an external
21/// PRNG crate. The output is sufficient for backoff jitter purposes; it is
22/// NOT cryptographically secure.
23struct Xorshift64(u64);
24
25impl Xorshift64 {
26 /// Seed from the current wall clock (nanoseconds since UNIX epoch).
27 /// Falls back to a non-zero constant on platforms where the clock is
28 /// unavailable.
29 fn seeded() -> Self {
30 use std::time::{SystemTime, UNIX_EPOCH};
31 let seed = SystemTime::now()
32 .duration_since(UNIX_EPOCH)
33 .map(|d| d.as_nanos() as u64)
34 .unwrap_or(0xDEAD_BEEF_CAFE_BABEu64);
35 // xorshift requires a non-zero state.
36 Self(if seed == 0 {
37 0xDEAD_BEEF_CAFE_BABEu64
38 } else {
39 seed
40 })
41 }
42
43 /// Produce the next pseudo-random u64.
44 fn next(&mut self) -> u64 {
45 let mut x = self.0;
46 x ^= x << 13;
47 x ^= x >> 7;
48 x ^= x << 17;
49 self.0 = x;
50 x
51 }
52
53 /// Produce a value in `[0.0, 1.0)`.
54 fn next_f64(&mut self) -> f64 {
55 // Use the top 53 bits for a clean f64 mantissa.
56 (self.next() >> 11) as f64 / (1u64 << 53) as f64
57 }
58}
59
60// ---------------------------------------------------------------------------
61// Public types
62// ---------------------------------------------------------------------------
63
64/// Retry policy for transient failures.
65///
66/// ## Safety note
67/// IMPORTANT: Only use for idempotent operations — non-idempotent writes must
68/// not be wrapped in [`retry_with_backoff`] without explicit caller opt-in and
69/// deduplication.
70#[derive(Debug, Clone)]
71pub struct RetryPolicy {
72 /// Total number of attempts including the first (1 = no retry).
73 pub max_attempts: u32,
74 /// Base delay in milliseconds for the first retry.
75 pub base_delay_ms: u64,
76 /// Maximum delay cap in milliseconds.
77 pub max_delay_ms: u64,
78 /// Jitter factor applied to each computed delay.
79 ///
80 /// `0.0` = no jitter; `0.1` = ±10% uniform jitter.
81 /// Valid range: `[0.0, 1.0)`.
82 pub jitter_factor: f64,
83}
84
85impl Default for RetryPolicy {
86 fn default() -> Self {
87 Self {
88 max_attempts: 3,
89 base_delay_ms: 100,
90 max_delay_ms: 5_000,
91 jitter_factor: 0.1,
92 }
93 }
94}
95
96impl RetryPolicy {
97 /// Compute the sleep duration for retry attempt `n` (0-indexed; n=0 is the
98 /// first retry, i.e. after the first failed attempt).
99 ///
100 /// Formula: `min(max_delay_ms, base_delay_ms * 2^n) * uniform(1 - jitter, 1 + jitter)`
101 fn delay_for_attempt(&self, n: u32, rng: &mut Xorshift64) -> Duration {
102 // Saturating exponentiation to avoid u64 overflow.
103 // 2^n using checked_shl to guard against n >= 64.
104 let multiplier: u64 = 1u64.checked_shl(n).unwrap_or(u64::MAX);
105 let base: u64 = self.base_delay_ms.saturating_mul(multiplier);
106 let capped = base.min(self.max_delay_ms);
107
108 let factor = if self.jitter_factor <= 0.0 {
109 1.0_f64
110 } else {
111 let j = self.jitter_factor.min(1.0);
112 // uniform in [1 - j, 1 + j]
113 let r = rng.next_f64(); // [0, 1)
114 1.0 - j + 2.0 * j * r
115 };
116
117 let ms = (capped as f64 * factor).max(0.0) as u64;
118 Duration::from_millis(ms)
119 }
120}
121
122// ---------------------------------------------------------------------------
123// Error classification
124// ---------------------------------------------------------------------------
125
126/// Trait for classifying errors as transient (retriable) or permanent.
127///
128/// Transient errors are those where a retry might succeed (e.g. a momentary
129/// I/O interruption). Permanent errors (e.g. `NotFound`, auth failure) should
130/// return `false` so they are surfaced immediately.
131pub trait ErrorClassification {
132 /// Returns `true` if this error is transient and the operation should be
133 /// retried (subject to the [`RetryPolicy`] limits).
134 fn is_transient(&self) -> bool;
135}
136
137// ---------------------------------------------------------------------------
138// Core retry driver
139// ---------------------------------------------------------------------------
140
141/// Retry `op` with exponential backoff and jitter according to `policy`.
142///
143/// - If `op` returns `Ok(v)`, returns immediately.
144/// - If `op` returns `Err(e)` and `e.is_transient()` is `true`, sleeps for
145/// the computed delay and tries again (up to `policy.max_attempts` times total).
146/// - If `op` returns `Err(e)` and `e.is_transient()` is `false`, returns
147/// the error immediately (no further attempts).
148/// - After exhausting all attempts, returns the last error.
149///
150/// # Example
151/// ```rust,no_run
152/// # use amaters_server::retry::{RetryPolicy, ErrorClassification, retry_with_backoff};
153/// # #[derive(Debug)] struct MyErr { transient: bool }
154/// # impl std::fmt::Display for MyErr { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "err") } }
155/// # impl ErrorClassification for MyErr { fn is_transient(&self) -> bool { self.transient } }
156/// # async fn demo() {
157/// let policy = RetryPolicy::default();
158/// let result = retry_with_backoff(|| async { Ok::<_, MyErr>(42) }, &policy).await;
159/// # }
160/// ```
161pub async fn retry_with_backoff<F, T, E, Fut>(mut op: F, policy: &RetryPolicy) -> Result<T, E>
162where
163 F: FnMut() -> Fut,
164 Fut: std::future::Future<Output = Result<T, E>>,
165 E: ErrorClassification + std::fmt::Debug,
166{
167 let mut rng = Xorshift64::seeded();
168 let max = policy.max_attempts.max(1);
169
170 for attempt in 0..max {
171 match op().await {
172 Ok(val) => return Ok(val),
173 Err(err) => {
174 let is_last = attempt + 1 >= max;
175 if is_last || !err.is_transient() {
176 return Err(err);
177 }
178 // Compute retry delay: n = attempt (0-indexed first retry).
179 let delay = policy.delay_for_attempt(attempt, &mut rng);
180 tokio::time::sleep(delay).await;
181 }
182 }
183 }
184
185 // Unreachable: the loop always returns on the last attempt, but the
186 // compiler cannot see that without an explicit unreachable. Calling op
187 // one final time satisfies the type-checker without introducing a panic.
188 op().await
189}
190
191// ---------------------------------------------------------------------------
192// Tests
193// ---------------------------------------------------------------------------
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 use std::sync::{Arc, Mutex};
199
200 // ---- Minimal test error types -----------------------------------------
201
202 #[derive(Debug, Clone, PartialEq)]
203 enum TestError {
204 Transient,
205 Permanent,
206 }
207
208 impl std::fmt::Display for TestError {
209 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210 match self {
211 TestError::Transient => write!(f, "transient error"),
212 TestError::Permanent => write!(f, "permanent error"),
213 }
214 }
215 }
216
217 impl ErrorClassification for TestError {
218 fn is_transient(&self) -> bool {
219 matches!(self, TestError::Transient)
220 }
221 }
222
223 // ---- Tests ------------------------------------------------------------
224
225 /// op fails with a transient error on attempts 1 and 2, succeeds on attempt 3.
226 #[tokio::test]
227 async fn test_retry_succeeds_on_third_attempt() {
228 let call_count = Arc::new(Mutex::new(0u32));
229 let counter = Arc::clone(&call_count);
230
231 let policy = RetryPolicy {
232 max_attempts: 3,
233 base_delay_ms: 1,
234 max_delay_ms: 5,
235 jitter_factor: 0.0,
236 };
237
238 let result = retry_with_backoff(
239 || {
240 let counter = Arc::clone(&counter);
241 async move {
242 let mut guard = counter.lock().expect("lock poisoned");
243 *guard += 1;
244 let n = *guard;
245 drop(guard);
246 if n < 3 {
247 Err(TestError::Transient)
248 } else {
249 Ok(n)
250 }
251 }
252 },
253 &policy,
254 )
255 .await;
256
257 assert!(result.is_ok(), "expected success on third attempt");
258 assert_eq!(result.expect("ok"), 3);
259 assert_eq!(*call_count.lock().expect("lock"), 3);
260 }
261
262 /// A permanent error must not be retried — total calls should be 1.
263 #[tokio::test]
264 async fn test_retry_permanent_error_not_retried() {
265 let call_count = Arc::new(Mutex::new(0u32));
266 let counter = Arc::clone(&call_count);
267
268 let policy = RetryPolicy {
269 max_attempts: 5,
270 base_delay_ms: 1,
271 max_delay_ms: 10,
272 jitter_factor: 0.0,
273 };
274
275 let result: Result<u32, TestError> = retry_with_backoff(
276 || {
277 let counter = Arc::clone(&counter);
278 async move {
279 let mut guard = counter.lock().expect("lock poisoned");
280 *guard += 1;
281 Err(TestError::Permanent)
282 }
283 },
284 &policy,
285 )
286 .await;
287
288 assert_eq!(result, Err(TestError::Permanent));
289 assert_eq!(
290 *call_count.lock().expect("lock"),
291 1,
292 "permanent error must not be retried"
293 );
294 }
295
296 /// When every attempt returns a transient error, total calls must equal
297 /// `policy.max_attempts`.
298 #[tokio::test]
299 async fn test_retry_respects_max_attempts() {
300 let call_count = Arc::new(Mutex::new(0u32));
301 let counter = Arc::clone(&call_count);
302
303 let policy = RetryPolicy {
304 max_attempts: 4,
305 base_delay_ms: 1,
306 max_delay_ms: 5,
307 jitter_factor: 0.0,
308 };
309
310 let result: Result<u32, TestError> = retry_with_backoff(
311 || {
312 let counter = Arc::clone(&counter);
313 async move {
314 let mut guard = counter.lock().expect("lock poisoned");
315 *guard += 1;
316 Err(TestError::Transient)
317 }
318 },
319 &policy,
320 )
321 .await;
322
323 assert_eq!(result, Err(TestError::Transient));
324 assert_eq!(
325 *call_count.lock().expect("lock"),
326 policy.max_attempts,
327 "total calls must equal max_attempts"
328 );
329 }
330
331 /// With `base_delay_ms = 50` and no jitter, two inter-attempt delays
332 /// are 50 ms and 100 ms, totalling ≥ 150 ms.
333 #[tokio::test]
334 async fn test_retry_backoff_increases_exponentially() {
335 let call_count = Arc::new(Mutex::new(0u32));
336 let counter = Arc::clone(&call_count);
337
338 let policy = RetryPolicy {
339 max_attempts: 3,
340 base_delay_ms: 50,
341 max_delay_ms: 5_000,
342 jitter_factor: 0.0, // no jitter so we can assert exact lower bound
343 };
344
345 let start = std::time::Instant::now();
346
347 let result: Result<u32, TestError> = retry_with_backoff(
348 || {
349 let counter = Arc::clone(&counter);
350 async move {
351 let mut guard = counter.lock().expect("lock poisoned");
352 *guard += 1;
353 Err(TestError::Transient)
354 }
355 },
356 &policy,
357 )
358 .await;
359
360 let elapsed = start.elapsed();
361
362 assert!(result.is_err());
363 // Two sleeps: 50 ms + 100 ms = 150 ms minimum.
364 assert!(
365 elapsed >= Duration::from_millis(148), // 2 ms tolerance for timer precision
366 "expected elapsed >= 150 ms, got {:?}",
367 elapsed
368 );
369 assert_eq!(*call_count.lock().expect("lock"), 3);
370 }
371
372 // ---- Xorshift64 smoke-tests -------------------------------------------
373
374 #[test]
375 fn test_xorshift64_non_zero() {
376 let mut rng = Xorshift64::seeded();
377 // Ten consecutive values should all be non-zero (the seed is non-zero
378 // and xorshift preserves that).
379 for _ in 0..10 {
380 assert_ne!(rng.next(), 0);
381 }
382 }
383
384 #[test]
385 fn test_xorshift64_f64_in_range() {
386 let mut rng = Xorshift64::seeded();
387 for _ in 0..1000 {
388 let v = rng.next_f64();
389 assert!((0.0..1.0).contains(&v), "out of range: {v}");
390 }
391 }
392
393 // ---- RetryPolicy delay computation ------------------------------------
394
395 #[test]
396 fn test_delay_for_attempt_no_jitter() {
397 let policy = RetryPolicy {
398 max_attempts: 5,
399 base_delay_ms: 100,
400 max_delay_ms: 1_000,
401 jitter_factor: 0.0,
402 };
403 let mut rng = Xorshift64::seeded();
404 assert_eq!(
405 policy.delay_for_attempt(0, &mut rng),
406 Duration::from_millis(100)
407 );
408 assert_eq!(
409 policy.delay_for_attempt(1, &mut rng),
410 Duration::from_millis(200)
411 );
412 assert_eq!(
413 policy.delay_for_attempt(2, &mut rng),
414 Duration::from_millis(400)
415 );
416 assert_eq!(
417 policy.delay_for_attempt(3, &mut rng),
418 Duration::from_millis(800)
419 );
420 // Capped at max_delay_ms.
421 assert_eq!(
422 policy.delay_for_attempt(4, &mut rng),
423 Duration::from_millis(1_000)
424 );
425 }
426}