1use crate::error::{Error, ErrorContext, ErrorLogEntry};
4use std::sync::Arc;
5use tokio::sync::RwLock;
6use tracing::{error, info, warn};
7
8#[derive(Debug, Clone)]
10pub struct ErrorReporter {
11 logs: Arc<RwLock<Vec<ErrorLogEntry>>>,
12 max_logs: usize,
13}
14
15impl ErrorReporter {
16 pub fn new() -> Self {
18 Self {
19 logs: Arc::new(RwLock::new(Vec::new())),
20 max_logs: 1000,
21 }
22 }
23
24 pub fn with_max_logs(max_logs: usize) -> Self {
26 Self {
27 logs: Arc::new(RwLock::new(Vec::new())),
28 max_logs,
29 }
30 }
31
32 pub async fn report_error(&self, error: &Error, context: Option<ErrorContext>) {
34 let mut log_entry = ErrorLogEntry::new(
35 error.error_type().to_string(),
36 error.to_string(),
37 )
38 .with_recoverable(error.is_recoverable());
39
40 if let Some(ctx) = context {
41 if let Some(tool_id) = ctx.tool_id {
42 log_entry = log_entry.with_tool_id(tool_id);
43 }
44 if let Some(server_id) = ctx.server_id {
45 log_entry = log_entry.with_server_id(server_id);
46 }
47 if let Some(parameters) = ctx.parameters {
48 log_entry = log_entry.with_parameters(parameters);
49 }
50 if let Some(stack_trace) = ctx.stack_trace {
51 log_entry = log_entry.with_stack_trace(stack_trace);
52 }
53 }
54
55 match error {
57 Error::ToolNotFound(_) | Error::PermissionDenied(_) | Error::NamingConflict(_) => {
58 error!("Error: {} - {}", log_entry.error_type, log_entry.message);
59 }
60 Error::TimeoutError(_) | Error::ConnectionError(_) => {
61 warn!("Error: {} - {}", log_entry.error_type, log_entry.message);
62 }
63 _ => {
64 info!("Error: {} - {}", log_entry.error_type, log_entry.message);
65 }
66 }
67
68 let mut logs = self.logs.write().await;
70 logs.push(log_entry);
71
72 if logs.len() > self.max_logs {
74 let remove_count = logs.len() - self.max_logs;
75 logs.drain(0..remove_count);
76 }
77 }
78
79 pub async fn report_error_with_retry(
81 &self,
82 error: &Error,
83 context: Option<ErrorContext>,
84 retry_count: u32,
85 ) {
86 let mut log_entry = ErrorLogEntry::new(
87 error.error_type().to_string(),
88 error.to_string(),
89 )
90 .with_recoverable(error.is_recoverable())
91 .with_retry_count(retry_count);
92
93 if let Some(ctx) = context {
94 if let Some(tool_id) = ctx.tool_id {
95 log_entry = log_entry.with_tool_id(tool_id);
96 }
97 if let Some(server_id) = ctx.server_id {
98 log_entry = log_entry.with_server_id(server_id);
99 }
100 if let Some(parameters) = ctx.parameters {
101 log_entry = log_entry.with_parameters(parameters);
102 }
103 if let Some(stack_trace) = ctx.stack_trace {
104 log_entry = log_entry.with_stack_trace(stack_trace);
105 }
106 }
107
108 warn!(
109 "Error (retry {}): {} - {}",
110 retry_count, log_entry.error_type, log_entry.message
111 );
112
113 let mut logs = self.logs.write().await;
114 logs.push(log_entry);
115
116 if logs.len() > self.max_logs {
117 let remove_count = logs.len() - self.max_logs;
118 logs.drain(0..remove_count);
119 }
120 }
121
122 pub async fn get_logs(&self) -> Vec<ErrorLogEntry> {
124 let logs = self.logs.read().await;
125 logs.clone()
126 }
127
128 pub async fn get_logs_by_type(&self, error_type: &str) -> Vec<ErrorLogEntry> {
130 let logs = self.logs.read().await;
131 logs.iter()
132 .filter(|log| log.error_type == error_type)
133 .cloned()
134 .collect()
135 }
136
137 pub async fn get_logs_by_tool(&self, tool_id: &str) -> Vec<ErrorLogEntry> {
139 let logs = self.logs.read().await;
140 logs.iter()
141 .filter(|log| log.tool_id.as_deref() == Some(tool_id))
142 .cloned()
143 .collect()
144 }
145
146 pub async fn get_logs_by_server(&self, server_id: &str) -> Vec<ErrorLogEntry> {
148 let logs = self.logs.read().await;
149 logs.iter()
150 .filter(|log| log.server_id.as_deref() == Some(server_id))
151 .cloned()
152 .collect()
153 }
154
155 pub async fn log_count(&self) -> usize {
157 let logs = self.logs.read().await;
158 logs.len()
159 }
160
161 pub async fn clear_logs(&self) {
163 let mut logs = self.logs.write().await;
164 logs.clear();
165 info!("Error logs cleared");
166 }
167
168 pub async fn get_statistics(&self) -> ErrorStatistics {
170 let logs = self.logs.read().await;
171
172 let total_errors = logs.len();
173 let recoverable_errors = logs.iter().filter(|l| l.is_recoverable).count();
174 let permanent_errors = total_errors - recoverable_errors;
175
176 let mut error_type_counts = std::collections::HashMap::new();
177 for log in logs.iter() {
178 *error_type_counts.entry(log.error_type.clone()).or_insert(0) += 1;
179 }
180
181 let mut tool_error_counts = std::collections::HashMap::new();
182 for log in logs.iter() {
183 if let Some(tool_id) = &log.tool_id {
184 *tool_error_counts.entry(tool_id.clone()).or_insert(0) += 1;
185 }
186 }
187
188 let mut server_error_counts = std::collections::HashMap::new();
189 for log in logs.iter() {
190 if let Some(server_id) = &log.server_id {
191 *server_error_counts.entry(server_id.clone()).or_insert(0) += 1;
192 }
193 }
194
195 ErrorStatistics {
196 total_errors,
197 recoverable_errors,
198 permanent_errors,
199 error_type_counts,
200 tool_error_counts,
201 server_error_counts,
202 }
203 }
204}
205
206impl Default for ErrorReporter {
207 fn default() -> Self {
208 Self::new()
209 }
210}
211
212#[derive(Debug, Clone)]
214pub struct ErrorStatistics {
215 pub total_errors: usize,
216 pub recoverable_errors: usize,
217 pub permanent_errors: usize,
218 pub error_type_counts: std::collections::HashMap<String, usize>,
219 pub tool_error_counts: std::collections::HashMap<String, usize>,
220 pub server_error_counts: std::collections::HashMap<String, usize>,
221}
222
223impl ErrorStatistics {
224 pub fn most_common_error_type(&self) -> Option<(String, usize)> {
226 self.error_type_counts
227 .iter()
228 .max_by_key(|(_, count)| *count)
229 .map(|(error_type, count)| (error_type.clone(), *count))
230 }
231
232 pub fn most_problematic_tool(&self) -> Option<(String, usize)> {
234 self.tool_error_counts
235 .iter()
236 .max_by_key(|(_, count)| *count)
237 .map(|(tool_id, count)| (tool_id.clone(), *count))
238 }
239
240 pub fn most_problematic_server(&self) -> Option<(String, usize)> {
242 self.server_error_counts
243 .iter()
244 .max_by_key(|(_, count)| *count)
245 .map(|(server_id, count)| (server_id.clone(), *count))
246 }
247
248 pub fn recovery_rate(&self) -> f64 {
250 if self.total_errors == 0 {
251 0.0
252 } else {
253 (self.recoverable_errors as f64 / self.total_errors as f64) * 100.0
254 }
255 }
256}
257
258pub struct ErrorMessageFormatter;
260
261impl ErrorMessageFormatter {
262 pub fn format_for_user(error: &Error) -> String {
264 error.user_message()
265 }
266
267 pub fn format_with_context(error: &Error, context: &ErrorContext) -> String {
269 let mut message = error.user_message();
270
271 if let Some(tool_id) = &context.tool_id {
272 message.push_str(&format!("\nTool: {}", tool_id));
273 }
274
275 if let Some(parameters) = &context.parameters {
276 message.push_str(&format!("\nParameters: {}", parameters));
277 }
278
279 if let Some(server_id) = &context.server_id {
280 message.push_str(&format!("\nServer: {}", server_id));
281 }
282
283 message
284 }
285
286 pub fn format_for_logging(error: &Error, context: Option<&ErrorContext>) -> String {
288 let mut message = format!("[{}] {}", error.error_type(), error);
289
290 if let Some(ctx) = context {
291 if let Some(tool_id) = &ctx.tool_id {
292 message.push_str(&format!(" [tool: {}]", tool_id));
293 }
294 if let Some(server_id) = &ctx.server_id {
295 message.push_str(&format!(" [server: {}]", server_id));
296 }
297 if let Some(parameters) = &ctx.parameters {
298 message.push_str(&format!(" [params: {}]", parameters));
299 }
300 }
301
302 message
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[tokio::test]
311 async fn test_create_error_reporter() {
312 let reporter = ErrorReporter::new();
313 assert_eq!(reporter.log_count().await, 0);
314 }
315
316 #[tokio::test]
317 async fn test_report_error() {
318 let reporter = ErrorReporter::new();
319 let error = Error::ToolNotFound("test-tool".to_string());
320
321 reporter.report_error(&error, None).await;
322
323 assert_eq!(reporter.log_count().await, 1);
324 let logs = reporter.get_logs().await;
325 assert_eq!(logs[0].error_type, "ToolNotFound");
326 }
327
328 #[tokio::test]
329 async fn test_report_error_with_context() {
330 let reporter = ErrorReporter::new();
331 let error = Error::ExecutionError("Tool failed".to_string());
332 let context = ErrorContext::new()
333 .with_tool_id("test-tool".to_string())
334 .with_server_id("server1".to_string());
335
336 reporter.report_error(&error, Some(context)).await;
337
338 let logs = reporter.get_logs().await;
339 assert_eq!(logs[0].tool_id, Some("test-tool".to_string()));
340 assert_eq!(logs[0].server_id, Some("server1".to_string()));
341 }
342
343 #[tokio::test]
344 async fn test_get_logs_by_type() {
345 let reporter = ErrorReporter::new();
346
347 reporter
348 .report_error(&Error::ToolNotFound("tool1".to_string()), None)
349 .await;
350 reporter
351 .report_error(&Error::ToolNotFound("tool2".to_string()), None)
352 .await;
353 reporter
354 .report_error(&Error::ConnectionError("conn".to_string()), None)
355 .await;
356
357 let logs = reporter.get_logs_by_type("ToolNotFound").await;
358 assert_eq!(logs.len(), 2);
359 }
360
361 #[tokio::test]
362 async fn test_get_logs_by_tool() {
363 let reporter = ErrorReporter::new();
364
365 let context1 = ErrorContext::new().with_tool_id("tool1".to_string());
366 let context2 = ErrorContext::new().with_tool_id("tool2".to_string());
367
368 reporter
369 .report_error(&Error::ExecutionError("failed".to_string()), Some(context1))
370 .await;
371 reporter
372 .report_error(&Error::ExecutionError("failed".to_string()), Some(context2))
373 .await;
374
375 let logs = reporter.get_logs_by_tool("tool1").await;
376 assert_eq!(logs.len(), 1);
377 }
378
379 #[tokio::test]
380 async fn test_get_statistics() {
381 let reporter = ErrorReporter::new();
382
383 reporter
384 .report_error(&Error::ToolNotFound("tool1".to_string()), None)
385 .await;
386 reporter
387 .report_error(&Error::ToolNotFound("tool2".to_string()), None)
388 .await;
389 reporter
390 .report_error(&Error::ConnectionError("conn".to_string()), None)
391 .await;
392
393 let stats = reporter.get_statistics().await;
394 assert_eq!(stats.total_errors, 3);
395 assert_eq!(stats.error_type_counts.get("ToolNotFound"), Some(&2));
396 assert_eq!(stats.error_type_counts.get("ConnectionError"), Some(&1));
397 }
398
399 #[tokio::test]
400 async fn test_clear_logs() {
401 let reporter = ErrorReporter::new();
402
403 reporter
404 .report_error(&Error::ToolNotFound("tool1".to_string()), None)
405 .await;
406 assert_eq!(reporter.log_count().await, 1);
407
408 reporter.clear_logs().await;
409 assert_eq!(reporter.log_count().await, 0);
410 }
411
412 #[test]
413 fn test_error_message_formatter() {
414 let error = Error::ToolNotFound("test-tool".to_string());
415 let message = ErrorMessageFormatter::format_for_user(&error);
416 assert!(message.contains("test-tool"));
417 }
418
419 #[test]
420 fn test_error_message_formatter_with_context() {
421 let error = Error::ExecutionError("failed".to_string());
422 let context = ErrorContext::new()
423 .with_tool_id("test-tool".to_string())
424 .with_server_id("server1".to_string());
425
426 let message = ErrorMessageFormatter::format_with_context(&error, &context);
427 assert!(message.contains("test-tool"));
428 assert!(message.contains("server1"));
429 }
430
431 #[test]
432 fn test_error_statistics_most_common() {
433 let mut stats = ErrorStatistics {
434 total_errors: 3,
435 recoverable_errors: 1,
436 permanent_errors: 2,
437 error_type_counts: std::collections::HashMap::new(),
438 tool_error_counts: std::collections::HashMap::new(),
439 server_error_counts: std::collections::HashMap::new(),
440 };
441
442 stats.error_type_counts.insert("ToolNotFound".to_string(), 2);
443 stats.error_type_counts.insert("ConnectionError".to_string(), 1);
444
445 let most_common = stats.most_common_error_type();
446 assert_eq!(most_common, Some(("ToolNotFound".to_string(), 2)));
447 }
448
449 #[test]
450 fn test_error_statistics_recovery_rate() {
451 let stats = ErrorStatistics {
452 total_errors: 10,
453 recoverable_errors: 7,
454 permanent_errors: 3,
455 error_type_counts: std::collections::HashMap::new(),
456 tool_error_counts: std::collections::HashMap::new(),
457 server_error_counts: std::collections::HashMap::new(),
458 };
459
460 assert_eq!(stats.recovery_rate(), 70.0);
461 }
462}