1use crate::error_handling::{ErrorContext, SuggestedAction};
4use crate::metrics::Metrics;
5use crate::types::AiLibError;
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9use std::sync::{Arc, Mutex};
10use std::time::Instant;
11
12pub struct ErrorRecoveryManager {
14 error_history: Arc<Mutex<VecDeque<ErrorRecord>>>,
15 recovery_strategies: HashMap<ErrorType, Box<dyn RecoveryStrategy>>,
16 metrics: Option<Arc<dyn Metrics>>,
18 #[allow(dead_code)] start_time: Instant,
20 error_patterns: Arc<Mutex<HashMap<ErrorType, ErrorPattern>>>,
22}
23
24impl Default for ErrorRecoveryManager {
25 fn default() -> Self {
26 Self::new()
27 }
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ErrorPattern {
33 pub error_type: ErrorType,
34 pub count: u32,
35 pub first_occurrence: chrono::DateTime<chrono::Utc>,
36 pub last_occurrence: chrono::DateTime<chrono::Utc>,
37 pub frequency: f64, pub suggested_action: SuggestedAction,
39 pub recovery_attempts: u32,
40 pub successful_recoveries: u32,
41}
42
43#[derive(Debug, Clone)]
45pub struct ErrorRecord {
46 pub error_type: ErrorType,
47 pub context: ErrorContext,
48 pub timestamp: chrono::DateTime<chrono::Utc>,
49}
50
51#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
53pub enum ErrorType {
54 RateLimit,
55 Network,
56 Authentication,
57 Provider,
58 Timeout,
59 Configuration,
60 Validation,
61 Serialization,
62 Deserialization,
63 FileOperation,
64 ModelNotFound,
65 ContextLengthExceeded,
66 UnsupportedFeature,
67 Unknown,
68}
69
70#[async_trait]
72pub trait RecoveryStrategy: Send + Sync {
73 async fn can_recover(&self, error: &AiLibError) -> bool;
75
76 async fn recover(&self, error: &AiLibError, context: &ErrorContext) -> Result<(), AiLibError>;
78}
79
80impl ErrorRecoveryManager {
81 pub fn new() -> Self {
83 Self {
84 error_history: Arc::new(Mutex::new(VecDeque::new())),
85 recovery_strategies: HashMap::new(),
86 metrics: None,
87 start_time: Instant::now(),
88 error_patterns: Arc::new(Mutex::new(HashMap::new())),
89 }
90 }
91
92 pub fn with_metrics(metrics: Arc<dyn Metrics>) -> Self {
94 Self {
95 error_history: Arc::new(Mutex::new(VecDeque::new())),
96 recovery_strategies: HashMap::new(),
97 metrics: Some(metrics),
98 start_time: Instant::now(),
99 error_patterns: Arc::new(Mutex::new(HashMap::new())),
100 }
101 }
102
103 pub fn register_strategy(
105 &mut self,
106 error_type: ErrorType,
107 strategy: Box<dyn RecoveryStrategy>,
108 ) {
109 self.recovery_strategies.insert(error_type, strategy);
110 }
111
112 pub async fn handle_error(
114 &self,
115 error: &AiLibError,
116 context: &ErrorContext,
117 ) -> Result<(), AiLibError> {
118 let error_type = self.classify_error(error);
119
120 self.record_error(error_type.clone(), context.clone()).await;
122
123 if let Some(strategy) = self.recovery_strategies.get(&error_type) {
125 if strategy.can_recover(error).await {
126 return strategy.recover(error, context).await;
127 }
128 }
129
130 Err((*error).clone())
131 }
132
133 fn classify_error(&self, error: &AiLibError) -> ErrorType {
135 match error {
136 AiLibError::RateLimitExceeded(_) => ErrorType::RateLimit,
137 AiLibError::NetworkError(_) => ErrorType::Network,
138 AiLibError::AuthenticationError(_) => ErrorType::Authentication,
139 AiLibError::ProviderError(_) => ErrorType::Provider,
140 AiLibError::TimeoutError(_) => ErrorType::Timeout,
141 AiLibError::ConfigurationError(_) => ErrorType::Configuration,
142 AiLibError::InvalidRequest(_) => ErrorType::Validation,
143 AiLibError::SerializationError(_) => ErrorType::Serialization,
144 AiLibError::DeserializationError(_) => ErrorType::Deserialization,
145 AiLibError::FileError(_) => ErrorType::FileOperation,
146 AiLibError::ModelNotFound(_) => ErrorType::ModelNotFound,
147 AiLibError::ContextLengthExceeded(_) => ErrorType::ContextLengthExceeded,
148 AiLibError::UnsupportedFeature(_) => ErrorType::UnsupportedFeature,
149 _ => ErrorType::Unknown,
150 }
151 }
152
153 fn generate_suggested_action(
155 &self,
156 error_type: &ErrorType,
157 pattern: &ErrorPattern,
158 ) -> SuggestedAction {
159 match error_type {
160 ErrorType::RateLimit => {
161 if pattern.frequency > 10.0 {
162 SuggestedAction::SwitchProvider {
163 alternative_providers: vec!["groq".to_string(), "anthropic".to_string()],
164 }
165 } else {
166 SuggestedAction::Retry {
167 delay_ms: 60000,
168 max_attempts: 3,
169 }
170 }
171 }
172 ErrorType::Network => SuggestedAction::Retry {
173 delay_ms: 2000,
174 max_attempts: 5,
175 },
176 ErrorType::Authentication => SuggestedAction::CheckCredentials,
177 ErrorType::Provider => SuggestedAction::SwitchProvider {
178 alternative_providers: vec!["openai".to_string(), "groq".to_string()],
179 },
180 ErrorType::Timeout => SuggestedAction::Retry {
181 delay_ms: 5000,
182 max_attempts: 3,
183 },
184 ErrorType::ContextLengthExceeded => SuggestedAction::ReduceRequestSize {
185 max_tokens: Some(1000),
186 },
187 ErrorType::ModelNotFound => SuggestedAction::ContactSupport {
188 reason: "Model not found - please verify model name".to_string(),
189 },
190 _ => SuggestedAction::NoAction,
191 }
192 }
193
194 async fn record_error(&self, error_type: ErrorType, mut context: ErrorContext) {
196 let now = chrono::Utc::now();
197 let record = ErrorRecord {
198 error_type: error_type.clone(),
199 context: context.clone(),
200 timestamp: now,
201 };
202 self.update_error_pattern(&error_type, now).await;
204
205 let suggested_action = self.get_suggested_action_for_error(&error_type).await;
207 context.suggested_action = suggested_action;
208 {
209 let mut history = self.error_history.lock().unwrap();
210 history.push_back(record);
211 if history.len() > 1000 {
212 history.pop_front();
213 }
214 }
215
216 if let Some(metrics) = &self.metrics {
219 metrics
220 .incr_counter(&format!("errors.{}", self.error_type_name(&error_type)), 1)
221 .await;
222 }
223 }
224
225 async fn update_error_pattern(
227 &self,
228 error_type: &ErrorType,
229 timestamp: chrono::DateTime<chrono::Utc>,
230 ) {
231 let mut patterns = self.error_patterns.lock().unwrap();
232 let entry = patterns.entry(error_type.clone());
233 use std::collections::hash_map::Entry;
234 match entry {
235 Entry::Occupied(mut occ) => {
236 let pattern = occ.get_mut();
237 pattern.count += 1;
238 pattern.last_occurrence = timestamp;
239 let duration = pattern
240 .last_occurrence
241 .signed_duration_since(pattern.first_occurrence);
242 if duration.num_minutes() > 0 {
243 pattern.frequency = pattern.count as f64 / duration.num_minutes() as f64;
244 }
245 pattern.suggested_action = self.generate_suggested_action(error_type, pattern);
246 }
247 Entry::Vacant(vac) => {
248 vac.insert(ErrorPattern {
249 error_type: error_type.clone(),
250 count: 1,
251 first_occurrence: timestamp,
252 last_occurrence: timestamp,
253 frequency: 0.0,
254 suggested_action: SuggestedAction::NoAction,
255 recovery_attempts: 0,
256 successful_recoveries: 0,
257 });
258 }
259 }
260 }
261
262 async fn get_suggested_action_for_error(&self, error_type: &ErrorType) -> SuggestedAction {
264 let patterns = self.error_patterns.lock().unwrap();
265 patterns
266 .get(error_type)
267 .map(|p| p.suggested_action.clone())
268 .unwrap_or(SuggestedAction::NoAction)
269 }
270
271 fn error_type_name(&self, error_type: &ErrorType) -> String {
273 match error_type {
274 ErrorType::RateLimit => "rate_limit".to_string(),
275 ErrorType::Network => "network".to_string(),
276 ErrorType::Authentication => "authentication".to_string(),
277 ErrorType::Provider => "provider".to_string(),
278 ErrorType::Timeout => "timeout".to_string(),
279 ErrorType::Configuration => "configuration".to_string(),
280 ErrorType::Validation => "validation".to_string(),
281 ErrorType::Serialization => "serialization".to_string(),
282 ErrorType::Deserialization => "deserialization".to_string(),
283 ErrorType::FileOperation => "file_operation".to_string(),
284 ErrorType::ModelNotFound => "model_not_found".to_string(),
285 ErrorType::ContextLengthExceeded => "context_length_exceeded".to_string(),
286 ErrorType::UnsupportedFeature => "unsupported_feature".to_string(),
287 ErrorType::Unknown => "unknown".to_string(),
288 }
289 }
290
291 pub fn get_error_patterns(&self) -> HashMap<ErrorType, ErrorPattern> {
293 self.error_patterns.lock().unwrap().clone()
294 }
295
296 pub fn get_error_statistics(&self) -> ErrorStatistics {
298 let patterns = self.error_patterns.lock().unwrap();
299 let total_errors: u32 = patterns.values().map(|p| p.count).sum();
300 let most_common_error = patterns
301 .values()
302 .max_by_key(|p| p.count)
303 .map(|p| p.error_type.clone());
304
305 ErrorStatistics {
306 total_errors,
307 unique_error_types: patterns.len(),
308 most_common_error,
309 patterns: patterns.clone(),
310 }
311 }
312
313 pub fn reset(&self) {
315 let mut history = self.error_history.lock().unwrap();
316 history.clear();
317
318 let mut patterns = self.error_patterns.lock().unwrap();
319 patterns.clear();
320 }
321}
322
323#[derive(Debug, Clone, Serialize, Deserialize)]
325pub struct ErrorStatistics {
326 pub total_errors: u32,
327 pub unique_error_types: usize,
328 pub most_common_error: Option<ErrorType>,
329 pub patterns: HashMap<ErrorType, ErrorPattern>,
330}