1use serde::{Deserialize, Serialize};
2use std::time::Duration;
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
6#[serde(tag = "type", rename_all = "snake_case")]
7pub enum BackoffStrategy {
8 Fixed { delay_secs: u64 },
10 Exponential { base_secs: u64, max_secs: u64 },
12}
13
14impl BackoffStrategy {
15 pub fn delay_for(&self, retry: u32) -> Duration {
17 match self {
18 BackoffStrategy::Fixed { delay_secs } => Duration::from_secs(*delay_secs),
19 BackoffStrategy::Exponential {
20 base_secs,
21 max_secs,
22 } => {
23 let exp = 1u64.checked_shl(retry).unwrap_or(u64::MAX);
24 let delay = base_secs.saturating_mul(exp);
25 Duration::from_secs(delay.min(*max_secs))
26 }
27 }
28 }
29}
30
31impl Default for BackoffStrategy {
32 fn default() -> Self {
33 BackoffStrategy::Exponential {
34 base_secs: 1,
35 max_secs: 300,
36 }
37 }
38}
39
40#[cfg(test)]
41mod tests {
42 use super::*;
43
44 #[test]
45 fn fixed_backoff() {
46 let strategy = BackoffStrategy::Fixed { delay_secs: 5 };
47 assert_eq!(strategy.delay_for(0), Duration::from_secs(5));
48 assert_eq!(strategy.delay_for(1), Duration::from_secs(5));
49 assert_eq!(strategy.delay_for(10), Duration::from_secs(5));
50 }
51
52 #[test]
53 fn exponential_backoff() {
54 let strategy = BackoffStrategy::Exponential {
55 base_secs: 1,
56 max_secs: 60,
57 };
58 assert_eq!(strategy.delay_for(0), Duration::from_secs(1));
59 assert_eq!(strategy.delay_for(1), Duration::from_secs(2));
60 assert_eq!(strategy.delay_for(2), Duration::from_secs(4));
61 assert_eq!(strategy.delay_for(3), Duration::from_secs(8));
62 assert_eq!(strategy.delay_for(10), Duration::from_secs(60));
64 }
65
66 #[test]
67 fn exponential_backoff_overflow() {
68 let strategy = BackoffStrategy::Exponential {
69 base_secs: 1,
70 max_secs: 300,
71 };
72 assert_eq!(strategy.delay_for(100), Duration::from_secs(300));
74 }
75
76 #[test]
77 fn backoff_serde_roundtrip() {
78 let strategy = BackoffStrategy::Exponential {
79 base_secs: 2,
80 max_secs: 120,
81 };
82 let json = serde_json::to_string(&strategy).unwrap();
83 let deserialized: BackoffStrategy = serde_json::from_str(&json).unwrap();
84 assert_eq!(strategy.delay_for(3), deserialized.delay_for(3));
85 }
86}