Skip to main content

aster/agents/error_handling/
error_handler.rs

1//! Error Handler
2//!
3//! Provides unified error recording and management for agent execution.
4//! Records errors with timestamps, context, and optional stack traces.
5//!
6//! **Validates: Requirements 15.1, 15.3**
7
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13
14/// Error severity levels
15#[derive(
16    Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, Default,
17)]
18#[serde(rename_all = "lowercase")]
19pub enum ErrorSeverity {
20    /// Debug level - for development
21    Debug,
22    /// Info level - informational
23    Info,
24    /// Warning level - potential issues
25    Warning,
26    /// Error level - recoverable errors
27    #[default]
28    Error,
29    /// Critical level - unrecoverable errors
30    Critical,
31}
32
33impl std::fmt::Display for ErrorSeverity {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        match self {
36            ErrorSeverity::Debug => write!(f, "debug"),
37            ErrorSeverity::Info => write!(f, "info"),
38            ErrorSeverity::Warning => write!(f, "warning"),
39            ErrorSeverity::Error => write!(f, "error"),
40            ErrorSeverity::Critical => write!(f, "critical"),
41        }
42    }
43}
44
45/// Error kinds for categorization
46#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
47#[serde(rename_all = "snake_case")]
48pub enum AgentErrorKind {
49    /// Timeout error
50    Timeout,
51    /// API call error
52    ApiCall,
53    /// Tool execution error
54    ToolExecution,
55    /// Context error
56    Context,
57    /// Configuration error
58    Configuration,
59    /// Resource limit error
60    ResourceLimit,
61    /// Network error
62    Network,
63    /// Serialization error
64    Serialization,
65    /// Internal error
66    Internal,
67    /// Custom error type
68    Custom(String),
69}
70
71impl std::fmt::Display for AgentErrorKind {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        match self {
74            AgentErrorKind::Timeout => write!(f, "timeout"),
75            AgentErrorKind::ApiCall => write!(f, "api_call"),
76            AgentErrorKind::ToolExecution => write!(f, "tool_execution"),
77            AgentErrorKind::Context => write!(f, "context"),
78            AgentErrorKind::Configuration => write!(f, "configuration"),
79            AgentErrorKind::ResourceLimit => write!(f, "resource_limit"),
80            AgentErrorKind::Network => write!(f, "network"),
81            AgentErrorKind::Serialization => write!(f, "serialization"),
82            AgentErrorKind::Internal => write!(f, "internal"),
83            AgentErrorKind::Custom(name) => write!(f, "custom:{}", name),
84        }
85    }
86}
87
88/// Context information for an error
89#[derive(Debug, Clone, Default, Serialize, Deserialize)]
90#[serde(rename_all = "camelCase")]
91pub struct ErrorContext {
92    /// Agent ID that encountered the error
93    pub agent_id: Option<String>,
94    /// Phase of execution (e.g., "tool_call", "api_call", "initialization")
95    pub phase: Option<String>,
96    /// Tool name if error occurred during tool execution
97    pub tool_name: Option<String>,
98    /// Tool call ID if applicable
99    pub tool_call_id: Option<String>,
100    /// Additional context data
101    pub metadata: HashMap<String, serde_json::Value>,
102}
103
104impl ErrorContext {
105    /// Create a new empty error context
106    pub fn new() -> Self {
107        Self::default()
108    }
109
110    /// Set the agent ID
111    pub fn with_agent_id(mut self, agent_id: impl Into<String>) -> Self {
112        self.agent_id = Some(agent_id.into());
113        self
114    }
115
116    /// Set the phase
117    pub fn with_phase(mut self, phase: impl Into<String>) -> Self {
118        self.phase = Some(phase.into());
119        self
120    }
121
122    /// Set the tool name
123    pub fn with_tool_name(mut self, tool_name: impl Into<String>) -> Self {
124        self.tool_name = Some(tool_name.into());
125        self
126    }
127
128    /// Set the tool call ID
129    pub fn with_tool_call_id(mut self, tool_call_id: impl Into<String>) -> Self {
130        self.tool_call_id = Some(tool_call_id.into());
131        self
132    }
133
134    /// Add metadata
135    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
136        self.metadata.insert(key.into(), value);
137        self
138    }
139
140    /// Check if context has any information
141    pub fn is_empty(&self) -> bool {
142        self.agent_id.is_none()
143            && self.phase.is_none()
144            && self.tool_name.is_none()
145            && self.tool_call_id.is_none()
146            && self.metadata.is_empty()
147    }
148}
149
150/// Unified error record with full context
151#[derive(Debug, Clone, Serialize, Deserialize)]
152#[serde(rename_all = "camelCase")]
153pub struct ErrorRecord {
154    /// Unique error ID
155    pub id: String,
156    /// Error kind
157    pub kind: AgentErrorKind,
158    /// Error severity
159    pub severity: ErrorSeverity,
160    /// Error message
161    pub message: String,
162    /// Error timestamp
163    pub timestamp: DateTime<Utc>,
164    /// Error context
165    pub context: ErrorContext,
166    /// Stack trace if available
167    pub stack_trace: Option<String>,
168    /// Whether the error is recoverable
169    pub recoverable: bool,
170    /// Number of retry attempts made
171    pub retry_count: u32,
172}
173
174impl ErrorRecord {
175    /// Create a new error record
176    pub fn new(kind: AgentErrorKind, message: impl Into<String>) -> Self {
177        Self {
178            id: uuid::Uuid::new_v4().to_string(),
179            kind,
180            severity: ErrorSeverity::Error,
181            message: message.into(),
182            timestamp: Utc::now(),
183            context: ErrorContext::new(),
184            stack_trace: None,
185            recoverable: true,
186            retry_count: 0,
187        }
188    }
189
190    /// Set the severity
191    pub fn with_severity(mut self, severity: ErrorSeverity) -> Self {
192        self.severity = severity;
193        self
194    }
195
196    /// Set the context
197    pub fn with_context(mut self, context: ErrorContext) -> Self {
198        self.context = context;
199        self
200    }
201
202    /// Set the stack trace
203    pub fn with_stack_trace(mut self, stack_trace: impl Into<String>) -> Self {
204        self.stack_trace = Some(stack_trace.into());
205        self
206    }
207
208    /// Set whether the error is recoverable
209    pub fn with_recoverable(mut self, recoverable: bool) -> Self {
210        self.recoverable = recoverable;
211        self
212    }
213
214    /// Set the retry count
215    pub fn with_retry_count(mut self, count: u32) -> Self {
216        self.retry_count = count;
217        self
218    }
219
220    /// Create a timeout error
221    pub fn timeout(message: impl Into<String>) -> Self {
222        Self::new(AgentErrorKind::Timeout, message)
223            .with_severity(ErrorSeverity::Error)
224            .with_recoverable(false)
225    }
226
227    /// Create an API call error
228    pub fn api_call(message: impl Into<String>) -> Self {
229        Self::new(AgentErrorKind::ApiCall, message).with_severity(ErrorSeverity::Error)
230    }
231
232    /// Create a tool execution error
233    pub fn tool_execution(tool_name: impl Into<String>, message: impl Into<String>) -> Self {
234        let tool_name = tool_name.into();
235        Self::new(AgentErrorKind::ToolExecution, message)
236            .with_context(ErrorContext::new().with_tool_name(&tool_name))
237    }
238
239    /// Check if this error has context
240    pub fn has_context(&self) -> bool {
241        !self.context.is_empty()
242    }
243
244    /// Check if this error has a stack trace
245    pub fn has_stack_trace(&self) -> bool {
246        self.stack_trace.is_some()
247    }
248}
249
250/// Agent error type for Result handling
251#[derive(Debug, Clone)]
252pub struct AgentError {
253    /// The error record
254    pub record: ErrorRecord,
255    /// Source error message if wrapped
256    pub source: Option<String>,
257}
258
259impl AgentError {
260    /// Create a new agent error
261    pub fn new(kind: AgentErrorKind, message: impl Into<String>) -> Self {
262        Self {
263            record: ErrorRecord::new(kind, message),
264            source: None,
265        }
266    }
267
268    /// Create from an error record
269    pub fn from_record(record: ErrorRecord) -> Self {
270        Self {
271            record,
272            source: None,
273        }
274    }
275
276    /// Set the source error
277    pub fn with_source(mut self, source: impl Into<String>) -> Self {
278        self.source = Some(source.into());
279        self
280    }
281
282    /// Set the context
283    pub fn with_context(mut self, context: ErrorContext) -> Self {
284        self.record = self.record.with_context(context);
285        self
286    }
287
288    /// Get the error kind
289    pub fn kind(&self) -> &AgentErrorKind {
290        &self.record.kind
291    }
292
293    /// Get the error message
294    pub fn message(&self) -> &str {
295        &self.record.message
296    }
297
298    /// Check if the error is recoverable
299    pub fn is_recoverable(&self) -> bool {
300        self.record.recoverable
301    }
302}
303
304impl std::fmt::Display for AgentError {
305    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306        write!(f, "[{}] {}", self.record.kind, self.record.message)?;
307        if let Some(source) = &self.source {
308            write!(f, " (caused by: {})", source)?;
309        }
310        Ok(())
311    }
312}
313
314impl std::error::Error for AgentError {}
315
316/// Error handler for recording and managing errors
317#[derive(Debug)]
318pub struct ErrorHandler {
319    /// All recorded errors indexed by ID
320    errors: HashMap<String, ErrorRecord>,
321    /// Errors indexed by agent ID
322    errors_by_agent: HashMap<String, Vec<String>>,
323    /// Maximum number of errors to keep
324    max_errors: usize,
325    /// Whether to capture stack traces
326    capture_stack_traces: bool,
327}
328
329impl Default for ErrorHandler {
330    fn default() -> Self {
331        Self::new()
332    }
333}
334
335impl ErrorHandler {
336    /// Create a new error handler
337    pub fn new() -> Self {
338        Self {
339            errors: HashMap::new(),
340            errors_by_agent: HashMap::new(),
341            max_errors: 10000,
342            capture_stack_traces: false,
343        }
344    }
345
346    /// Create with configuration
347    pub fn with_config(max_errors: usize, capture_stack_traces: bool) -> Self {
348        Self {
349            errors: HashMap::new(),
350            errors_by_agent: HashMap::new(),
351            max_errors,
352            capture_stack_traces,
353        }
354    }
355
356    /// Record an error
357    pub fn record(&mut self, mut error: ErrorRecord) -> String {
358        // Capture stack trace if enabled and not already present
359        if self.capture_stack_traces && error.stack_trace.is_none() {
360            error.stack_trace = Some(Self::capture_backtrace());
361        }
362
363        let id = error.id.clone();
364
365        // Track by agent ID if present
366        if let Some(agent_id) = &error.context.agent_id {
367            self.errors_by_agent
368                .entry(agent_id.clone())
369                .or_default()
370                .push(id.clone());
371        }
372
373        // Enforce max errors limit
374        if self.errors.len() >= self.max_errors {
375            self.remove_oldest();
376        }
377
378        self.errors.insert(id.clone(), error);
379        id
380    }
381
382    /// Record an error with context
383    pub fn record_with_context(
384        &mut self,
385        kind: AgentErrorKind,
386        message: impl Into<String>,
387        context: ErrorContext,
388    ) -> String {
389        let error = ErrorRecord::new(kind, message).with_context(context);
390        self.record(error)
391    }
392
393    /// Record a tool execution error
394    pub fn record_tool_error(
395        &mut self,
396        agent_id: &str,
397        tool_name: &str,
398        tool_call_id: Option<&str>,
399        message: impl Into<String>,
400    ) -> String {
401        let mut context = ErrorContext::new()
402            .with_agent_id(agent_id)
403            .with_phase("tool_execution")
404            .with_tool_name(tool_name);
405
406        if let Some(call_id) = tool_call_id {
407            context = context.with_tool_call_id(call_id);
408        }
409
410        let error = ErrorRecord::tool_execution(tool_name, message).with_context(context);
411        self.record(error)
412    }
413
414    /// Get an error by ID
415    pub fn get(&self, error_id: &str) -> Option<&ErrorRecord> {
416        self.errors.get(error_id)
417    }
418
419    /// Get all errors for an agent
420    pub fn get_by_agent(&self, agent_id: &str) -> Vec<&ErrorRecord> {
421        self.errors_by_agent
422            .get(agent_id)
423            .map(|ids| ids.iter().filter_map(|id| self.errors.get(id)).collect())
424            .unwrap_or_default()
425    }
426
427    /// Get all errors of a specific kind
428    pub fn get_by_kind(&self, kind: &AgentErrorKind) -> Vec<&ErrorRecord> {
429        self.errors.values().filter(|e| &e.kind == kind).collect()
430    }
431
432    /// Get all errors with severity >= threshold
433    pub fn get_by_severity(&self, min_severity: ErrorSeverity) -> Vec<&ErrorRecord> {
434        self.errors
435            .values()
436            .filter(|e| e.severity >= min_severity)
437            .collect()
438    }
439
440    /// Get all errors
441    pub fn get_all(&self) -> Vec<&ErrorRecord> {
442        self.errors.values().collect()
443    }
444
445    /// Get error count
446    pub fn count(&self) -> usize {
447        self.errors.len()
448    }
449
450    /// Get error count for an agent
451    pub fn count_by_agent(&self, agent_id: &str) -> usize {
452        self.errors_by_agent
453            .get(agent_id)
454            .map(|ids| ids.len())
455            .unwrap_or(0)
456    }
457
458    /// Clear all errors
459    pub fn clear(&mut self) {
460        self.errors.clear();
461        self.errors_by_agent.clear();
462    }
463
464    /// Clear errors for an agent
465    pub fn clear_by_agent(&mut self, agent_id: &str) {
466        if let Some(ids) = self.errors_by_agent.remove(agent_id) {
467            for id in ids {
468                self.errors.remove(&id);
469            }
470        }
471    }
472
473    /// Remove oldest error
474    fn remove_oldest(&mut self) {
475        if let Some(oldest_id) = self
476            .errors
477            .values()
478            .min_by_key(|e| e.timestamp)
479            .map(|e| e.id.clone())
480        {
481            if let Some(error) = self.errors.remove(&oldest_id) {
482                if let Some(agent_id) = &error.context.agent_id {
483                    if let Some(ids) = self.errors_by_agent.get_mut(agent_id) {
484                        ids.retain(|id| id != &oldest_id);
485                    }
486                }
487            }
488        }
489    }
490
491    /// Capture a backtrace
492    fn capture_backtrace() -> String {
493        std::backtrace::Backtrace::capture().to_string()
494    }
495
496    /// Enable or disable stack trace capture
497    pub fn set_capture_stack_traces(&mut self, capture: bool) {
498        self.capture_stack_traces = capture;
499    }
500
501    /// Set maximum number of errors to keep
502    pub fn set_max_errors(&mut self, max: usize) {
503        self.max_errors = max;
504    }
505}
506
507/// Thread-safe error handler wrapper
508#[allow(dead_code)]
509pub type SharedErrorHandler = Arc<RwLock<ErrorHandler>>;
510
511/// Create a new shared error handler
512#[allow(dead_code)]
513pub fn new_shared_error_handler() -> SharedErrorHandler {
514    Arc::new(RwLock::new(ErrorHandler::new()))
515}
516
517#[cfg(test)]
518mod tests {
519    use super::*;
520
521    #[test]
522    fn test_error_context_builder() {
523        let context = ErrorContext::new()
524            .with_agent_id("agent-1")
525            .with_phase("tool_execution")
526            .with_tool_name("bash")
527            .with_tool_call_id("call-1")
528            .with_metadata("key", serde_json::json!("value"));
529
530        assert_eq!(context.agent_id, Some("agent-1".to_string()));
531        assert_eq!(context.phase, Some("tool_execution".to_string()));
532        assert_eq!(context.tool_name, Some("bash".to_string()));
533        assert_eq!(context.tool_call_id, Some("call-1".to_string()));
534        assert!(!context.is_empty());
535    }
536
537    #[test]
538    fn test_error_record_creation() {
539        let error = ErrorRecord::new(AgentErrorKind::ApiCall, "API call failed");
540
541        assert!(!error.id.is_empty());
542        assert_eq!(error.kind, AgentErrorKind::ApiCall);
543        assert_eq!(error.message, "API call failed");
544        assert_eq!(error.severity, ErrorSeverity::Error);
545        assert!(error.recoverable);
546    }
547
548    #[test]
549    fn test_error_record_timeout() {
550        let error = ErrorRecord::timeout("Operation timed out after 30s");
551
552        assert_eq!(error.kind, AgentErrorKind::Timeout);
553        assert!(!error.recoverable);
554    }
555
556    #[test]
557    fn test_error_record_tool_execution() {
558        let error = ErrorRecord::tool_execution("bash", "Command failed");
559
560        assert_eq!(error.kind, AgentErrorKind::ToolExecution);
561        assert_eq!(error.context.tool_name, Some("bash".to_string()));
562    }
563
564    #[test]
565    fn test_error_handler_record() {
566        let mut handler = ErrorHandler::new();
567
568        let error = ErrorRecord::new(AgentErrorKind::ApiCall, "Test error")
569            .with_context(ErrorContext::new().with_agent_id("agent-1"));
570
571        let id = handler.record(error);
572
573        assert_eq!(handler.count(), 1);
574        assert!(handler.get(&id).is_some());
575        assert_eq!(handler.count_by_agent("agent-1"), 1);
576    }
577
578    #[test]
579    fn test_error_handler_record_tool_error() {
580        let mut handler = ErrorHandler::new();
581
582        let id = handler.record_tool_error("agent-1", "bash", Some("call-1"), "Command failed");
583
584        let error = handler.get(&id).unwrap();
585        assert_eq!(error.kind, AgentErrorKind::ToolExecution);
586        assert_eq!(error.context.agent_id, Some("agent-1".to_string()));
587        assert_eq!(error.context.tool_name, Some("bash".to_string()));
588        assert_eq!(error.context.tool_call_id, Some("call-1".to_string()));
589    }
590
591    #[test]
592    fn test_error_handler_get_by_kind() {
593        let mut handler = ErrorHandler::new();
594
595        handler.record(ErrorRecord::new(AgentErrorKind::ApiCall, "Error 1"));
596        handler.record(ErrorRecord::new(AgentErrorKind::Timeout, "Error 2"));
597        handler.record(ErrorRecord::new(AgentErrorKind::ApiCall, "Error 3"));
598
599        let api_errors = handler.get_by_kind(&AgentErrorKind::ApiCall);
600        assert_eq!(api_errors.len(), 2);
601    }
602
603    #[test]
604    fn test_error_handler_get_by_severity() {
605        let mut handler = ErrorHandler::new();
606
607        handler.record(
608            ErrorRecord::new(AgentErrorKind::ApiCall, "Error 1")
609                .with_severity(ErrorSeverity::Warning),
610        );
611        handler.record(
612            ErrorRecord::new(AgentErrorKind::Timeout, "Error 2")
613                .with_severity(ErrorSeverity::Critical),
614        );
615        handler.record(
616            ErrorRecord::new(AgentErrorKind::ApiCall, "Error 3")
617                .with_severity(ErrorSeverity::Error),
618        );
619
620        let severe_errors = handler.get_by_severity(ErrorSeverity::Error);
621        assert_eq!(severe_errors.len(), 2); // Error and Critical
622    }
623
624    #[test]
625    fn test_error_handler_clear_by_agent() {
626        let mut handler = ErrorHandler::new();
627
628        handler.record(
629            ErrorRecord::new(AgentErrorKind::ApiCall, "Error 1")
630                .with_context(ErrorContext::new().with_agent_id("agent-1")),
631        );
632        handler.record(
633            ErrorRecord::new(AgentErrorKind::ApiCall, "Error 2")
634                .with_context(ErrorContext::new().with_agent_id("agent-2")),
635        );
636
637        handler.clear_by_agent("agent-1");
638
639        assert_eq!(handler.count(), 1);
640        assert_eq!(handler.count_by_agent("agent-1"), 0);
641        assert_eq!(handler.count_by_agent("agent-2"), 1);
642    }
643
644    #[test]
645    fn test_error_handler_max_errors() {
646        let mut handler = ErrorHandler::with_config(3, false);
647
648        for i in 0..5 {
649            handler.record(ErrorRecord::new(
650                AgentErrorKind::ApiCall,
651                format!("Error {}", i),
652            ));
653        }
654
655        assert_eq!(handler.count(), 3);
656    }
657
658    #[test]
659    fn test_agent_error_display() {
660        let error = AgentError::new(AgentErrorKind::ApiCall, "API call failed")
661            .with_source("Connection refused");
662
663        let display = format!("{}", error);
664        assert!(display.contains("api_call"));
665        assert!(display.contains("API call failed"));
666        assert!(display.contains("Connection refused"));
667    }
668}