1use cardinal_config::{DestinationRetry, DestinationRetryBackoffType};
2use serde::{Deserialize, Serialize};
3use std::time::{Duration, Instant};
4
5#[derive(Debug, Serialize, Deserialize)]
6pub enum BackoffStrategy {
7 Exponential,
8 Linear,
9 None,
10}
11
12pub struct RetryState {
13 pub current_attempt: u32,
15
16 pub max_attempts: u32,
18
19 pub base_interval: Duration,
21
22 pub last_attempt_at: Option<Instant>,
24
25 pub next_delay: Duration,
27
28 pub strategy: BackoffStrategy,
30
31 pub max_interval: Option<Duration>,
33}
34
35impl From<DestinationRetry> for RetryState {
36 fn from(value: DestinationRetry) -> Self {
37 let base_interval = Duration::from_millis(value.interval_ms);
38 let max_interval = value.max_interval.map(Duration::from_millis);
39 let initial_delay = max_interval
40 .map(|max| base_interval.min(max))
41 .unwrap_or(base_interval);
42
43 RetryState {
44 current_attempt: 0,
45 max_attempts: value.max_attempts.min(u32::MAX as u64) as u32,
46 base_interval,
47 last_attempt_at: None,
48 next_delay: initial_delay,
49 strategy: match value.backoff_type {
50 DestinationRetryBackoffType::Exponential => BackoffStrategy::Exponential,
51 DestinationRetryBackoffType::Linear => BackoffStrategy::Linear,
52 DestinationRetryBackoffType::None => BackoffStrategy::None,
53 },
54 max_interval,
55 }
56 }
57}
58
59impl RetryState {
60 pub fn register_attempt(&mut self) {
61 self.current_attempt += 1;
62 self.last_attempt_at = Some(Instant::now());
63
64 let mut next_delay = match self.strategy {
66 BackoffStrategy::None => self.base_interval,
67 BackoffStrategy::Linear => self
68 .base_interval
69 .saturating_mul(self.current_attempt.max(1)),
70 BackoffStrategy::Exponential => {
71 let shift = (self.current_attempt - 1).min(31);
72 let multiplier = 1u32 << shift;
73 self.base_interval.saturating_mul(multiplier)
74 }
75 };
76
77 if let Some(max_interval) = self.max_interval {
78 if next_delay > max_interval {
79 next_delay = max_interval;
80 }
81 }
82
83 self.next_delay = next_delay;
84 }
85
86 pub fn can_retry(&self) -> bool {
87 self.current_attempt < self.max_attempts
88 }
89
90 pub async fn sleep_if_retry_allowed(&mut self) -> bool {
91 if self.can_retry() {
92 tokio::time::sleep(self.next_delay).await;
93 true
94 } else {
95 false
96 }
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103 use std::time::Duration;
104 use tokio::time::sleep;
105 #[test]
110 fn none_backoff_increments_and_uses_fixed_interval() {
111 let mut state = RetryState {
112 current_attempt: 0,
113 max_attempts: 3,
114 base_interval: Duration::from_millis(100),
115 last_attempt_at: None,
116 next_delay: Duration::ZERO,
117 strategy: BackoffStrategy::None,
118 max_interval: None,
119 };
120
121 state.register_attempt();
122 assert_eq!(state.current_attempt, 1);
123 assert_eq!(state.next_delay, Duration::from_millis(100));
124 assert!(state.last_attempt_at.is_some());
125 }
126
127 #[test]
128 fn linear_backoff_grows_linearly() {
129 let mut state = RetryState {
130 current_attempt: 0,
131 max_attempts: 3,
132 base_interval: Duration::from_millis(100),
133 last_attempt_at: None,
134 next_delay: Duration::ZERO,
135 strategy: BackoffStrategy::Linear,
136 max_interval: None,
137 };
138
139 state.register_attempt();
140 assert_eq!(state.next_delay, Duration::from_millis(100));
141
142 state.register_attempt();
143 assert_eq!(state.next_delay, Duration::from_millis(200));
144
145 state.register_attempt();
146 assert_eq!(state.next_delay, Duration::from_millis(300));
147 }
148
149 #[test]
150 fn exponential_backoff_doubles_each_attempt() {
151 let mut state = RetryState {
152 current_attempt: 0,
153 max_attempts: 4,
154 base_interval: Duration::from_millis(50),
155 last_attempt_at: None,
156 next_delay: Duration::ZERO,
157 strategy: BackoffStrategy::Exponential,
158 max_interval: None,
159 };
160
161 state.register_attempt();
162 assert_eq!(state.next_delay, Duration::from_millis(50)); state.register_attempt();
165 assert_eq!(state.next_delay, Duration::from_millis(100)); state.register_attempt();
168 assert_eq!(state.next_delay, Duration::from_millis(200)); state.register_attempt();
171 assert_eq!(state.next_delay, Duration::from_millis(400)); }
173
174 #[test]
175 fn can_retry_returns_false_when_limit_reached() {
176 let mut state = RetryState {
177 current_attempt: 0,
178 max_attempts: 2,
179 base_interval: Duration::from_millis(100),
180 last_attempt_at: None,
181 next_delay: Duration::ZERO,
182 strategy: BackoffStrategy::Linear,
183 max_interval: None,
184 };
185
186 assert!(state.can_retry());
187 state.register_attempt();
188 assert!(state.can_retry());
189 state.register_attempt();
190 assert!(!state.can_retry());
191 }
192
193 #[test]
194 fn exponential_backoff_saturates_safely_at_large_attempts() {
195 let mut state = RetryState {
197 current_attempt: 31,
198 max_attempts: 32,
199 base_interval: Duration::from_millis(1),
200 last_attempt_at: None,
201 next_delay: Duration::ZERO,
202 strategy: BackoffStrategy::Exponential,
203 max_interval: None,
204 };
205
206 let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
207 state.register_attempt();
208 }));
209
210 assert!(state.next_delay > Duration::ZERO);
211 }
212
213 #[test]
214 fn retry_state_from_clamps_initial_delay() {
215 let retry = DestinationRetry {
216 max_attempts: 3,
217 interval_ms: 200,
218 backoff_type: DestinationRetryBackoffType::Linear,
219 max_interval: Some(150),
220 };
221
222 let state = RetryState::from(retry);
223
224 assert_eq!(state.next_delay, Duration::from_millis(150));
225 }
226
227 #[test]
228 fn max_interval_caps_backoff_growth() {
229 let mut state = RetryState {
230 current_attempt: 0,
231 max_attempts: 4,
232 base_interval: Duration::from_millis(100),
233 last_attempt_at: None,
234 next_delay: Duration::from_millis(100),
235 strategy: BackoffStrategy::Exponential,
236 max_interval: Some(Duration::from_millis(250)),
237 };
238
239 state.register_attempt();
240 assert_eq!(state.next_delay, Duration::from_millis(100));
241
242 state.register_attempt();
243 assert_eq!(state.next_delay, Duration::from_millis(200));
244
245 state.register_attempt();
246 assert_eq!(state.next_delay, Duration::from_millis(250));
247
248 state.register_attempt();
249 assert_eq!(state.next_delay, Duration::from_millis(250));
250 }
251
252 async fn fake_request(
253 should_succeed_on: u32,
254 attempt: u32,
255 ) -> Result<&'static str, &'static str> {
256 if attempt >= should_succeed_on {
257 Ok("success")
258 } else {
259 Err("failed")
260 }
261 }
262
263 #[tokio::test]
264 async fn retry_loop_with_exponential_backoff_succeeds_after_expected_attempts() {
265 let mut state = RetryState {
266 current_attempt: 0,
267 max_attempts: 5,
268 base_interval: Duration::from_millis(100),
269 last_attempt_at: None,
270 next_delay: Duration::ZERO,
271 strategy: BackoffStrategy::Exponential,
272 max_interval: None,
273 };
274
275 let start = Instant::now();
276 let mut result = Err("not started");
277
278 while state.can_retry() {
279 result = fake_request(3, state.current_attempt).await;
280 if result.is_ok() {
281 break;
282 }
283
284 state.register_attempt();
285 sleep(state.next_delay).await;
286 }
287
288 let elapsed = start.elapsed();
289
290 assert_eq!(result, Ok("success"));
291 assert_eq!(state.current_attempt, 3);
292
293 assert!(
295 elapsed >= Duration::from_millis(650) && elapsed <= Duration::from_millis(850),
296 "elapsed = {:?}",
297 elapsed
298 );
299 }
300
301 #[tokio::test]
302 async fn retry_loop_with_linear_backoff_fails_after_max_attempts() {
303 let mut state = RetryState {
304 current_attempt: 0,
305 max_attempts: 4,
306 base_interval: Duration::from_millis(100),
307 last_attempt_at: None,
308 next_delay: Duration::ZERO,
309 strategy: BackoffStrategy::Linear,
310 max_interval: None,
311 };
312
313 let start = Instant::now();
314 let mut result = Err("failed");
315
316 while state.can_retry() {
317 result = fake_request(10, state.current_attempt).await; if result.is_ok() {
319 break;
320 }
321
322 state.register_attempt();
323 sleep(state.next_delay).await;
324 }
325
326 let elapsed = start.elapsed();
327
328 assert_eq!(result, Err("failed"));
329 assert_eq!(state.current_attempt, state.max_attempts);
330
331 assert!(
333 elapsed >= Duration::from_millis(900) && elapsed <= Duration::from_millis(1100),
334 "elapsed = {:?}",
335 elapsed
336 );
337 }
338
339 #[tokio::test]
340 async fn retry_loop_with_none_backoff_retries_immediately() {
341 let mut state = RetryState {
342 current_attempt: 0,
343 max_attempts: 3,
344 base_interval: Duration::from_millis(100),
345 last_attempt_at: None,
346 next_delay: Duration::ZERO,
347 strategy: BackoffStrategy::None,
348 max_interval: None,
349 };
350
351 let start = Instant::now();
352 let mut result = Err("failed");
353
354 while state.can_retry() {
355 result = fake_request(2, state.current_attempt).await;
356 if result.is_ok() {
357 break;
358 }
359
360 state.register_attempt();
361 sleep(state.next_delay).await;
362 }
363
364 let elapsed = start.elapsed();
365
366 assert_eq!(result, Ok("success"));
367 assert_eq!(state.current_attempt, 2);
368
369 assert!(
371 elapsed >= Duration::from_millis(150) && elapsed <= Duration::from_millis(300),
372 "elapsed = {:?}",
373 elapsed
374 );
375 }
376
377 #[test]
378 fn retry_state_from_clamps_max_attempts_to_u32_max() {
379 let retry = DestinationRetry {
380 max_attempts: (u32::MAX as u64) + 42,
381 interval_ms: 50,
382 backoff_type: DestinationRetryBackoffType::Linear,
383 max_interval: None,
384 };
385
386 let state = RetryState::from(retry);
387
388 assert_eq!(state.max_attempts, u32::MAX);
389 }
390
391 #[test]
392 fn exponential_backoff_from_config_respects_max_interval_sequence() {
393 let retry = DestinationRetry {
394 max_attempts: 5,
395 interval_ms: 100,
396 backoff_type: DestinationRetryBackoffType::Exponential,
397 max_interval: Some(250),
398 };
399
400 let mut state = RetryState::from(retry);
401 let mut observed = Vec::new();
402
403 for _ in 0..state.max_attempts {
404 state.register_attempt();
405 observed.push(state.next_delay);
406 }
407
408 let expected = [
409 Duration::from_millis(100),
410 Duration::from_millis(200),
411 Duration::from_millis(250),
412 Duration::from_millis(250),
413 Duration::from_millis(250),
414 ];
415
416 assert_eq!(&observed[..], &expected);
417 assert!(!state.can_retry());
418 }
419
420 #[tokio::test]
421 async fn sleep_if_retry_allowed_returns_false_when_no_attempts_left() {
422 let retry = DestinationRetry {
423 max_attempts: 2,
424 interval_ms: 10,
425 backoff_type: DestinationRetryBackoffType::Linear,
426 max_interval: Some(10),
427 };
428
429 let mut state = RetryState::from(retry);
430
431 state.register_attempt();
432 assert!(state.can_retry());
433
434 state.register_attempt();
435 assert!(!state.can_retry());
436
437 let slept = state.sleep_if_retry_allowed().await;
438 assert!(!slept);
439 assert_eq!(state.current_attempt, state.max_attempts);
440 }
441
442 #[test]
443 fn exponential_backoff_does_not_overflow_large_base_interval() {
444 let retry = DestinationRetry {
445 max_attempts: 100,
446 interval_ms: u64::MAX / 4,
447 backoff_type: DestinationRetryBackoffType::Exponential,
448 max_interval: None,
449 };
450
451 let mut state = RetryState::from(retry);
452
453 for _ in 0..40 {
454 state.register_attempt();
455 }
456
457 assert_eq!(state.next_delay, Duration::MAX);
458 assert!(state.can_retry());
459 }
460
461 #[tokio::test]
462 async fn retry_loop_with_real_waits_respects_limits() {
463 let retry = DestinationRetry {
464 max_attempts: 4,
465 interval_ms: 90,
466 backoff_type: DestinationRetryBackoffType::Exponential,
467 max_interval: Some(200),
468 };
469
470 let mut state = RetryState::from(retry);
471 let mut observed_delays = Vec::new();
472 let mut sleep_calls = 0;
473
474 while state.can_retry() {
475 state.register_attempt();
476 observed_delays.push(state.next_delay);
477
478 if !state.can_retry() {
479 assert!(!state.sleep_if_retry_allowed().await);
480 break;
481 }
482
483 assert!(state.next_delay <= Duration::from_millis(200));
484 assert!(state.sleep_if_retry_allowed().await);
485 sleep_calls += 1;
486 }
487
488 assert_eq!(state.current_attempt, state.max_attempts);
489 assert_eq!(sleep_calls, (state.max_attempts - 1) as usize);
490 assert_eq!(
491 observed_delays,
492 vec![
493 Duration::from_millis(90),
494 Duration::from_millis(180),
495 Duration::from_millis(200),
496 Duration::from_millis(200),
497 ]
498 );
499 }
500}