1use std::collections::HashMap;
8use std::sync::Arc;
9
10use serde::{Deserialize, Serialize};
11use tokio::sync::Mutex;
12
13use crate::error::Error;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct RecoveryStrategy {
17 pub name: String,
18 pub description: String,
19}
20
21pub struct ErrorRecovery {
23 strategies: Arc<Mutex<RecoveryStrategies>>,
25 error_tracker: Arc<Mutex<ErrorTracker>>,
27 state_recovery: Arc<Mutex<StateRecovery>>,
29 #[allow(dead_code)]
31 config: ErrorRecoveryConfig,
32}
33
34impl Default for ErrorRecovery {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40impl ErrorRecovery {
41 pub fn new() -> Self {
42 Self {
43 strategies: Arc::new(Mutex::new(RecoveryStrategies::new())),
44 error_tracker: Arc::new(Mutex::new(ErrorTracker::new())),
45 state_recovery: Arc::new(Mutex::new(StateRecovery::new())),
46 config: ErrorRecoveryConfig::default(),
47 }
48 }
49
50 pub async fn attempt_recovery(
52 &self,
53 error: &Error,
54 task_node: &super::task_graph::TaskNode,
55 task_graph: &super::task_graph::TaskGraph,
56 ) -> Result<RecoveryResult, Error> {
57 let start_time = std::time::Instant::now();
58 let error_type = self.classify_error(error);
59
60 tracing::info!("Attempting recovery for error: {:?}", error_type);
61
62 {
64 let mut tracker = self.error_tracker.lock().await;
65 tracker.record_error(error, &error_type, task_node.id());
66 }
67
68 let strategies = self.strategies.lock().await;
70 let strategy = strategies.get_strategy(&error_type, task_node.task_type())?;
71
72 tracing::info!("Using recovery strategy: {}", strategy.name());
73
74 let recovery_result = match strategy {
76 RecoveryStrategyType::RetryWithBackoff => {
77 self.execute_retry_with_backoff(task_node, task_graph, &error_type)
78 .await
79 }
80 RecoveryStrategyType::AlternativeExecution => {
81 self.execute_alternative_execution(task_node, task_graph, &error_type)
82 .await
83 }
84 RecoveryStrategyType::StateRollback => {
85 self.execute_state_rollback(task_node, task_graph, &error_type)
86 .await
87 }
88 RecoveryStrategyType::ComponentFallback => {
89 self.execute_component_fallback(task_node, task_graph, &error_type)
90 .await
91 }
92 RecoveryStrategyType::SkipAndContinue => {
93 self.execute_skip_and_continue(task_node, task_graph, &error_type)
94 .await
95 }
96 RecoveryStrategyType::ManualIntervention => {
97 self.execute_manual_intervention(task_node, task_graph, &error_type)
98 .await
99 }
100 RecoveryStrategyType::ExtendTimeout
101 | RecoveryStrategyType::ParallelExecution
102 | RecoveryStrategyType::ResourceOptimization => {
103 self.execute_retry_with_backoff(task_node, task_graph, &error_type)
104 .await
105 }
106 }?;
107
108 let recovery_time_ms = start_time.elapsed().as_millis() as u64;
109
110 {
112 let mut tracker = self.error_tracker.lock().await;
113 tracker.record_recovery_attempt(&error_type, recovery_result.success, recovery_time_ms);
114 }
115
116 if recovery_result.success {
118 tracing::info!(
119 "Recovery successful for error {:?} in {}ms: {}",
120 error_type,
121 recovery_time_ms,
122 recovery_result.details
123 );
124 } else {
125 tracing::warn!(
126 "Recovery failed for error {:?} in {}ms: {}",
127 error_type,
128 recovery_time_ms,
129 recovery_result.details
130 );
131 }
132
133 Ok(RecoveryResult {
134 success: recovery_result.success,
135 strategy_used: strategy.name().to_string(),
136 recovery_time_ms,
137 details: recovery_result.details,
138 recovered_state: recovery_result.recovered_state,
139 fallback_actions: recovery_result.fallback_actions,
140 })
141 }
142
143 pub async fn handle_timeout(
145 &self,
146 task_node: &super::task_graph::TaskNode,
147 timeout_ms: u64,
148 ) -> Result<RecoveryResult, Error> {
149 let error_type = ErrorType::Timeout;
150
151 tracing::warn!(
152 "Handling timeout for task {} ({}ms)",
153 task_node.name(),
154 timeout_ms
155 );
156
157 {
159 let mut tracker = self.error_tracker.lock().await;
160 tracker.record_error(
161 &Error::Validation(format!("Task {} timed out", task_node.name())),
162 &error_type,
163 task_node.id(),
164 );
165 }
166
167 let strategies = self.strategies.lock().await;
169 let strategy = strategies.get_timeout_strategy(task_node)?;
170
171 let recovery_result = match strategy {
173 RecoveryStrategyType::ExtendTimeout => self.extend_timeout(task_node, timeout_ms).await,
174 RecoveryStrategyType::ParallelExecution => {
175 self.parallel_execution_recovery(task_node).await
176 }
177 RecoveryStrategyType::ResourceOptimization => self.optimize_resources(task_node).await,
178 _ => {
179 self.execute_skip_and_continue(
180 task_node,
181 &super::task_graph::TaskGraph::new(),
182 &error_type,
183 )
184 .await
185 }
186 }?;
187
188 let recovery_time_ms = 0;
189
190 Ok(RecoveryResult {
191 success: recovery_result.success,
192 strategy_used: strategy.name().to_string(),
193 recovery_time_ms,
194 details: recovery_result.details,
195 recovered_state: recovery_result.recovered_state,
196 fallback_actions: recovery_result.fallback_actions,
197 })
198 }
199
200 fn classify_error(&self, error: &Error) -> ErrorType {
202 match error {
203 Error::Validation(_) => ErrorType::ValidationError,
204 Error::Io(_) => ErrorType::IoError,
205 Error::Network(_) => ErrorType::NetworkError,
206 Error::ResourceExhausted(_) => ErrorType::ResourceExhausted,
207 Error::M2ExecutionError(_) => ErrorType::M2Error,
208 _ => ErrorType::Unknown,
209 }
210 }
211
212 async fn execute_retry_with_backoff(
214 &self,
215 task_node: &super::task_graph::TaskNode,
216 _task_graph: &super::task_graph::TaskGraph,
217 _error_type: &ErrorType,
218 ) -> Result<StrategyExecutionResult, Error> {
219 let max_retries = task_node.max_retries();
220 let mut attempt = 0;
221 let base_delay_ms = 1000; while attempt < max_retries {
224 attempt += 1;
225 let delay_ms = base_delay_ms * (2_u64.pow(attempt - 1)); tracing::info!(
228 "Retry attempt {} for task {} (delay: {}ms)",
229 attempt,
230 task_node.name(),
231 delay_ms
232 );
233
234 tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
235
236 let retry_success = self.simulate_retry_execution(task_node).await?;
239
240 if retry_success {
241 return Ok(StrategyExecutionResult {
242 success: true,
243 details: format!(
244 "Task {} recovered after {} retry attempts",
245 task_node.name(),
246 attempt
247 ),
248 recovered_state: Some(serde_json::json!({"retries_used": attempt})),
249 fallback_actions: vec![],
250 });
251 }
252 }
253
254 Ok(StrategyExecutionResult {
255 success: false,
256 details: format!(
257 "Task {} failed after {} retry attempts",
258 task_node.name(),
259 max_retries
260 ),
261 recovered_state: None,
262 fallback_actions: vec!["skip_task".to_string(), "notify_failure".to_string()],
263 })
264 }
265
266 async fn execute_alternative_execution(
268 &self,
269 task_node: &super::task_graph::TaskNode,
270 _task_graph: &super::task_graph::TaskGraph,
271 error_type: &ErrorType,
272 ) -> Result<StrategyExecutionResult, Error> {
273 tracing::info!(
274 "Attempting alternative execution for task {}",
275 task_node.name()
276 );
277
278 let alternative_path = match task_node.task_type() {
280 super::task_graph::TaskType::ProtocolGeneration => {
281 self.get_alternative_protocol_generation_path(error_type)
282 .await
283 }
284 super::task_graph::TaskType::CodeAnalysis => {
285 self.get_alternative_code_analysis_path(error_type).await
286 }
287 super::task_graph::TaskType::WebAutomation => {
288 self.get_alternative_web_automation_path(error_type).await
289 }
290 _ => self.get_general_alternative_path(error_type).await,
291 };
292
293 if let Some(alternative_path) = alternative_path {
294 Ok(StrategyExecutionResult {
295 success: true,
296 details: format!(
297 "Alternative execution path found for task {}",
298 task_node.name()
299 ),
300 recovered_state: Some(serde_json::json!({"alternative_path": alternative_path})),
301 fallback_actions: vec![],
302 })
303 } else {
304 Ok(StrategyExecutionResult {
305 success: false,
306 details: format!(
307 "No alternative execution path available for task {}",
308 task_node.name()
309 ),
310 recovered_state: None,
311 fallback_actions: vec!["manual_review".to_string()],
312 })
313 }
314 }
315
316 async fn execute_state_rollback(
318 &self,
319 task_node: &super::task_graph::TaskNode,
320 _task_graph: &super::task_graph::TaskGraph,
321 _error_type: &ErrorType,
322 ) -> Result<StrategyExecutionResult, Error> {
323 tracing::info!("Attempting state rollback for task {}", task_node.name());
324
325 let mut state_recovery = self.state_recovery.lock().await;
326 let rollback_result = state_recovery
327 .rollback_to_checkpoint(task_node.id())
328 .await?;
329
330 if rollback_result.success {
331 Ok(StrategyExecutionResult {
332 success: true,
333 details: format!(
334 "State rolled back successfully for task {}",
335 task_node.name()
336 ),
337 recovered_state: Some(rollback_result.state_data),
338 fallback_actions: vec![],
339 })
340 } else {
341 Ok(StrategyExecutionResult {
342 success: false,
343 details: format!("State rollback failed for task {}", task_node.name()),
344 recovered_state: None,
345 fallback_actions: vec!["reset_execution".to_string()],
346 })
347 }
348 }
349
350 async fn execute_component_fallback(
352 &self,
353 task_node: &super::task_graph::TaskNode,
354 _task_graph: &super::task_graph::TaskGraph,
355 error_type: &ErrorType,
356 ) -> Result<StrategyExecutionResult, Error> {
357 tracing::info!(
358 "Attempting component fallback for task {}",
359 task_node.name()
360 );
361
362 let fallback_components = self
364 .determine_fallback_components(task_node, error_type)
365 .await?;
366
367 if !fallback_components.is_empty() {
368 Ok(StrategyExecutionResult {
369 success: true,
370 details: format!(
371 "Component fallback executed with {} components",
372 fallback_components.len()
373 ),
374 recovered_state: Some(
375 serde_json::json!({"fallback_components": fallback_components}),
376 ),
377 fallback_actions: vec![],
378 })
379 } else {
380 Ok(StrategyExecutionResult {
381 success: false,
382 details: "No suitable fallback components available".to_string(),
383 recovered_state: None,
384 fallback_actions: vec!["manual_intervention".to_string()],
385 })
386 }
387 }
388
389 async fn execute_skip_and_continue(
391 &self,
392 task_node: &super::task_graph::TaskNode,
393 _task_graph: &super::task_graph::TaskGraph,
394 _error_type: &ErrorType,
395 ) -> Result<StrategyExecutionResult, Error> {
396 tracing::info!(
397 "Skipping task {} and continuing execution",
398 task_node.name()
399 );
400
401 Ok(StrategyExecutionResult {
405 success: true,
406 details: format!("Task {} skipped, execution continuing", task_node.name()),
407 recovered_state: Some(serde_json::json!({"skipped": true, "task_id": task_node.id()})),
408 fallback_actions: vec![],
409 })
410 }
411
412 async fn execute_manual_intervention(
414 &self,
415 task_node: &super::task_graph::TaskNode,
416 _task_graph: &super::task_graph::TaskGraph,
417 error_type: &ErrorType,
418 ) -> Result<StrategyExecutionResult, Error> {
419 tracing::warn!("Manual intervention required for task {}", task_node.name());
420
421 let intervention_required = serde_json::json!({
424 "task_id": task_node.id(),
425 "task_name": task_node.name(),
426 "error_type": format!("{:?}", error_type),
427 "timestamp": chrono::Utc::now().timestamp(),
428 "intervention_url": format!("https://admin.reasonkit.sh/intervention/{}", task_node.id())
429 });
430
431 Ok(StrategyExecutionResult {
432 success: false, details: "Manual intervention required - execution paused".to_string(),
434 recovered_state: Some(intervention_required),
435 fallback_actions: vec!["notify_admin".to_string(), "pause_execution".to_string()],
436 })
437 }
438
439 async fn simulate_retry_execution(
441 &self,
442 _task_node: &super::task_graph::TaskNode,
443 ) -> Result<bool, Error> {
444 let success_rate = 0.7;
446 let random_value = 0.5f64;
447
448 Ok(random_value < success_rate)
449 }
450
451 async fn get_alternative_protocol_generation_path(
453 &self,
454 error_type: &ErrorType,
455 ) -> Option<String> {
456 match error_type {
457 ErrorType::ProtocolError => Some("use_simplified_protocol".to_string()),
458 ErrorType::ResourceExhausted => Some("defer_generation".to_string()),
459 _ => Some("manual_generation".to_string()),
460 }
461 }
462
463 async fn get_alternative_code_analysis_path(&self, error_type: &ErrorType) -> Option<String> {
465 match error_type {
466 ErrorType::ThinkToolError => Some("use_basic_analysis".to_string()),
467 ErrorType::M2Error => Some("use_fallback_model".to_string()),
468 _ => Some("basic_static_analysis".to_string()),
469 }
470 }
471
472 async fn get_alternative_web_automation_path(&self, error_type: &ErrorType) -> Option<String> {
474 match error_type {
475 ErrorType::NetworkError => Some("retry_with_different_proxy".to_string()),
476 ErrorType::Timeout => Some("use_headless_mode".to_string()),
477 _ => Some("skip_web_automation".to_string()),
478 }
479 }
480
481 async fn get_general_alternative_path(&self, error_type: &ErrorType) -> Option<String> {
483 match error_type {
484 ErrorType::MemoryError => Some("reduce_memory_usage".to_string()),
485 ErrorType::RateLimitError => Some("wait_and_retry".to_string()),
486 _ => Some("execute_basic_version".to_string()),
487 }
488 }
489
490 async fn determine_fallback_components(
492 &self,
493 task_node: &super::task_graph::TaskNode,
494 error_type: &ErrorType,
495 ) -> Result<Vec<String>, Error> {
496 let mut fallback_components = Vec::new();
497
498 match task_node.task_type() {
500 super::task_graph::TaskType::ProtocolGeneration => {
501 fallback_components.push("reasonkit-core".to_string());
502 if matches!(error_type, ErrorType::M2Error) {
503 fallback_components.push("reasonkit-pro".to_string());
504 }
505 }
506 super::task_graph::TaskType::CodeAnalysis => {
507 fallback_components
508 .extend(["reasonkit-core".to_string(), "reasonkit-web".to_string()]);
509 }
510 _ => {
511 fallback_components.push("reasonkit-core".to_string());
512 }
513 }
514
515 Ok(fallback_components)
516 }
517
518 async fn extend_timeout(
520 &self,
521 _task_node: &super::task_graph::TaskNode,
522 current_timeout: u64,
523 ) -> Result<StrategyExecutionResult, Error> {
524 let extended_timeout = current_timeout * 2; Ok(StrategyExecutionResult {
527 success: true,
528 details: format!(
529 "Extended timeout from {}ms to {}ms",
530 current_timeout, extended_timeout
531 ),
532 recovered_state: Some(serde_json::json!({"extended_timeout": extended_timeout})),
533 fallback_actions: vec![],
534 })
535 }
536
537 async fn parallel_execution_recovery(
539 &self,
540 _task_node: &super::task_graph::TaskNode,
541 ) -> Result<StrategyExecutionResult, Error> {
542 Ok(StrategyExecutionResult {
543 success: true,
544 details: "Switching to parallel execution".to_string(),
545 recovered_state: Some(serde_json::json!({"parallel_execution": true})),
546 fallback_actions: vec![],
547 })
548 }
549
550 async fn optimize_resources(
552 &self,
553 _task_node: &super::task_graph::TaskNode,
554 ) -> Result<StrategyExecutionResult, Error> {
555 Ok(StrategyExecutionResult {
556 success: true,
557 details: "Optimized resource allocation".to_string(),
558 recovered_state: Some(serde_json::json!({"resource_optimization": true})),
559 fallback_actions: vec![],
560 })
561 }
562
563 pub async fn get_recovery_statistics(&self) -> Result<RecoveryStatistics, Error> {
565 let tracker = self.error_tracker.lock().await;
566 Ok(tracker.get_statistics())
567 }
568
569 pub async fn reset(&self) -> Result<(), Error> {
571 {
572 let mut tracker = self.error_tracker.lock().await;
573 tracker.reset();
574 }
575
576 {
577 let mut state_recovery = self.state_recovery.lock().await;
578 state_recovery.reset();
579 }
580
581 tracing::info!("Error recovery system reset");
582 Ok(())
583 }
584}
585
586#[derive(Debug, Clone, PartialEq, Eq, Hash)]
588pub enum ErrorType {
589 Timeout,
590 ValidationError,
591 IoError,
592 NetworkError,
593 MemoryError,
594 RateLimitError,
595 AuthenticationError,
596 AuthorizationError,
597 ResourceExhausted,
598 DependencyError,
599 ProtocolError,
600 ThinkToolError,
601 M2Error,
602 Unknown,
603}
604
605#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
607#[allow(dead_code)]
608enum RecoveryStrategyType {
609 RetryWithBackoff,
610 AlternativeExecution,
611 StateRollback,
612 ComponentFallback,
613 SkipAndContinue,
614 ManualIntervention,
615 ExtendTimeout,
616 ParallelExecution,
617 ResourceOptimization,
618}
619
620impl RecoveryStrategyType {
621 fn name(&self) -> &'static str {
622 match self {
623 RecoveryStrategyType::RetryWithBackoff => "RetryWithBackoff",
624 RecoveryStrategyType::AlternativeExecution => "AlternativeExecution",
625 RecoveryStrategyType::StateRollback => "StateRollback",
626 RecoveryStrategyType::ComponentFallback => "ComponentFallback",
627 RecoveryStrategyType::SkipAndContinue => "SkipAndContinue",
628 RecoveryStrategyType::ManualIntervention => "ManualIntervention",
629 RecoveryStrategyType::ExtendTimeout => "ExtendTimeout",
630 RecoveryStrategyType::ParallelExecution => "ParallelExecution",
631 RecoveryStrategyType::ResourceOptimization => "ResourceOptimization",
632 }
633 }
634}
635
636#[derive(Debug)]
638struct RecoveryStrategies {
639 strategies: HashMap<(ErrorType, super::task_graph::TaskType), RecoveryStrategyType>,
640}
641
642impl RecoveryStrategies {
643 fn new() -> Self {
644 let mut strategies = HashMap::new();
645
646 strategies.insert(
648 (
649 ErrorType::Timeout,
650 super::task_graph::TaskType::ProtocolGeneration,
651 ),
652 RecoveryStrategyType::RetryWithBackoff,
653 );
654 strategies.insert(
655 (
656 ErrorType::Timeout,
657 super::task_graph::TaskType::CodeAnalysis,
658 ),
659 RecoveryStrategyType::ExtendTimeout,
660 );
661 strategies.insert(
662 (
663 ErrorType::ValidationError,
664 super::task_graph::TaskType::ProtocolGeneration,
665 ),
666 RecoveryStrategyType::AlternativeExecution,
667 );
668 strategies.insert(
669 (
670 ErrorType::ValidationError,
671 super::task_graph::TaskType::General,
672 ),
673 RecoveryStrategyType::RetryWithBackoff,
674 );
675 strategies.insert(
676 (
677 ErrorType::ProtocolError,
678 super::task_graph::TaskType::ProtocolGeneration,
679 ),
680 RecoveryStrategyType::StateRollback,
681 );
682 strategies.insert(
683 (
684 ErrorType::ThinkToolError,
685 super::task_graph::TaskType::CodeAnalysis,
686 ),
687 RecoveryStrategyType::ComponentFallback,
688 );
689 strategies.insert(
690 (
691 ErrorType::M2Error,
692 super::task_graph::TaskType::ProtocolGeneration,
693 ),
694 RecoveryStrategyType::ComponentFallback,
695 );
696 strategies.insert(
697 (
698 ErrorType::ResourceExhausted,
699 super::task_graph::TaskType::EnterpriseWorkflow,
700 ),
701 RecoveryStrategyType::ResourceOptimization,
702 );
703 strategies.insert(
704 (
705 ErrorType::MemoryError,
706 super::task_graph::TaskType::MultiAgentCoordination,
707 ),
708 RecoveryStrategyType::StateRollback,
709 );
710
711 Self { strategies }
712 }
713
714 fn get_strategy(
715 &self,
716 error_type: &ErrorType,
717 task_type: super::task_graph::TaskType,
718 ) -> Result<RecoveryStrategyType, Error> {
719 let key = (error_type.clone(), task_type);
720 self.strategies
721 .get(&key)
722 .cloned()
723 .or_else(|| {
724 self.strategies
725 .get(&(error_type.clone(), super::task_graph::TaskType::General))
726 .cloned()
727 })
728 .ok_or_else(|| {
729 Error::Validation(format!(
730 "No recovery strategy for error {:?} and task type {:?}",
731 error_type, task_type
732 ))
733 })
734 }
735
736 fn get_timeout_strategy(
737 &self,
738 task_node: &super::task_graph::TaskNode,
739 ) -> Result<RecoveryStrategyType, Error> {
740 match task_node.task_type() {
741 super::task_graph::TaskType::ProtocolGeneration => {
742 Ok(RecoveryStrategyType::ExtendTimeout)
743 }
744 super::task_graph::TaskType::CodeAnalysis => {
745 Ok(RecoveryStrategyType::ParallelExecution)
746 }
747 super::task_graph::TaskType::EnterpriseWorkflow => {
748 Ok(RecoveryStrategyType::ResourceOptimization)
749 }
750 _ => Ok(RecoveryStrategyType::RetryWithBackoff),
751 }
752 }
753}
754
755#[derive(Debug)]
757struct ErrorTracker {
758 error_history: Vec<ErrorRecord>,
759 error_patterns: HashMap<ErrorType, u32>,
760 recovery_success_rates: HashMap<RecoveryStrategyType, (u32, u32)>,
761 total_errors: u32,
762 total_recoveries: u32,
763 successful_recoveries: u32,
764}
765
766impl ErrorTracker {
767 fn new() -> Self {
768 Self {
769 error_history: Vec::new(),
770 error_patterns: HashMap::new(),
771 recovery_success_rates: HashMap::new(),
772 total_errors: 0,
773 total_recoveries: 0,
774 successful_recoveries: 0,
775 }
776 }
777
778 fn record_error(&mut self, error: &Error, error_type: &ErrorType, task_id: &str) {
779 let record = ErrorRecord {
780 timestamp: chrono::Utc::now(),
781 error_type: error_type.clone(),
782 error_message: error.to_string(),
783 task_id: task_id.to_string(),
784 context: serde_json::json!({}),
785 };
786
787 self.error_history.push(record);
788 *self.error_patterns.entry(error_type.clone()).or_insert(0) += 1;
789 self.total_errors += 1;
790
791 if self.error_history.len() > 1000 {
793 self.error_history.remove(0);
794 }
795 }
796
797 fn record_recovery_attempt(
798 &mut self,
799 _error_type: &ErrorType,
800 success: bool,
801 _recovery_time_ms: u64,
802 ) {
803 self.total_recoveries += 1;
804 if success {
805 self.successful_recoveries += 1;
806 }
807
808 let strategy_key = RecoveryStrategyType::RetryWithBackoff; let (success_count, total_count) = self
811 .recovery_success_rates
812 .entry(strategy_key)
813 .or_insert((0, 0));
814 *total_count += 1;
815 if success {
816 *success_count += 1;
817 }
818 }
819
820 fn get_statistics(&self) -> RecoveryStatistics {
821 let overall_success_rate = if self.total_recoveries > 0 {
822 self.successful_recoveries as f64 / self.total_recoveries as f64
823 } else {
824 0.0
825 };
826
827 RecoveryStatistics {
828 total_errors: self.total_errors,
829 total_recovery_attempts: self.total_recoveries,
830 successful_recoveries: self.successful_recoveries,
831 overall_success_rate,
832 error_patterns: self.error_patterns.clone(),
833 recovery_success_rates: self
834 .recovery_success_rates
835 .iter()
836 .map(|(strategy, (success, total))| {
837 (
838 strategy.name().to_string(),
839 (*success as f64 / *total as f64),
840 )
841 })
842 .collect(),
843 }
844 }
845
846 fn reset(&mut self) {
847 self.error_history.clear();
848 self.error_patterns.clear();
849 self.recovery_success_rates.clear();
850 self.total_errors = 0;
851 self.total_recoveries = 0;
852 self.successful_recoveries = 0;
853 }
854}
855
856#[derive(Debug)]
858struct StateRecovery {
859 checkpoints: Vec<StateCheckpoint>,
860 #[allow(dead_code)]
861 max_checkpoints: usize,
862}
863
864impl StateRecovery {
865 fn new() -> Self {
866 Self {
867 checkpoints: Vec::new(),
868 max_checkpoints: 50,
869 }
870 }
871
872 async fn rollback_to_checkpoint(&mut self, task_id: &str) -> Result<RollbackResult, Error> {
873 let checkpoint = self
875 .checkpoints
876 .iter()
877 .rev()
878 .find(|cp| cp.task_id == task_id)
879 .cloned();
880
881 if let Some(checkpoint) = checkpoint {
882 Ok(RollbackResult {
883 success: true,
884 state_data: checkpoint.state_data,
885 timestamp: checkpoint.timestamp,
886 })
887 } else {
888 Ok(RollbackResult {
889 success: false,
890 state_data: serde_json::json!({}),
891 timestamp: chrono::Utc::now(),
892 })
893 }
894 }
895
896 #[allow(dead_code)]
897 fn add_checkpoint(&mut self, task_id: &str, state_data: serde_json::Value) {
898 let checkpoint = StateCheckpoint {
899 task_id: task_id.to_string(),
900 state_data,
901 timestamp: chrono::Utc::now(),
902 };
903
904 self.checkpoints.push(checkpoint);
905
906 if self.checkpoints.len() > self.max_checkpoints {
908 self.checkpoints.remove(0);
909 }
910 }
911
912 fn reset(&mut self) {
913 self.checkpoints.clear();
914 }
915}
916
917#[derive(Debug, Clone)]
919pub struct RecoveryResult {
920 pub success: bool,
921 pub strategy_used: String,
922 pub recovery_time_ms: u64,
923 pub details: String,
924 pub recovered_state: Option<serde_json::Value>,
925 pub fallback_actions: Vec<String>,
926}
927
928#[derive(Debug, Clone)]
930struct StrategyExecutionResult {
931 success: bool,
932 details: String,
933 recovered_state: Option<serde_json::Value>,
934 fallback_actions: Vec<String>,
935}
936
937#[derive(Debug, Clone)]
939#[allow(dead_code)]
940struct ErrorRecord {
941 timestamp: chrono::DateTime<chrono::Utc>,
942 error_type: ErrorType,
943 error_message: String,
944 task_id: String,
945 context: serde_json::Value,
946}
947
948#[derive(Debug, Clone)]
950struct StateCheckpoint {
951 task_id: String,
952 state_data: serde_json::Value,
953 timestamp: chrono::DateTime<chrono::Utc>,
954}
955
956#[derive(Debug, Clone)]
958struct RollbackResult {
959 success: bool,
960 state_data: serde_json::Value,
961 #[allow(dead_code)]
962 timestamp: chrono::DateTime<chrono::Utc>,
963}
964
965#[derive(Debug, Clone)]
967pub struct RecoveryStatistics {
968 pub total_errors: u32,
969 pub total_recovery_attempts: u32,
970 pub successful_recoveries: u32,
971 pub overall_success_rate: f64,
972 pub error_patterns: HashMap<ErrorType, u32>,
973 pub recovery_success_rates: Vec<(String, f64)>,
974}
975
976#[derive(Debug, Clone)]
978pub struct ErrorRecoveryConfig {
979 pub max_retry_attempts: u32,
980 pub base_retry_delay_ms: u64,
981 pub max_retry_delay_ms: u64,
982 pub enable_automatic_recovery: bool,
983 pub enable_state_rollback: bool,
984 pub recovery_timeout_ms: u64,
985}
986
987impl Default for ErrorRecoveryConfig {
988 fn default() -> Self {
989 Self {
990 max_retry_attempts: 3,
991 base_retry_delay_ms: 1000,
992 max_retry_delay_ms: 30000,
993 enable_automatic_recovery: true,
994 enable_state_rollback: true,
995 recovery_timeout_ms: 60000,
996 }
997 }
998}
999
1000#[cfg(test)]
1001mod tests {
1002 use super::*;
1003
1004 #[test]
1005 fn test_error_classification() {
1006 let recovery = ErrorRecovery::new();
1007 let timeout_error = Error::Validation("Test timeout".to_string());
1008 let error_type = recovery.classify_error(&timeout_error);
1009
1010 assert_eq!(error_type, ErrorType::ValidationError);
1011 }
1012
1013 #[test]
1014 fn test_recovery_result_creation() {
1015 let result = RecoveryResult {
1016 success: true,
1017 strategy_used: "RetryWithBackoff".to_string(),
1018 recovery_time_ms: 2000,
1019 details: "Recovery successful".to_string(),
1020 recovered_state: Some(serde_json::json!({"retries": 2})),
1021 fallback_actions: vec![],
1022 };
1023
1024 assert!(result.success);
1025 assert_eq!(result.recovery_time_ms, 2000);
1026 }
1027
1028 #[tokio::test]
1029 async fn test_error_recovery_creation() {
1030 let recovery = ErrorRecovery::new();
1031 assert!(recovery
1032 .attempt_recovery(
1033 &Error::Validation("Test error".to_string()),
1034 &super::super::task_graph::TaskNode::new(
1035 "test".to_string(),
1036 "Test Task".to_string(),
1037 super::super::task_graph::TaskType::General,
1038 super::super::task_graph::TaskPriority::Normal,
1039 "Test task".to_string(),
1040 ),
1041 &super::super::task_graph::TaskGraph::new(),
1042 )
1043 .await
1044 .is_ok());
1045 }
1046}