ai_lib/error_handling/
recovery.rs

1//! Error recovery strategies and management
2
3use crate::error_handling::{ErrorContext, SuggestedAction};
4use crate::metrics::Metrics;
5use crate::types::AiLibError;
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9use std::sync::{Arc, Mutex};
10use std::time::Instant;
11
12/// Error recovery manager
13pub struct ErrorRecoveryManager {
14    error_history: Arc<Mutex<VecDeque<ErrorRecord>>>,
15    recovery_strategies: HashMap<ErrorType, Box<dyn RecoveryStrategy>>,
16    // Metrics and monitoring
17    metrics: Option<Arc<dyn Metrics>>,
18    #[allow(dead_code)] // Reserved for future use
19    start_time: Instant,
20    // Error pattern analysis
21    error_patterns: Arc<Mutex<HashMap<ErrorType, ErrorPattern>>>,
22}
23
24impl Default for ErrorRecoveryManager {
25    fn default() -> Self {
26        Self::new()
27    }
28}
29
30/// Error pattern analysis for intelligent recovery
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ErrorPattern {
33    pub error_type: ErrorType,
34    pub count: u32,
35    pub first_occurrence: chrono::DateTime<chrono::Utc>,
36    pub last_occurrence: chrono::DateTime<chrono::Utc>,
37    pub frequency: f64, // errors per minute
38    pub suggested_action: SuggestedAction,
39    pub recovery_attempts: u32,
40    pub successful_recoveries: u32,
41}
42
43/// Record of an error for tracking patterns
44#[derive(Debug, Clone)]
45pub struct ErrorRecord {
46    pub error_type: ErrorType,
47    pub context: ErrorContext,
48    pub timestamp: chrono::DateTime<chrono::Utc>,
49}
50
51/// Types of errors for categorization
52#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
53pub enum ErrorType {
54    RateLimit,
55    Network,
56    Authentication,
57    Provider,
58    Timeout,
59    Configuration,
60    Validation,
61    Serialization,
62    Deserialization,
63    FileOperation,
64    ModelNotFound,
65    ContextLengthExceeded,
66    UnsupportedFeature,
67    Unknown,
68}
69
70/// Trait for implementing recovery strategies
71#[async_trait]
72pub trait RecoveryStrategy: Send + Sync {
73    /// Check if this strategy can recover from the given error
74    async fn can_recover(&self, error: &AiLibError) -> bool;
75
76    /// Attempt to recover from the error
77    async fn recover(&self, error: &AiLibError, context: &ErrorContext) -> Result<(), AiLibError>;
78}
79
80impl ErrorRecoveryManager {
81    /// Create a new error recovery manager
82    pub fn new() -> Self {
83        Self {
84            error_history: Arc::new(Mutex::new(VecDeque::new())),
85            recovery_strategies: HashMap::new(),
86            metrics: None,
87            start_time: Instant::now(),
88            error_patterns: Arc::new(Mutex::new(HashMap::new())),
89        }
90    }
91
92    /// Create a new error recovery manager with metrics
93    pub fn with_metrics(metrics: Arc<dyn Metrics>) -> Self {
94        Self {
95            error_history: Arc::new(Mutex::new(VecDeque::new())),
96            recovery_strategies: HashMap::new(),
97            metrics: Some(metrics),
98            start_time: Instant::now(),
99            error_patterns: Arc::new(Mutex::new(HashMap::new())),
100        }
101    }
102
103    /// Register a recovery strategy for a specific error type
104    pub fn register_strategy(
105        &mut self,
106        error_type: ErrorType,
107        strategy: Box<dyn RecoveryStrategy>,
108    ) {
109        self.recovery_strategies.insert(error_type, strategy);
110    }
111
112    /// Handle an error and attempt recovery
113    pub async fn handle_error(
114        &self,
115        error: &AiLibError,
116        context: &ErrorContext,
117    ) -> Result<(), AiLibError> {
118        let error_type = self.classify_error(error);
119
120        // Record the error
121        self.record_error(error_type.clone(), context.clone()).await;
122
123        // Try to find a recovery strategy
124        if let Some(strategy) = self.recovery_strategies.get(&error_type) {
125            if strategy.can_recover(error).await {
126                return strategy.recover(error, context).await;
127            }
128        }
129
130        Err((*error).clone())
131    }
132
133    /// Classify an error into a specific type
134    fn classify_error(&self, error: &AiLibError) -> ErrorType {
135        match error {
136            AiLibError::RateLimitExceeded(_) => ErrorType::RateLimit,
137            AiLibError::NetworkError(_) => ErrorType::Network,
138            AiLibError::AuthenticationError(_) => ErrorType::Authentication,
139            AiLibError::ProviderError(_) => ErrorType::Provider,
140            AiLibError::TimeoutError(_) => ErrorType::Timeout,
141            AiLibError::ConfigurationError(_) => ErrorType::Configuration,
142            AiLibError::InvalidRequest(_) => ErrorType::Validation,
143            AiLibError::SerializationError(_) => ErrorType::Serialization,
144            AiLibError::DeserializationError(_) => ErrorType::Deserialization,
145            AiLibError::FileError(_) => ErrorType::FileOperation,
146            AiLibError::ModelNotFound(_) => ErrorType::ModelNotFound,
147            AiLibError::ContextLengthExceeded(_) => ErrorType::ContextLengthExceeded,
148            AiLibError::UnsupportedFeature(_) => ErrorType::UnsupportedFeature,
149            _ => ErrorType::Unknown,
150        }
151    }
152
153    /// Generate intelligent suggested action based on error pattern
154    fn generate_suggested_action(
155        &self,
156        error_type: &ErrorType,
157        pattern: &ErrorPattern,
158    ) -> SuggestedAction {
159        match error_type {
160            ErrorType::RateLimit => {
161                if pattern.frequency > 10.0 {
162                    SuggestedAction::SwitchProvider {
163                        alternative_providers: vec!["groq".to_string(), "anthropic".to_string()],
164                    }
165                } else {
166                    SuggestedAction::Retry {
167                        delay_ms: 60000,
168                        max_attempts: 3,
169                    }
170                }
171            }
172            ErrorType::Network => SuggestedAction::Retry {
173                delay_ms: 2000,
174                max_attempts: 5,
175            },
176            ErrorType::Authentication => SuggestedAction::CheckCredentials,
177            ErrorType::Provider => SuggestedAction::SwitchProvider {
178                alternative_providers: vec!["openai".to_string(), "groq".to_string()],
179            },
180            ErrorType::Timeout => SuggestedAction::Retry {
181                delay_ms: 5000,
182                max_attempts: 3,
183            },
184            ErrorType::ContextLengthExceeded => SuggestedAction::ReduceRequestSize {
185                max_tokens: Some(1000),
186            },
187            ErrorType::ModelNotFound => SuggestedAction::ContactSupport {
188                reason: "Model not found - please verify model name".to_string(),
189            },
190            _ => SuggestedAction::NoAction,
191        }
192    }
193
194    /// Record an error in the history
195    async fn record_error(&self, error_type: ErrorType, mut context: ErrorContext) {
196        let now = chrono::Utc::now();
197        let record = ErrorRecord {
198            error_type: error_type.clone(),
199            context: context.clone(),
200            timestamp: now,
201        };
202        // Update error patterns (drop lock before await in metrics)
203        self.update_error_pattern(&error_type, now).await;
204
205        // Generate suggested action based on pattern
206        let suggested_action = self.get_suggested_action_for_error(&error_type).await;
207        context.suggested_action = suggested_action;
208        {
209            let mut history = self.error_history.lock().unwrap();
210            history.push_back(record);
211            if history.len() > 1000 {
212                history.pop_front();
213            }
214        }
215
216        // Keep only the last 1000 records
217        // Record metrics
218        if let Some(metrics) = &self.metrics {
219            metrics
220                .incr_counter(&format!("errors.{}", self.error_type_name(&error_type)), 1)
221                .await;
222        }
223    }
224
225    /// Update error pattern analysis
226    async fn update_error_pattern(
227        &self,
228        error_type: &ErrorType,
229        timestamp: chrono::DateTime<chrono::Utc>,
230    ) {
231        let mut patterns = self.error_patterns.lock().unwrap();
232        let entry = patterns.entry(error_type.clone());
233        use std::collections::hash_map::Entry;
234        match entry {
235            Entry::Occupied(mut occ) => {
236                let pattern = occ.get_mut();
237                pattern.count += 1;
238                pattern.last_occurrence = timestamp;
239                let duration = pattern
240                    .last_occurrence
241                    .signed_duration_since(pattern.first_occurrence);
242                if duration.num_minutes() > 0 {
243                    pattern.frequency = pattern.count as f64 / duration.num_minutes() as f64;
244                }
245                pattern.suggested_action = self.generate_suggested_action(error_type, pattern);
246            }
247            Entry::Vacant(vac) => {
248                vac.insert(ErrorPattern {
249                    error_type: error_type.clone(),
250                    count: 1,
251                    first_occurrence: timestamp,
252                    last_occurrence: timestamp,
253                    frequency: 0.0,
254                    suggested_action: SuggestedAction::NoAction,
255                    recovery_attempts: 0,
256                    successful_recoveries: 0,
257                });
258            }
259        }
260    }
261
262    /// Get suggested action for a specific error type
263    async fn get_suggested_action_for_error(&self, error_type: &ErrorType) -> SuggestedAction {
264        let patterns = self.error_patterns.lock().unwrap();
265        patterns
266            .get(error_type)
267            .map(|p| p.suggested_action.clone())
268            .unwrap_or(SuggestedAction::NoAction)
269    }
270
271    /// Get error type name for metrics
272    fn error_type_name(&self, error_type: &ErrorType) -> String {
273        match error_type {
274            ErrorType::RateLimit => "rate_limit".to_string(),
275            ErrorType::Network => "network".to_string(),
276            ErrorType::Authentication => "authentication".to_string(),
277            ErrorType::Provider => "provider".to_string(),
278            ErrorType::Timeout => "timeout".to_string(),
279            ErrorType::Configuration => "configuration".to_string(),
280            ErrorType::Validation => "validation".to_string(),
281            ErrorType::Serialization => "serialization".to_string(),
282            ErrorType::Deserialization => "deserialization".to_string(),
283            ErrorType::FileOperation => "file_operation".to_string(),
284            ErrorType::ModelNotFound => "model_not_found".to_string(),
285            ErrorType::ContextLengthExceeded => "context_length_exceeded".to_string(),
286            ErrorType::UnsupportedFeature => "unsupported_feature".to_string(),
287            ErrorType::Unknown => "unknown".to_string(),
288        }
289    }
290
291    /// Get error patterns for analysis
292    pub fn get_error_patterns(&self) -> HashMap<ErrorType, ErrorPattern> {
293        self.error_patterns.lock().unwrap().clone()
294    }
295
296    /// Get error statistics
297    pub fn get_error_statistics(&self) -> ErrorStatistics {
298        let patterns = self.error_patterns.lock().unwrap();
299        let total_errors: u32 = patterns.values().map(|p| p.count).sum();
300        let most_common_error = patterns
301            .values()
302            .max_by_key(|p| p.count)
303            .map(|p| p.error_type.clone());
304
305        ErrorStatistics {
306            total_errors,
307            unique_error_types: patterns.len(),
308            most_common_error,
309            patterns: patterns.clone(),
310        }
311    }
312
313    /// Reset all error tracking
314    pub fn reset(&self) {
315        let mut history = self.error_history.lock().unwrap();
316        history.clear();
317
318        let mut patterns = self.error_patterns.lock().unwrap();
319        patterns.clear();
320    }
321}
322
323/// Error statistics for monitoring and analysis
324#[derive(Debug, Clone, Serialize, Deserialize)]
325pub struct ErrorStatistics {
326    pub total_errors: u32,
327    pub unique_error_types: usize,
328    pub most_common_error: Option<ErrorType>,
329    pub patterns: HashMap<ErrorType, ErrorPattern>,
330}