cardinal_proxy/
retry.rs

1use cardinal_config::{DestinationRetry, DestinationRetryBackoffType};
2use serde::{Deserialize, Serialize};
3use std::time::{Duration, Instant};
4
5#[derive(Debug, Serialize, Deserialize)]
6pub enum BackoffStrategy {
7    Exponential,
8    Linear,
9    None,
10}
11
12pub struct RetryState {
13    /// How many attempts have been made so far (starts at 0)
14    pub current_attempt: u32,
15
16    /// Total allowed attempts from the config
17    pub max_attempts: u32,
18
19    /// The base interval between retries
20    pub base_interval: Duration,
21
22    /// The timestamp of the last retry attempt
23    pub last_attempt_at: Option<Instant>,
24
25    /// The computed delay before the next retry
26    pub next_delay: Duration,
27
28    /// Whether exponential or linear backoff is used
29    pub strategy: BackoffStrategy,
30
31    /// Upper bound for the delay if provided in the config
32    pub max_interval: Option<Duration>,
33}
34
35impl From<DestinationRetry> for RetryState {
36    fn from(value: DestinationRetry) -> Self {
37        let base_interval = Duration::from_millis(value.interval_ms);
38        let max_interval = value.max_interval.map(Duration::from_millis);
39        let initial_delay = max_interval
40            .map(|max| base_interval.min(max))
41            .unwrap_or(base_interval);
42
43        RetryState {
44            current_attempt: 0,
45            max_attempts: value.max_attempts.min(u32::MAX as u64) as u32,
46            base_interval,
47            last_attempt_at: None,
48            next_delay: initial_delay,
49            strategy: match value.backoff_type {
50                DestinationRetryBackoffType::Exponential => BackoffStrategy::Exponential,
51                DestinationRetryBackoffType::Linear => BackoffStrategy::Linear,
52                DestinationRetryBackoffType::None => BackoffStrategy::None,
53            },
54            max_interval,
55        }
56    }
57}
58
59impl RetryState {
60    pub fn register_attempt(&mut self) {
61        self.current_attempt += 1;
62        self.last_attempt_at = Some(Instant::now());
63
64        // Compute the next delay based on the strategy
65        let mut next_delay = match self.strategy {
66            BackoffStrategy::None => self.base_interval,
67            BackoffStrategy::Linear => self
68                .base_interval
69                .saturating_mul(self.current_attempt.max(1)),
70            BackoffStrategy::Exponential => {
71                let shift = (self.current_attempt - 1).min(31);
72                let multiplier = 1u32 << shift;
73                self.base_interval.saturating_mul(multiplier)
74            }
75        };
76
77        if let Some(max_interval) = self.max_interval {
78            if next_delay > max_interval {
79                next_delay = max_interval;
80            }
81        }
82
83        self.next_delay = next_delay;
84    }
85
86    pub fn can_retry(&self) -> bool {
87        self.current_attempt < self.max_attempts
88    }
89
90    pub async fn sleep_if_retry_allowed(&mut self) -> bool {
91        if self.can_retry() {
92            tokio::time::sleep(self.next_delay).await;
93            true
94        } else {
95            false
96        }
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103    use std::time::Duration;
104    use tokio::time::sleep;
105    //
106    // ──────────────────────────────── UNIT TESTS ────────────────────────────────
107    //
108
109    #[test]
110    fn none_backoff_increments_and_uses_fixed_interval() {
111        let mut state = RetryState {
112            current_attempt: 0,
113            max_attempts: 3,
114            base_interval: Duration::from_millis(100),
115            last_attempt_at: None,
116            next_delay: Duration::ZERO,
117            strategy: BackoffStrategy::None,
118            max_interval: None,
119        };
120
121        state.register_attempt();
122        assert_eq!(state.current_attempt, 1);
123        assert_eq!(state.next_delay, Duration::from_millis(100));
124        assert!(state.last_attempt_at.is_some());
125    }
126
127    #[test]
128    fn linear_backoff_grows_linearly() {
129        let mut state = RetryState {
130            current_attempt: 0,
131            max_attempts: 3,
132            base_interval: Duration::from_millis(100),
133            last_attempt_at: None,
134            next_delay: Duration::ZERO,
135            strategy: BackoffStrategy::Linear,
136            max_interval: None,
137        };
138
139        state.register_attempt();
140        assert_eq!(state.next_delay, Duration::from_millis(100));
141
142        state.register_attempt();
143        assert_eq!(state.next_delay, Duration::from_millis(200));
144
145        state.register_attempt();
146        assert_eq!(state.next_delay, Duration::from_millis(300));
147    }
148
149    #[test]
150    fn exponential_backoff_doubles_each_attempt() {
151        let mut state = RetryState {
152            current_attempt: 0,
153            max_attempts: 4,
154            base_interval: Duration::from_millis(50),
155            last_attempt_at: None,
156            next_delay: Duration::ZERO,
157            strategy: BackoffStrategy::Exponential,
158            max_interval: None,
159        };
160
161        state.register_attempt();
162        assert_eq!(state.next_delay, Duration::from_millis(50)); // 1x
163
164        state.register_attempt();
165        assert_eq!(state.next_delay, Duration::from_millis(100)); // 2x
166
167        state.register_attempt();
168        assert_eq!(state.next_delay, Duration::from_millis(200)); // 4x
169
170        state.register_attempt();
171        assert_eq!(state.next_delay, Duration::from_millis(400)); // 8x
172    }
173
174    #[test]
175    fn can_retry_returns_false_when_limit_reached() {
176        let mut state = RetryState {
177            current_attempt: 0,
178            max_attempts: 2,
179            base_interval: Duration::from_millis(100),
180            last_attempt_at: None,
181            next_delay: Duration::ZERO,
182            strategy: BackoffStrategy::Linear,
183            max_interval: None,
184        };
185
186        assert!(state.can_retry());
187        state.register_attempt();
188        assert!(state.can_retry());
189        state.register_attempt();
190        assert!(!state.can_retry());
191    }
192
193    #[test]
194    fn exponential_backoff_saturates_safely_at_large_attempts() {
195        // Verify no panic when exceeding shift limits in release mode
196        let mut state = RetryState {
197            current_attempt: 31,
198            max_attempts: 32,
199            base_interval: Duration::from_millis(1),
200            last_attempt_at: None,
201            next_delay: Duration::ZERO,
202            strategy: BackoffStrategy::Exponential,
203            max_interval: None,
204        };
205
206        let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
207            state.register_attempt();
208        }));
209
210        assert!(state.next_delay > Duration::ZERO);
211    }
212
213    #[test]
214    fn retry_state_from_clamps_initial_delay() {
215        let retry = DestinationRetry {
216            max_attempts: 3,
217            interval_ms: 200,
218            backoff_type: DestinationRetryBackoffType::Linear,
219            max_interval: Some(150),
220        };
221
222        let state = RetryState::from(retry);
223
224        assert_eq!(state.next_delay, Duration::from_millis(150));
225    }
226
227    #[test]
228    fn max_interval_caps_backoff_growth() {
229        let mut state = RetryState {
230            current_attempt: 0,
231            max_attempts: 4,
232            base_interval: Duration::from_millis(100),
233            last_attempt_at: None,
234            next_delay: Duration::from_millis(100),
235            strategy: BackoffStrategy::Exponential,
236            max_interval: Some(Duration::from_millis(250)),
237        };
238
239        state.register_attempt();
240        assert_eq!(state.next_delay, Duration::from_millis(100));
241
242        state.register_attempt();
243        assert_eq!(state.next_delay, Duration::from_millis(200));
244
245        state.register_attempt();
246        assert_eq!(state.next_delay, Duration::from_millis(250));
247
248        state.register_attempt();
249        assert_eq!(state.next_delay, Duration::from_millis(250));
250    }
251
252    async fn fake_request(
253        should_succeed_on: u32,
254        attempt: u32,
255    ) -> Result<&'static str, &'static str> {
256        if attempt >= should_succeed_on {
257            Ok("success")
258        } else {
259            Err("failed")
260        }
261    }
262
263    #[tokio::test]
264    async fn retry_loop_with_exponential_backoff_succeeds_after_expected_attempts() {
265        let mut state = RetryState {
266            current_attempt: 0,
267            max_attempts: 5,
268            base_interval: Duration::from_millis(100),
269            last_attempt_at: None,
270            next_delay: Duration::ZERO,
271            strategy: BackoffStrategy::Exponential,
272            max_interval: None,
273        };
274
275        let start = Instant::now();
276        let mut result = Err("not started");
277
278        while state.can_retry() {
279            result = fake_request(3, state.current_attempt).await;
280            if result.is_ok() {
281                break;
282            }
283
284            state.register_attempt();
285            sleep(state.next_delay).await;
286        }
287
288        let elapsed = start.elapsed();
289
290        assert_eq!(result, Ok("success"));
291        assert_eq!(state.current_attempt, 3);
292
293        // Expected 100 + 200 + 400 = ~700ms total wait
294        assert!(
295            elapsed >= Duration::from_millis(650) && elapsed <= Duration::from_millis(850),
296            "elapsed = {:?}",
297            elapsed
298        );
299    }
300
301    #[tokio::test]
302    async fn retry_loop_with_linear_backoff_fails_after_max_attempts() {
303        let mut state = RetryState {
304            current_attempt: 0,
305            max_attempts: 4,
306            base_interval: Duration::from_millis(100),
307            last_attempt_at: None,
308            next_delay: Duration::ZERO,
309            strategy: BackoffStrategy::Linear,
310            max_interval: None,
311        };
312
313        let start = Instant::now();
314        let mut result = Err("failed");
315
316        while state.can_retry() {
317            result = fake_request(10, state.current_attempt).await; // always fails
318            if result.is_ok() {
319                break;
320            }
321
322            state.register_attempt();
323            sleep(state.next_delay).await;
324        }
325
326        let elapsed = start.elapsed();
327
328        assert_eq!(result, Err("failed"));
329        assert_eq!(state.current_attempt, state.max_attempts);
330
331        // Expected 100 + 200 + 300 + 400 = ~1000ms total
332        assert!(
333            elapsed >= Duration::from_millis(900) && elapsed <= Duration::from_millis(1100),
334            "elapsed = {:?}",
335            elapsed
336        );
337    }
338
339    #[tokio::test]
340    async fn retry_loop_with_none_backoff_retries_immediately() {
341        let mut state = RetryState {
342            current_attempt: 0,
343            max_attempts: 3,
344            base_interval: Duration::from_millis(100),
345            last_attempt_at: None,
346            next_delay: Duration::ZERO,
347            strategy: BackoffStrategy::None,
348            max_interval: None,
349        };
350
351        let start = Instant::now();
352        let mut result = Err("failed");
353
354        while state.can_retry() {
355            result = fake_request(2, state.current_attempt).await;
356            if result.is_ok() {
357                break;
358            }
359
360            state.register_attempt();
361            sleep(state.next_delay).await;
362        }
363
364        let elapsed = start.elapsed();
365
366        assert_eq!(result, Ok("success"));
367        assert_eq!(state.current_attempt, 2);
368
369        // Expected 0 + 100 + 100 = ~200ms total
370        assert!(
371            elapsed >= Duration::from_millis(150) && elapsed <= Duration::from_millis(300),
372            "elapsed = {:?}",
373            elapsed
374        );
375    }
376
377    #[test]
378    fn retry_state_from_clamps_max_attempts_to_u32_max() {
379        let retry = DestinationRetry {
380            max_attempts: (u32::MAX as u64) + 42,
381            interval_ms: 50,
382            backoff_type: DestinationRetryBackoffType::Linear,
383            max_interval: None,
384        };
385
386        let state = RetryState::from(retry);
387
388        assert_eq!(state.max_attempts, u32::MAX);
389    }
390
391    #[test]
392    fn exponential_backoff_from_config_respects_max_interval_sequence() {
393        let retry = DestinationRetry {
394            max_attempts: 5,
395            interval_ms: 100,
396            backoff_type: DestinationRetryBackoffType::Exponential,
397            max_interval: Some(250),
398        };
399
400        let mut state = RetryState::from(retry);
401        let mut observed = Vec::new();
402
403        for _ in 0..state.max_attempts {
404            state.register_attempt();
405            observed.push(state.next_delay);
406        }
407
408        let expected = [
409            Duration::from_millis(100),
410            Duration::from_millis(200),
411            Duration::from_millis(250),
412            Duration::from_millis(250),
413            Duration::from_millis(250),
414        ];
415
416        assert_eq!(&observed[..], &expected);
417        assert!(!state.can_retry());
418    }
419
420    #[tokio::test]
421    async fn sleep_if_retry_allowed_returns_false_when_no_attempts_left() {
422        let retry = DestinationRetry {
423            max_attempts: 2,
424            interval_ms: 10,
425            backoff_type: DestinationRetryBackoffType::Linear,
426            max_interval: Some(10),
427        };
428
429        let mut state = RetryState::from(retry);
430
431        state.register_attempt();
432        assert!(state.can_retry());
433
434        state.register_attempt();
435        assert!(!state.can_retry());
436
437        let slept = state.sleep_if_retry_allowed().await;
438        assert!(!slept);
439        assert_eq!(state.current_attempt, state.max_attempts);
440    }
441
442    #[test]
443    fn exponential_backoff_does_not_overflow_large_base_interval() {
444        let retry = DestinationRetry {
445            max_attempts: 100,
446            interval_ms: u64::MAX / 4,
447            backoff_type: DestinationRetryBackoffType::Exponential,
448            max_interval: None,
449        };
450
451        let mut state = RetryState::from(retry);
452
453        for _ in 0..40 {
454            state.register_attempt();
455        }
456
457        assert_eq!(state.next_delay, Duration::MAX);
458        assert!(state.can_retry());
459    }
460
461    #[tokio::test]
462    async fn retry_loop_with_real_waits_respects_limits() {
463        let retry = DestinationRetry {
464            max_attempts: 4,
465            interval_ms: 90,
466            backoff_type: DestinationRetryBackoffType::Exponential,
467            max_interval: Some(200),
468        };
469
470        let mut state = RetryState::from(retry);
471        let mut observed_delays = Vec::new();
472        let mut sleep_calls = 0;
473
474        while state.can_retry() {
475            state.register_attempt();
476            observed_delays.push(state.next_delay);
477
478            if !state.can_retry() {
479                assert!(!state.sleep_if_retry_allowed().await);
480                break;
481            }
482
483            assert!(state.next_delay <= Duration::from_millis(200));
484            assert!(state.sleep_if_retry_allowed().await);
485            sleep_calls += 1;
486        }
487
488        assert_eq!(state.current_attempt, state.max_attempts);
489        assert_eq!(sleep_calls, (state.max_attempts - 1) as usize);
490        assert_eq!(
491            observed_delays,
492            vec![
493                Duration::from_millis(90),
494                Duration::from_millis(180),
495                Duration::from_millis(200),
496                Duration::from_millis(200),
497            ]
498        );
499    }
500}