1use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::Duration;
13use tokio::sync::RwLock;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
17#[serde(rename_all = "snake_case")]
18pub enum RetryStrategy {
19 Fixed,
21 Linear,
23 #[default]
25 Exponential,
26 ExponentialWithJitter,
28}
29
30impl std::fmt::Display for RetryStrategy {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 match self {
33 RetryStrategy::Fixed => write!(f, "fixed"),
34 RetryStrategy::Linear => write!(f, "linear"),
35 RetryStrategy::Exponential => write!(f, "exponential"),
36 RetryStrategy::ExponentialWithJitter => write!(f, "exponential_with_jitter"),
37 }
38 }
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43#[serde(rename_all = "camelCase")]
44pub struct RetryConfig {
45 pub max_retries: u32,
47 pub base_delay: Duration,
49 pub max_delay: Duration,
51 pub strategy: RetryStrategy,
53 pub jitter_factor: f64,
55 pub retry_on_timeout: bool,
57 pub retryable_errors: Vec<String>,
59}
60
61impl Default for RetryConfig {
62 fn default() -> Self {
63 Self {
64 max_retries: 3,
65 base_delay: Duration::from_millis(1000),
66 max_delay: Duration::from_secs(30),
67 strategy: RetryStrategy::Exponential,
68 jitter_factor: 0.1,
69 retry_on_timeout: true,
70 retryable_errors: vec![
71 "network".to_string(),
72 "timeout".to_string(),
73 "rate_limit".to_string(),
74 "temporary".to_string(),
75 ],
76 }
77 }
78}
79
80impl RetryConfig {
81 pub fn new(max_retries: u32, base_delay: Duration) -> Self {
83 Self {
84 max_retries,
85 base_delay,
86 ..Default::default()
87 }
88 }
89
90 pub fn with_strategy(mut self, strategy: RetryStrategy) -> Self {
92 self.strategy = strategy;
93 self
94 }
95
96 pub fn with_max_delay(mut self, max_delay: Duration) -> Self {
98 self.max_delay = max_delay;
99 self
100 }
101
102 pub fn with_jitter_factor(mut self, factor: f64) -> Self {
104 self.jitter_factor = factor.clamp(0.0, 1.0);
105 self
106 }
107
108 pub fn with_retry_on_timeout(mut self, retry: bool) -> Self {
110 self.retry_on_timeout = retry;
111 self
112 }
113
114 pub fn with_retryable_error(mut self, error_type: impl Into<String>) -> Self {
116 self.retryable_errors.push(error_type.into());
117 self
118 }
119
120 pub fn calculate_delay(&self, attempt: u32) -> Duration {
122 let base_ms = self.base_delay.as_millis() as f64;
123 let max_ms = self.max_delay.as_millis() as f64;
124
125 let delay_ms = match self.strategy {
126 RetryStrategy::Fixed => base_ms,
127 RetryStrategy::Linear => base_ms * (attempt as f64 + 1.0),
128 RetryStrategy::Exponential => base_ms * 2.0_f64.powi(attempt as i32),
129 RetryStrategy::ExponentialWithJitter => {
130 let exp_delay = base_ms * 2.0_f64.powi(attempt as i32);
131 let jitter = exp_delay * self.jitter_factor * rand_jitter();
132 exp_delay + jitter
133 }
134 };
135
136 Duration::from_millis(delay_ms.min(max_ms) as u64)
137 }
138
139 pub fn is_retryable(&self, error_type: &str) -> bool {
141 self.retryable_errors
142 .iter()
143 .any(|e| error_type.to_lowercase().contains(&e.to_lowercase()))
144 }
145
146 pub fn validate(&self) -> Result<(), String> {
148 if self.max_retries == 0 {
149 return Err("max_retries must be greater than 0".to_string());
150 }
151 if self.base_delay.is_zero() {
152 return Err("base_delay must be greater than 0".to_string());
153 }
154 if self.max_delay < self.base_delay {
155 return Err("max_delay must be >= base_delay".to_string());
156 }
157 Ok(())
158 }
159}
160
161fn rand_jitter() -> f64 {
163 use std::time::SystemTime;
164 let nanos = SystemTime::now()
165 .duration_since(SystemTime::UNIX_EPOCH)
166 .map(|d| d.subsec_nanos())
167 .unwrap_or(0);
168 ((nanos % 2000) as f64 / 1000.0) - 1.0
170}
171
172#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
174#[serde(rename_all = "snake_case")]
175pub enum RetryResult {
176 Success,
178 Retry,
180 MaxRetriesExceeded,
182 NotRetryable,
184 Skipped,
186}
187
188impl std::fmt::Display for RetryResult {
189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 match self {
191 RetryResult::Success => write!(f, "success"),
192 RetryResult::Retry => write!(f, "retry"),
193 RetryResult::MaxRetriesExceeded => write!(f, "max_retries_exceeded"),
194 RetryResult::NotRetryable => write!(f, "not_retryable"),
195 RetryResult::Skipped => write!(f, "skipped"),
196 }
197 }
198}
199
200#[derive(Debug, Clone, Serialize, Deserialize)]
202#[serde(rename_all = "camelCase")]
203pub struct RetryState {
204 pub operation_id: String,
206 pub attempt: u32,
208 pub config: RetryConfig,
210 pub started_at: DateTime<Utc>,
212 pub last_attempt_at: Option<DateTime<Utc>>,
214 pub last_error: Option<String>,
216 pub total_delay: Duration,
218 pub succeeded: bool,
220}
221
222impl RetryState {
223 pub fn new(operation_id: impl Into<String>, config: RetryConfig) -> Self {
225 Self {
226 operation_id: operation_id.into(),
227 attempt: 0,
228 config,
229 started_at: Utc::now(),
230 last_attempt_at: None,
231 last_error: None,
232 total_delay: Duration::ZERO,
233 succeeded: false,
234 }
235 }
236
237 pub fn can_retry(&self) -> bool {
239 self.attempt < self.config.max_retries
240 }
241
242 pub fn next_delay(&self) -> Duration {
244 self.config.calculate_delay(self.attempt)
245 }
246
247 pub fn record_attempt(&mut self, error: Option<String>) {
249 self.attempt += 1;
250 self.last_attempt_at = Some(Utc::now());
251 self.last_error = error;
252 }
253
254 pub fn record_success(&mut self) {
256 self.succeeded = true;
257 self.last_attempt_at = Some(Utc::now());
258 }
259
260 pub fn add_delay(&mut self, delay: Duration) {
262 self.total_delay += delay;
263 }
264
265 pub fn elapsed(&self) -> Duration {
267 let elapsed = Utc::now().signed_duration_since(self.started_at);
268 elapsed.to_std().unwrap_or(Duration::ZERO)
269 }
270}
271
272#[derive(Debug)]
274pub struct RetryHandler {
275 states: HashMap<String, RetryState>,
277 default_config: RetryConfig,
279}
280
281impl Default for RetryHandler {
282 fn default() -> Self {
283 Self::new()
284 }
285}
286
287impl RetryHandler {
288 pub fn new() -> Self {
290 Self {
291 states: HashMap::new(),
292 default_config: RetryConfig::default(),
293 }
294 }
295
296 pub fn with_default_config(config: RetryConfig) -> Self {
298 Self {
299 states: HashMap::new(),
300 default_config: config,
301 }
302 }
303
304 pub fn start(&mut self, operation_id: &str) -> &RetryState {
306 self.start_with_config(operation_id, self.default_config.clone())
307 }
308
309 pub fn start_with_config(&mut self, operation_id: &str, config: RetryConfig) -> &RetryState {
311 let state = RetryState::new(operation_id, config);
312 self.states.insert(operation_id.to_string(), state);
313 self.states.get(operation_id).unwrap()
314 }
315
316 pub fn handle_failure(
318 &mut self,
319 operation_id: &str,
320 error_type: &str,
321 error_message: &str,
322 ) -> RetryResult {
323 let state = match self.states.get_mut(operation_id) {
324 Some(s) => s,
325 None => return RetryResult::Skipped,
326 };
327
328 if !state.config.is_retryable(error_type) {
330 return RetryResult::NotRetryable;
331 }
332
333 if !state.can_retry() {
335 return RetryResult::MaxRetriesExceeded;
336 }
337
338 state.record_attempt(Some(error_message.to_string()));
340
341 RetryResult::Retry
342 }
343
344 pub fn get_retry_delay(&self, operation_id: &str) -> Option<Duration> {
346 self.states.get(operation_id).map(|s| s.next_delay())
347 }
348
349 pub fn record_delay(&mut self, operation_id: &str, delay: Duration) {
351 if let Some(state) = self.states.get_mut(operation_id) {
352 state.add_delay(delay);
353 }
354 }
355
356 pub fn record_success(&mut self, operation_id: &str) {
358 if let Some(state) = self.states.get_mut(operation_id) {
359 state.record_success();
360 }
361 }
362
363 pub fn get_state(&self, operation_id: &str) -> Option<&RetryState> {
365 self.states.get(operation_id)
366 }
367
368 pub fn get_attempt(&self, operation_id: &str) -> Option<u32> {
370 self.states.get(operation_id).map(|s| s.attempt)
371 }
372
373 pub fn can_retry(&self, operation_id: &str) -> bool {
375 self.states
376 .get(operation_id)
377 .map(|s| s.can_retry())
378 .unwrap_or(false)
379 }
380
381 pub fn complete(&mut self, operation_id: &str) -> Option<RetryState> {
383 self.states.remove(operation_id)
384 }
385
386 pub fn clear(&mut self) {
388 self.states.clear();
389 }
390
391 pub fn active_count(&self) -> usize {
393 self.states.len()
394 }
395
396 pub fn set_default_config(&mut self, config: RetryConfig) {
398 self.default_config = config;
399 }
400
401 pub fn default_config(&self) -> &RetryConfig {
403 &self.default_config
404 }
405
406 pub async fn execute_with_retry<F, Fut, T, E>(
408 &mut self,
409 operation_id: &str,
410 mut operation: F,
411 ) -> Result<T, E>
412 where
413 F: FnMut() -> Fut,
414 Fut: std::future::Future<Output = Result<T, E>>,
415 E: std::fmt::Display,
416 {
417 self.start(operation_id);
418
419 loop {
420 match operation().await {
421 Ok(result) => {
422 self.record_success(operation_id);
423 return Ok(result);
424 }
425 Err(e) => {
426 let error_msg = e.to_string();
427 let result = self.handle_failure(operation_id, "general", &error_msg);
428
429 match result {
430 RetryResult::Retry => {
431 if let Some(delay) = self.get_retry_delay(operation_id) {
432 tokio::time::sleep(delay).await;
433 self.record_delay(operation_id, delay);
434 }
435 }
436 _ => return Err(e),
437 }
438 }
439 }
440 }
441 }
442}
443
444#[allow(dead_code)]
446pub type SharedRetryHandler = Arc<RwLock<RetryHandler>>;
447
448#[allow(dead_code)]
450pub fn new_shared_retry_handler() -> SharedRetryHandler {
451 Arc::new(RwLock::new(RetryHandler::new()))
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457
458 #[test]
459 fn test_retry_config_default() {
460 let config = RetryConfig::default();
461 assert_eq!(config.max_retries, 3);
462 assert_eq!(config.base_delay, Duration::from_millis(1000));
463 assert_eq!(config.strategy, RetryStrategy::Exponential);
464 }
465
466 #[test]
467 fn test_retry_config_calculate_delay_fixed() {
468 let config =
469 RetryConfig::new(3, Duration::from_millis(100)).with_strategy(RetryStrategy::Fixed);
470
471 assert_eq!(config.calculate_delay(0), Duration::from_millis(100));
472 assert_eq!(config.calculate_delay(1), Duration::from_millis(100));
473 assert_eq!(config.calculate_delay(2), Duration::from_millis(100));
474 }
475
476 #[test]
477 fn test_retry_config_calculate_delay_linear() {
478 let config =
479 RetryConfig::new(3, Duration::from_millis(100)).with_strategy(RetryStrategy::Linear);
480
481 assert_eq!(config.calculate_delay(0), Duration::from_millis(100));
482 assert_eq!(config.calculate_delay(1), Duration::from_millis(200));
483 assert_eq!(config.calculate_delay(2), Duration::from_millis(300));
484 }
485
486 #[test]
487 fn test_retry_config_calculate_delay_exponential() {
488 let config = RetryConfig::new(3, Duration::from_millis(100))
489 .with_strategy(RetryStrategy::Exponential);
490
491 assert_eq!(config.calculate_delay(0), Duration::from_millis(100));
492 assert_eq!(config.calculate_delay(1), Duration::from_millis(200));
493 assert_eq!(config.calculate_delay(2), Duration::from_millis(400));
494 }
495
496 #[test]
497 fn test_retry_config_max_delay() {
498 let config = RetryConfig::new(10, Duration::from_millis(100))
499 .with_strategy(RetryStrategy::Exponential)
500 .with_max_delay(Duration::from_millis(500));
501
502 assert_eq!(config.calculate_delay(5), Duration::from_millis(500));
504 }
505
506 #[test]
507 fn test_retry_config_is_retryable() {
508 let config = RetryConfig::default();
509
510 assert!(config.is_retryable("network_error"));
511 assert!(config.is_retryable("timeout"));
512 assert!(config.is_retryable("rate_limit_exceeded"));
513 assert!(!config.is_retryable("invalid_input"));
514 }
515
516 #[test]
517 fn test_retry_config_validate() {
518 let valid = RetryConfig::default();
519 assert!(valid.validate().is_ok());
520
521 let invalid_retries = RetryConfig {
522 max_retries: 0,
523 ..Default::default()
524 };
525 assert!(invalid_retries.validate().is_err());
526
527 let invalid_delay = RetryConfig {
528 base_delay: Duration::ZERO,
529 ..Default::default()
530 };
531 assert!(invalid_delay.validate().is_err());
532 }
533
534 #[test]
535 fn test_retry_state_creation() {
536 let config = RetryConfig::default();
537 let state = RetryState::new("op-1", config);
538
539 assert_eq!(state.operation_id, "op-1");
540 assert_eq!(state.attempt, 0);
541 assert!(!state.succeeded);
542 assert!(state.can_retry());
543 }
544
545 #[test]
546 fn test_retry_state_record_attempt() {
547 let config = RetryConfig::new(3, Duration::from_millis(100));
548 let mut state = RetryState::new("op-1", config);
549
550 state.record_attempt(Some("Error 1".to_string()));
551 assert_eq!(state.attempt, 1);
552 assert_eq!(state.last_error, Some("Error 1".to_string()));
553 assert!(state.can_retry());
554
555 state.record_attempt(Some("Error 2".to_string()));
556 state.record_attempt(Some("Error 3".to_string()));
557 assert_eq!(state.attempt, 3);
558 assert!(!state.can_retry());
559 }
560
561 #[test]
562 fn test_retry_handler_start() {
563 let mut handler = RetryHandler::new();
564 handler.start("op-1");
565
566 assert_eq!(handler.active_count(), 1);
567 assert!(handler.get_state("op-1").is_some());
568 }
569
570 #[test]
571 fn test_retry_handler_handle_failure() {
572 let mut handler = RetryHandler::new();
573 handler.start("op-1");
574
575 let result = handler.handle_failure("op-1", "network", "Connection failed");
576 assert_eq!(result, RetryResult::Retry);
577 assert_eq!(handler.get_attempt("op-1"), Some(1));
578 }
579
580 #[test]
581 fn test_retry_handler_handle_failure_not_retryable() {
582 let mut handler = RetryHandler::new();
583 handler.start("op-1");
584
585 let result = handler.handle_failure("op-1", "invalid_input", "Bad request");
586 assert_eq!(result, RetryResult::NotRetryable);
587 }
588
589 #[test]
590 fn test_retry_handler_handle_failure_max_exceeded() {
591 let config = RetryConfig::new(2, Duration::from_millis(100));
592 let mut handler = RetryHandler::with_default_config(config);
593 handler.start("op-1");
594
595 handler.handle_failure("op-1", "network", "Error 1");
596 handler.handle_failure("op-1", "network", "Error 2");
597 let result = handler.handle_failure("op-1", "network", "Error 3");
598
599 assert_eq!(result, RetryResult::MaxRetriesExceeded);
600 }
601
602 #[test]
603 fn test_retry_handler_record_success() {
604 let mut handler = RetryHandler::new();
605 handler.start("op-1");
606 handler.record_success("op-1");
607
608 let state = handler.get_state("op-1").unwrap();
609 assert!(state.succeeded);
610 }
611
612 #[test]
613 fn test_retry_handler_complete() {
614 let mut handler = RetryHandler::new();
615 handler.start("op-1");
616
617 let state = handler.complete("op-1");
618 assert!(state.is_some());
619 assert_eq!(handler.active_count(), 0);
620 }
621
622 #[test]
623 fn test_retry_result_display() {
624 assert_eq!(format!("{}", RetryResult::Success), "success");
625 assert_eq!(format!("{}", RetryResult::Retry), "retry");
626 assert_eq!(
627 format!("{}", RetryResult::MaxRetriesExceeded),
628 "max_retries_exceeded"
629 );
630 }
631}