1use std::time::{Duration, Instant};
10use std::sync::Arc;
11use tokio::sync::RwLock;
12use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Clone)]
16pub struct ErrorRecovery {
17 policy: RetryPolicy,
19 circuit_breaker: Arc<RwLock<CircuitBreakerPolicy>>,
21 stats: Arc<RwLock<RecoveryStats>>,
23 initialized: bool,
25}
26
27impl ErrorRecovery {
28 pub fn new() -> Self {
30 Self {
31 policy: RetryPolicy::default(),
32 circuit_breaker: Arc::new(RwLock::new(CircuitBreakerPolicy::new())),
33 stats: Arc::new(RwLock::new(RecoveryStats::new())),
34 initialized: false,
35 }
36 }
37
38 pub fn with_config(config: RecoveryConfig) -> Self {
40 Self {
41 policy: config.retry_policy,
42 circuit_breaker: Arc::new(RwLock::new(config.circuit_breaker)),
43 stats: Arc::new(RwLock::new(RecoveryStats::new())),
44 initialized: false,
45 }
46 }
47
48 pub async fn initialize(&mut self) -> Result<(), RecoveryError> {
50 let mut breaker = self.circuit_breaker.write().await;
52 breaker.reset();
53
54 let mut stats = self.stats.write().await;
56 stats.reset();
57
58 self.initialized = true;
59 Ok(())
60 }
61
62 pub async fn shutdown(&mut self) -> Result<(), RecoveryError> {
64 self.initialized = false;
65 Ok(())
66 }
67
68 pub fn is_initialized(&self) -> bool {
70 self.initialized
71 }
72
73 pub async fn execute_with_recovery<F, T, E>(
75 &self,
76 operation: F,
77 ) -> Result<T, RecoveryError>
78 where
79 F: Fn() -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<T, E>> + Send>>,
80 E: std::error::Error + Send + Sync + 'static,
81 {
82 if !self.initialized {
83 return Err(RecoveryError::NotInitialized);
84 }
85
86 let mut breaker = self.circuit_breaker.write().await;
88 if !breaker.can_execute() {
89 return Err(RecoveryError::CircuitBreakerOpen);
90 }
91
92 let mut attempt = 0;
93 let start_time = Instant::now();
94
95 loop {
96 attempt += 1;
97
98 let result = operation().await;
100
101 match result {
102 Ok(value) => {
103 breaker.record_success();
105 let mut stats = self.stats.write().await;
106 stats.record_success(attempt, start_time.elapsed());
107 return Ok(value);
108 }
109 Err(error) => {
110 let error_type = self.classify_error(&error);
112
113 if attempt > self.policy.max_retries {
114 breaker.record_failure();
116 let mut stats = self.stats.write().await;
117 stats.record_failure(attempt - 1, start_time.elapsed(), error_type.clone());
118 return Err(RecoveryError::MaxRetriesExceeded);
119 }
120
121 if !self.should_retry(&error_type, attempt) {
122 breaker.record_failure();
124 let mut stats = self.stats.write().await;
125 stats.record_failure(attempt - 1, start_time.elapsed(), error_type.clone());
126 return Err(RecoveryError::NonRetryableError(error_type));
127 }
128
129 let delay = self.calculate_delay(attempt);
131 tokio::time::sleep(delay).await;
132 }
133 }
134 }
135 }
136
137 fn classify_error<E: std::error::Error>(&self, error: &E) -> ErrorType {
139 let error_string = error.to_string().to_lowercase();
140
141 if error_string.contains("timeout") || error_string.contains("connection") {
142 ErrorType::Network
143 } else if error_string.contains("permission") || error_string.contains("unauthorized") {
144 ErrorType::Permission
145 } else if error_string.contains("not found") || error_string.contains("missing") {
146 ErrorType::NotFound
147 } else if error_string.contains("conflict") || error_string.contains("duplicate") {
148 ErrorType::Conflict
149 } else if error_string.contains("rate limit") || error_string.contains("throttle") {
150 ErrorType::RateLimit
151 } else {
152 ErrorType::Unknown
153 }
154 }
155
156 fn should_retry(&self, error_type: &ErrorType, attempt: usize) -> bool {
158 match error_type {
159 ErrorType::Network => true,
160 ErrorType::RateLimit => true,
161 ErrorType::Permission => false,
162 ErrorType::NotFound => false,
163 ErrorType::Conflict => false,
164 ErrorType::Unknown => attempt < 3, }
166 }
167
168 fn calculate_delay(&self, attempt: usize) -> Duration {
170 match &self.policy.strategy {
171 RetryStrategy::ExponentialBackoff(config) => {
172 let base_delay = config.base_delay;
173 let max_delay = config.max_delay;
174 let multiplier = config.multiplier;
175
176 let delay = base_delay * multiplier.pow(attempt as u32 - 1);
177 delay.min(max_delay)
178 }
179 RetryStrategy::Linear(config) => {
180 let base_delay = config.base_delay;
181 let increment = config.increment;
182 let max_delay = config.max_delay;
183
184 let delay = base_delay + (increment * (attempt - 1) as u32);
185 delay.min(max_delay)
186 }
187 RetryStrategy::Fixed(config) => config.delay,
188 }
189 }
190
191 pub async fn get_stats(&self) -> RecoveryStats {
193 self.stats.read().await.clone()
194 }
195
196 pub async fn get_circuit_breaker_status(&self) -> CircuitBreakerStatus {
198 let breaker = self.circuit_breaker.read().await;
199 breaker.get_status()
200 }
201
202 pub async fn reset_circuit_breaker(&self) {
204 let mut breaker = self.circuit_breaker.write().await;
205 breaker.reset();
206 }
207}
208
209#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
211pub struct RetryPolicy {
212 pub max_retries: usize,
214 pub strategy: RetryStrategy,
216}
217
218impl Default for RetryPolicy {
219 fn default() -> Self {
220 Self {
221 max_retries: 3,
222 strategy: RetryStrategy::ExponentialBackoff(ExponentialBackoffConfig::default()),
223 }
224 }
225}
226
227#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
229pub enum RetryStrategy {
230 ExponentialBackoff(ExponentialBackoffConfig),
232 Linear(LinearBackoffConfig),
234 Fixed(FixedDelayConfig),
236}
237
238#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
240pub struct ExponentialBackoffConfig {
241 pub base_delay: Duration,
243 pub max_delay: Duration,
245 pub multiplier: u32,
247}
248
249impl Default for ExponentialBackoffConfig {
250 fn default() -> Self {
251 Self {
252 base_delay: Duration::from_millis(100),
253 max_delay: Duration::from_secs(30),
254 multiplier: 2,
255 }
256 }
257}
258
259#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
261pub struct LinearBackoffConfig {
262 pub base_delay: Duration,
264 pub increment: Duration,
266 pub max_delay: Duration,
268}
269
270impl Default for LinearBackoffConfig {
271 fn default() -> Self {
272 Self {
273 base_delay: Duration::from_millis(100),
274 increment: Duration::from_millis(100),
275 max_delay: Duration::from_secs(30),
276 }
277 }
278}
279
280#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
282pub struct FixedDelayConfig {
283 pub delay: Duration,
285}
286
287impl Default for FixedDelayConfig {
288 fn default() -> Self {
289 Self {
290 delay: Duration::from_millis(1000),
291 }
292 }
293}
294
295#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
297pub struct CircuitBreakerPolicy {
298 state: CircuitBreakerState,
300 failure_count: usize,
302 success_count: usize,
304 failure_threshold: usize,
306 success_threshold: usize,
308 timeout: Duration,
310 #[serde(skip_serializing, skip_deserializing)]
312 last_failure_time: Option<Instant>,
313}
314
315impl CircuitBreakerPolicy {
316 pub fn new() -> Self {
318 Self {
319 state: CircuitBreakerState::Closed,
320 failure_count: 0,
321 success_count: 0,
322 failure_threshold: 5,
323 success_threshold: 3,
324 timeout: Duration::from_secs(60),
325 last_failure_time: None,
326 }
327 }
328
329 pub fn can_execute(&self) -> bool {
331 match self.state {
332 CircuitBreakerState::Closed => true,
333 CircuitBreakerState::Open => {
334 if let Some(last_failure) = self.last_failure_time {
335 last_failure.elapsed() >= self.timeout
336 } else {
337 false
338 }
339 }
340 CircuitBreakerState::HalfOpen => true,
341 }
342 }
343
344 pub fn record_success(&mut self) {
346 self.success_count += 1;
347 self.failure_count = 0;
348
349 if self.state == CircuitBreakerState::HalfOpen && self.success_count >= self.success_threshold {
350 self.state = CircuitBreakerState::Closed;
351 self.success_count = 0;
352 }
353 }
354
355 pub fn record_failure(&mut self) {
357 self.failure_count += 1;
358 self.success_count = 0;
359 self.last_failure_time = Some(Instant::now());
360
361 if self.failure_count >= self.failure_threshold {
362 self.state = CircuitBreakerState::Open;
363 } else if self.state == CircuitBreakerState::HalfOpen {
364 self.state = CircuitBreakerState::Open;
365 }
366 }
367
368 pub fn reset(&mut self) {
370 self.state = CircuitBreakerState::Closed;
371 self.failure_count = 0;
372 self.success_count = 0;
373 self.last_failure_time = None;
374 }
375
376 pub fn get_status(&self) -> CircuitBreakerStatus {
378 CircuitBreakerStatus {
379 state: self.state.clone(),
380 failure_count: self.failure_count,
381 success_count: self.success_count,
382 last_failure_time: self.last_failure_time,
383 }
384 }
385}
386
387#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
389pub enum CircuitBreakerState {
390 Closed,
392 Open,
394 HalfOpen,
396}
397
398#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
400pub struct CircuitBreakerStatus {
401 pub state: CircuitBreakerState,
403 pub failure_count: usize,
405 pub success_count: usize,
407 #[serde(skip_serializing, skip_deserializing)]
409 pub last_failure_time: Option<Instant>,
410}
411
412#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
414pub enum ErrorType {
415 Network,
417 Permission,
419 NotFound,
421 Conflict,
423 RateLimit,
425 Unknown,
427}
428
429#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
431pub struct RecoveryStats {
432 pub total_operations: usize,
434 pub successful_operations: usize,
436 pub failed_operations: usize,
438 pub total_retries: usize,
440 pub avg_retries_per_operation: f64,
442 pub avg_operation_duration: Duration,
444 pub error_type_distribution: std::collections::HashMap<ErrorType, usize>,
446}
447
448impl RecoveryStats {
449 pub fn new() -> Self {
451 Self {
452 total_operations: 0,
453 successful_operations: 0,
454 failed_operations: 0,
455 total_retries: 0,
456 avg_retries_per_operation: 0.0,
457 avg_operation_duration: Duration::ZERO,
458 error_type_distribution: std::collections::HashMap::new(),
459 }
460 }
461
462 pub fn reset(&mut self) {
464 self.total_operations = 0;
465 self.successful_operations = 0;
466 self.failed_operations = 0;
467 self.total_retries = 0;
468 self.avg_retries_per_operation = 0.0;
469 self.avg_operation_duration = Duration::ZERO;
470 self.error_type_distribution.clear();
471 }
472
473 pub fn record_success(&mut self, retry_count: usize, duration: Duration) {
475 self.total_operations += 1;
476 self.successful_operations += 1;
477 self.total_retries += retry_count;
478
479 self.update_averages();
480 }
481
482 pub fn record_failure(&mut self, retry_count: usize, duration: Duration, error_type: ErrorType) {
484 self.total_operations += 1;
485 self.failed_operations += 1;
486 self.total_retries += retry_count;
487
488 *self.error_type_distribution.entry(error_type).or_insert(0) += 1;
489
490 self.update_averages();
491 }
492
493 fn update_averages(&mut self) {
495 if self.total_operations > 0 {
496 self.avg_retries_per_operation = self.total_retries as f64 / self.total_operations as f64;
497 }
498 }
499}
500
501#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
503pub struct RecoveryConfig {
504 pub retry_policy: RetryPolicy,
506 pub circuit_breaker: CircuitBreakerPolicy,
508}
509
510impl Default for RecoveryConfig {
511 fn default() -> Self {
512 Self {
513 retry_policy: RetryPolicy::default(),
514 circuit_breaker: CircuitBreakerPolicy::new(),
515 }
516 }
517}
518
519#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
521pub struct RecoveryResult<T> {
522 pub value: T,
524 pub retry_count: usize,
526 pub duration: Duration,
528 pub first_try_success: bool,
530}
531
532#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
534pub enum RecoveryError {
535 NotInitialized,
537 MaxRetriesExceeded,
539 CircuitBreakerOpen,
541 NonRetryableError(ErrorType),
543 Timeout,
545 ConfigurationError(String),
547}
548
549impl std::fmt::Display for RecoveryError {
550 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
551 match self {
552 RecoveryError::NotInitialized => write!(f, "Error recovery system not initialized"),
553 RecoveryError::MaxRetriesExceeded => write!(f, "Maximum retry attempts exceeded"),
554 RecoveryError::CircuitBreakerOpen => write!(f, "Circuit breaker is open"),
555 RecoveryError::NonRetryableError(error_type) => {
556 write!(f, "Non-retryable error type: {:?}", error_type)
557 }
558 RecoveryError::Timeout => write!(f, "Operation timeout"),
559 RecoveryError::ConfigurationError(msg) => write!(f, "Configuration error: {}", msg),
560 }
561 }
562}
563
564impl std::error::Error for RecoveryError {}
565
566#[cfg(test)]
567mod tests {
568 use super::*;
569 use std::sync::atomic::{AtomicUsize, Ordering};
570
571 #[tokio::test]
572 async fn test_error_recovery_creation() {
573 let recovery = ErrorRecovery::new();
574 assert!(!recovery.is_initialized());
575 }
576
577 #[tokio::test]
578 async fn test_error_recovery_initialization() {
579 let mut recovery = ErrorRecovery::new();
580 let result = recovery.initialize().await;
581 assert!(result.is_ok());
582 assert!(recovery.is_initialized());
583 }
584
585 #[tokio::test]
586 async fn test_error_recovery_shutdown() {
587 let mut recovery = ErrorRecovery::new();
588 recovery.initialize().await.unwrap();
589 let result = recovery.shutdown().await;
590 assert!(result.is_ok());
591 assert!(!recovery.is_initialized());
592 }
593
594 #[tokio::test]
595 async fn test_successful_operation() {
596 let mut recovery = ErrorRecovery::new();
597 recovery.initialize().await.unwrap();
598
599 let result = recovery.execute_with_recovery(|| {
600 Box::pin(async { Ok::<i32, std::io::Error>(42) })
601 }).await;
602
603 assert!(result.is_ok());
604 assert_eq!(result.unwrap(), 42);
605
606 let stats = recovery.get_stats().await;
607 assert_eq!(stats.total_operations, 1);
608 assert_eq!(stats.successful_operations, 1);
609 assert_eq!(stats.failed_operations, 0);
610 }
611
612 #[tokio::test]
613 async fn test_retry_on_network_error() {
614 let mut recovery = ErrorRecovery::new();
615 recovery.initialize().await.unwrap();
616
617 let attempt_count = std::sync::Arc::new(AtomicUsize::new(0));
618
619 let result = recovery.execute_with_recovery(|| {
620 let attempt_count = std::sync::Arc::clone(&attempt_count);
621 Box::pin(async move {
622 let attempt = attempt_count.fetch_add(1, Ordering::SeqCst) + 1;
623 if attempt < 3 {
624 Err::<i32, std::io::Error>(std::io::Error::new(std::io::ErrorKind::TimedOut, "network timeout"))
625 } else {
626 Ok(42)
627 }
628 })
629 }).await;
630
631 assert!(result.is_ok());
632 assert_eq!(result.unwrap(), 42);
633 assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
634
635 let stats = recovery.get_stats().await;
636 assert_eq!(stats.total_operations, 1);
637 assert_eq!(stats.successful_operations, 1);
638 assert_eq!(stats.total_retries, 2);
639 }
640
641 #[tokio::test]
642 async fn test_max_retries_exceeded() {
643 let mut recovery = ErrorRecovery::new();
644 recovery.initialize().await.unwrap();
645
646 let result = recovery.execute_with_recovery(|| {
647 Box::pin(async { Err::<i32, std::io::Error>(std::io::Error::new(std::io::ErrorKind::TimedOut, "network timeout")) })
648 }).await;
649
650 assert!(result.is_err());
651 assert!(matches!(result.unwrap_err(), RecoveryError::MaxRetriesExceeded));
652
653 let stats = recovery.get_stats().await;
654 assert_eq!(stats.total_operations, 1);
655 assert_eq!(stats.successful_operations, 0);
656 assert_eq!(stats.failed_operations, 1);
657 assert_eq!(stats.total_retries, 3);
658 }
659
660 #[tokio::test]
661 async fn test_non_retryable_error() {
662 let mut recovery = ErrorRecovery::new();
663 recovery.initialize().await.unwrap();
664
665 let result = recovery.execute_with_recovery(|| {
666 Box::pin(async { Err::<i32, std::io::Error>(std::io::Error::new(std::io::ErrorKind::PermissionDenied, "permission denied")) })
667 }).await;
668
669 assert!(result.is_err());
670 assert!(matches!(result.unwrap_err(), RecoveryError::NonRetryableError(ErrorType::Permission)));
671
672 let stats = recovery.get_stats().await;
673 assert_eq!(stats.total_operations, 1);
674 assert_eq!(stats.successful_operations, 0);
675 assert_eq!(stats.failed_operations, 1);
676 assert_eq!(stats.total_retries, 0);
677 }
678
679 #[tokio::test]
680 async fn test_circuit_breaker() {
681 let mut recovery = ErrorRecovery::new();
682 recovery.initialize().await.unwrap();
683
684 for _ in 0..6 {
686 let _ = recovery.execute_with_recovery(|| {
687 Box::pin(async { Err::<i32, std::io::Error>(std::io::Error::new(std::io::ErrorKind::TimedOut, "network timeout")) })
688 }).await;
689 }
690
691 let result = recovery.execute_with_recovery(|| {
693 Box::pin(async { Ok::<i32, std::io::Error>(42) })
694 }).await;
695
696 assert!(result.is_err());
697 assert!(matches!(result.unwrap_err(), RecoveryError::CircuitBreakerOpen));
698 }
699
700 #[tokio::test]
701 async fn test_circuit_breaker_reset() {
702 let mut recovery = ErrorRecovery::new();
703 recovery.initialize().await.unwrap();
704
705 for _ in 0..6 {
707 let _ = recovery.execute_with_recovery(|| {
708 Box::pin(async { Err::<i32, std::io::Error>(std::io::Error::new(std::io::ErrorKind::TimedOut, "network timeout")) })
709 }).await;
710 }
711
712 recovery.reset_circuit_breaker().await;
714
715 let result = recovery.execute_with_recovery(|| {
717 Box::pin(async { Ok::<i32, std::io::Error>(42) })
718 }).await;
719
720 assert!(result.is_ok());
721 assert_eq!(result.unwrap(), 42);
722 }
723
724 #[tokio::test]
725 async fn test_error_classification() {
726 let recovery = ErrorRecovery::new();
727
728 let network_error = std::io::Error::new(std::io::ErrorKind::TimedOut, "connection timeout");
730 let error_type = recovery.classify_error(&network_error);
731 assert_eq!(error_type, ErrorType::Network);
732
733 let permission_error = std::io::Error::new(std::io::ErrorKind::PermissionDenied, "permission denied");
735 let error_type = recovery.classify_error(&permission_error);
736 assert_eq!(error_type, ErrorType::Permission);
737
738 let rate_limit_error = std::io::Error::new(std::io::ErrorKind::Other, "rate limit exceeded");
740 let error_type = recovery.classify_error(&rate_limit_error);
741 assert_eq!(error_type, ErrorType::RateLimit);
742 }
743
744 #[tokio::test]
745 async fn test_retry_policy_configuration() {
746 let config = RecoveryConfig {
747 retry_policy: RetryPolicy {
748 max_retries: 5,
749 strategy: RetryStrategy::Fixed(FixedDelayConfig {
750 delay: Duration::from_millis(500),
751 }),
752 },
753 circuit_breaker: CircuitBreakerPolicy::new(),
754 };
755
756 let mut recovery = ErrorRecovery::with_config(config);
757 recovery.initialize().await.unwrap();
758
759 let start_time = Instant::now();
760 let result = recovery.execute_with_recovery(|| {
761 Box::pin(async { Err::<i32, std::io::Error>(std::io::Error::new(std::io::ErrorKind::TimedOut, "network timeout")) })
762 }).await;
763
764 assert!(result.is_err());
765 assert!(matches!(result.unwrap_err(), RecoveryError::MaxRetriesExceeded));
766
767 assert!(start_time.elapsed() >= Duration::from_millis(2400));
769 }
770
771 #[tokio::test]
772 async fn test_recovery_stats() {
773 let mut recovery = ErrorRecovery::new();
774 recovery.initialize().await.unwrap();
775
776 let _ = recovery.execute_with_recovery(|| {
778 Box::pin(async { Ok::<i32, std::io::Error>(42) })
779 }).await;
780
781 let _ = recovery.execute_with_recovery(|| {
782 Box::pin(async { Err::<i32, std::io::Error>(std::io::Error::new(std::io::ErrorKind::TimedOut, "network timeout")) })
783 }).await;
784
785 let stats = recovery.get_stats().await;
786 assert_eq!(stats.total_operations, 2);
787 assert_eq!(stats.successful_operations, 1);
788 assert_eq!(stats.failed_operations, 1);
789 assert!(stats.error_type_distribution.contains_key(&ErrorType::Network));
790 }
791
792 #[test]
793 fn test_retry_policy_default() {
794 let policy = RetryPolicy::default();
795 assert_eq!(policy.max_retries, 3);
796 assert!(matches!(policy.strategy, RetryStrategy::ExponentialBackoff(_)));
797 }
798
799 #[test]
800 fn test_circuit_breaker_policy() {
801 let mut policy = CircuitBreakerPolicy::new();
802 assert!(policy.can_execute());
803
804 for _ in 0..5 {
806 policy.record_failure();
807 }
808
809 assert!(!policy.can_execute());
810
811 policy.state = CircuitBreakerState::HalfOpen;
813 policy.success_count = 0;
814
815 for _ in 0..3 {
817 policy.record_success();
818 }
819
820 assert!(policy.can_execute());
821 }
822
823 #[test]
824 fn test_recovery_stats_creation() {
825 let stats = RecoveryStats::new();
826 assert_eq!(stats.total_operations, 0);
827 assert_eq!(stats.successful_operations, 0);
828 assert_eq!(stats.failed_operations, 0);
829 assert_eq!(stats.total_retries, 0);
830 }
831
832 #[test]
833 fn test_recovery_error_display() {
834 let error = RecoveryError::MaxRetriesExceeded;
835 let error_string = format!("{}", error);
836 assert!(error_string.contains("Maximum retry attempts exceeded"));
837 }
838}