1use chrono::{DateTime, Duration, Utc};
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub enum RetryStrategy {
11 None,
13 Fixed {
15 interval: Duration,
17 max_attempts: u32,
19 },
20 ExponentialBackoff {
22 initial_delay: Duration,
24 multiplier: f64,
26 max_delay: Duration,
28 max_attempts: u32,
30 jitter: bool,
32 },
33 LinearBackoff {
35 initial_delay: Duration,
37 increment: Duration,
39 max_delay: Duration,
41 max_attempts: u32,
43 },
44}
45
46impl Default for RetryStrategy {
47 fn default() -> Self {
48 Self::ExponentialBackoff {
49 initial_delay: Duration::seconds(1),
50 multiplier: 2.0,
51 max_delay: Duration::minutes(15),
52 max_attempts: 5,
53 jitter: true,
54 }
55 }
56}
57
58impl RetryStrategy {
59 pub fn none() -> Self {
61 Self::None
62 }
63
64 pub fn fixed(interval: Duration, max_attempts: u32) -> Self {
66 Self::Fixed {
67 interval,
68 max_attempts,
69 }
70 }
71
72 pub fn exponential_backoff(
74 initial_delay: Duration,
75 multiplier: f64,
76 max_delay: Duration,
77 max_attempts: u32,
78 ) -> Self {
79 Self::ExponentialBackoff {
80 initial_delay,
81 multiplier,
82 max_delay,
83 max_attempts,
84 jitter: true,
85 }
86 }
87
88 pub fn linear_backoff(
90 initial_delay: Duration,
91 increment: Duration,
92 max_delay: Duration,
93 max_attempts: u32,
94 ) -> Self {
95 Self::LinearBackoff {
96 initial_delay,
97 increment,
98 max_delay,
99 max_attempts,
100 }
101 }
102
103 pub fn calculate_delay(&self, attempt: u32) -> Option<Duration> {
105 match self {
106 RetryStrategy::None => None,
107 RetryStrategy::Fixed {
108 interval,
109 max_attempts,
110 } => {
111 if attempt <= *max_attempts {
112 Some(*interval)
113 } else {
114 None
115 }
116 }
117 RetryStrategy::ExponentialBackoff {
118 initial_delay,
119 multiplier,
120 max_delay,
121 max_attempts,
122 jitter,
123 } => {
124 if attempt > *max_attempts {
125 return None;
126 }
127
128 let mut delay = initial_delay.num_milliseconds() as f64;
129 for _ in 1..attempt {
130 delay *= multiplier;
131 }
132
133 delay = delay.min(max_delay.num_milliseconds() as f64);
135
136 if *jitter {
138 let jitter_amount = delay * 0.25;
139 let random_factor = fastrand::f64() * 2.0 - 1.0; delay += jitter_amount * random_factor;
141 }
142
143 Some(Duration::milliseconds(delay as i64))
144 }
145 RetryStrategy::LinearBackoff {
146 initial_delay,
147 increment,
148 max_delay,
149 max_attempts,
150 } => {
151 if attempt > *max_attempts {
152 return None;
153 }
154
155 let delay = *initial_delay + *increment * (attempt as i32 - 1);
156 Some(delay.min(*max_delay))
157 }
158 }
159 }
160
161 pub fn max_attempts(&self) -> u32 {
163 match self {
164 RetryStrategy::None => 0,
165 RetryStrategy::Fixed { max_attempts, .. }
166 | RetryStrategy::ExponentialBackoff { max_attempts, .. }
167 | RetryStrategy::LinearBackoff { max_attempts, .. } => *max_attempts,
168 }
169 }
170
171 pub fn is_retry_enabled(&self) -> bool {
173 !matches!(self, RetryStrategy::None)
174 }
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct RetryPolicy {
180 pub strategy: RetryStrategy,
182 pub retry_on_all_exceptions: bool,
184 pub retryable_exceptions: Vec<String>,
186 pub non_retryable_exceptions: Vec<String>,
188}
189
190impl Default for RetryPolicy {
191 fn default() -> Self {
192 Self {
193 strategy: RetryStrategy::default(),
194 retry_on_all_exceptions: true,
195 retryable_exceptions: vec![],
196 non_retryable_exceptions: vec![
197 "ArgumentError".to_string(),
198 "ValidationError".to_string(),
199 "AuthenticationError".to_string(),
200 "AuthorizationError".to_string(),
201 ],
202 }
203 }
204}
205
206impl RetryPolicy {
207 pub fn new(strategy: RetryStrategy) -> Self {
209 Self {
210 strategy,
211 ..Default::default()
212 }
213 }
214
215 pub fn no_retry() -> Self {
217 Self {
218 strategy: RetryStrategy::None,
219 retry_on_all_exceptions: false,
220 retryable_exceptions: vec![],
221 non_retryable_exceptions: vec![],
222 }
223 }
224
225 pub fn retry_on_all_exceptions(mut self, retry_all: bool) -> Self {
227 self.retry_on_all_exceptions = retry_all;
228 self
229 }
230
231 pub fn add_retryable_exception(mut self, exception_type: String) -> Self {
233 self.retryable_exceptions.push(exception_type);
234 self
235 }
236
237 pub fn add_non_retryable_exception(mut self, exception_type: String) -> Self {
239 self.non_retryable_exceptions.push(exception_type);
240 self
241 }
242
243 pub fn should_retry(&self, exception_type: Option<&str>, attempt: u32) -> bool {
245 if attempt > self.strategy.max_attempts() {
247 return false;
248 }
249
250 if !self.strategy.is_retry_enabled() {
252 return false;
253 }
254
255 let exception_type = match exception_type {
257 Some(ex) => ex,
258 None => return self.retry_on_all_exceptions, };
260
261 if self
263 .non_retryable_exceptions
264 .contains(&exception_type.to_string())
265 {
266 return false;
267 }
268
269 if self.retry_on_all_exceptions {
271 return true;
272 }
273
274 self.retryable_exceptions
276 .contains(&exception_type.to_string())
277 }
278
279 pub fn calculate_retry_time(&self, attempt: u32) -> Option<DateTime<Utc>> {
281 self.strategy
282 .calculate_delay(attempt)
283 .map(|delay| Utc::now() + delay)
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 #[test]
292 fn test_exponential_backoff_calculation() {
293 let strategy = RetryStrategy::ExponentialBackoff {
294 initial_delay: Duration::seconds(1),
295 multiplier: 2.0,
296 max_delay: Duration::minutes(5),
297 max_attempts: 3,
298 jitter: false, };
300
301 let delay1 = strategy.calculate_delay(1).unwrap();
303 assert!(delay1.num_seconds() >= 1);
304
305 let delay2 = strategy.calculate_delay(2).unwrap();
307 assert!(delay2.num_seconds() >= 2);
308
309 assert!(strategy.calculate_delay(4).is_none());
311 }
312
313 #[test]
314 fn test_fixed_retry_calculation() {
315 let strategy = RetryStrategy::fixed(Duration::seconds(5), 2);
316
317 let delay1 = strategy.calculate_delay(1).unwrap();
318 assert_eq!(delay1.num_seconds(), 5);
319
320 let delay2 = strategy.calculate_delay(2).unwrap();
321 assert_eq!(delay2.num_seconds(), 5);
322
323 assert!(strategy.calculate_delay(3).is_none());
325 }
326
327 #[test]
328 fn test_retry_policy_should_retry() {
329 let policy = RetryPolicy::default();
330
331 assert!(policy.should_retry(Some("NetworkError"), 1));
333
334 assert!(!policy.should_retry(Some("ArgumentError"), 1));
336
337 assert!(!policy.should_retry(Some("NetworkError"), 10));
339 }
340
341 #[test]
342 fn test_linear_backoff_calculation() {
343 let strategy = RetryStrategy::linear_backoff(
344 Duration::seconds(1),
345 Duration::seconds(2),
346 Duration::minutes(1),
347 3,
348 );
349
350 let delay1 = strategy.calculate_delay(1).unwrap();
351 assert_eq!(delay1.num_seconds(), 1); let delay2 = strategy.calculate_delay(2).unwrap();
354 assert_eq!(delay2.num_seconds(), 3); let delay3 = strategy.calculate_delay(3).unwrap();
357 assert_eq!(delay3.num_seconds(), 5); }
359}