Skip to main content

bob_core/
error_classifier.rs

1//! # Error Classifier
2//!
3//! Classifies errors into actionable categories for retry and failover decisions.
4//!
5//! Inspired by rustic-ai's error classification approach, this module provides
6//! classifiers that categorize errors by their error codes for decision-making.
7//!
8//! ## Categories
9//!
10//! | Category         | Meaning                                     | Failover? |
11//! |------------------|---------------------------------------------|-----------|
12//! | `"timeout"`      | Operation exceeded time limit               | Yes       |
13//! | `"rate_limited"` | Provider returned HTTP 429                  | Yes       |
14//! | `"http_5xx"`     | Provider returned server error              | Yes       |
15//! | `"connect_error"`| Network/connection failure                  | Yes       |
16//! | `"context_length"`| Input exceeded model context window         | No        |
17//! | `"other"`        | Unknown or unclassified error               | Config    |
18
19use crate::error::{AgentError, LlmError};
20
21/// Classify an [`AgentError`] into a stable category string.
22#[must_use]
23pub fn classify_agent_error(err: &AgentError) -> &'static str {
24    match err {
25        AgentError::Llm(llm) => classify_llm_error(llm),
26        AgentError::Timeout => "timeout",
27        AgentError::Tool(crate::error::ToolError::Timeout { .. }) => "timeout",
28        AgentError::Tool(_) => "other",
29        _ => "other",
30    }
31}
32
33/// Classify an [`LlmError`] into a stable category string.
34#[must_use]
35pub fn classify_llm_error(err: &LlmError) -> &'static str {
36    match err {
37        LlmError::RateLimited => "rate_limited",
38        LlmError::ContextLengthExceeded => "context_length",
39        LlmError::Provider(msg) => classify_provider_message(msg),
40        LlmError::Stream(msg) => classify_provider_message(msg),
41        LlmError::Other(_) => "other",
42    }
43}
44
45/// Classify a free-form error message string.
46fn classify_provider_message(msg: &str) -> &'static str {
47    let lower = msg.to_lowercase();
48    if lower.contains("429") || lower.contains("rate") || lower.contains("throttl") {
49        "rate_limited"
50    } else if lower.contains("500") || lower.contains("502") || lower.contains("503") {
51        "http_5xx"
52    } else if lower.contains("timeout") || lower.contains("timed out") {
53        "timeout"
54    } else if lower.contains("connect") || lower.contains("dns") || lower.contains("network") {
55        "connect_error"
56    } else if lower.contains("context length") || lower.contains("maximum") {
57        "context_length"
58    } else {
59        "other"
60    }
61}
62
63/// Configuration for which error categories should trigger failover.
64#[derive(Debug, Clone)]
65pub struct FailoverConfig {
66    /// Error categories that should trigger failover to a backup provider.
67    pub failover_on: Vec<String>,
68    /// Maximum retries on the primary provider before failing over.
69    pub retry_limit: u32,
70}
71
72impl Default for FailoverConfig {
73    fn default() -> Self {
74        Self {
75            failover_on: vec![
76                "timeout".into(),
77                "rate_limited".into(),
78                "http_5xx".into(),
79                "connect_error".into(),
80            ],
81            retry_limit: 2,
82        }
83    }
84}
85
86/// Result of a failover-aware execution.
87#[derive(Debug, Clone)]
88pub struct FailoverResult<T> {
89    /// The successful value.
90    pub value: T,
91    /// Whether failover to a backup provider occurred.
92    pub failed_over: bool,
93    /// Number of attempts on the primary provider.
94    pub primary_attempts: u32,
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100
101    #[test]
102    fn classify_rate_limited() {
103        let err = LlmError::RateLimited;
104        assert_eq!(classify_llm_error(&err), "rate_limited");
105    }
106
107    #[test]
108    fn classify_context_length() {
109        let err = LlmError::ContextLengthExceeded;
110        assert_eq!(classify_llm_error(&err), "context_length");
111    }
112
113    #[test]
114    fn classify_timeout_from_provider_message() {
115        let err = LlmError::Provider("request timed out".into());
116        assert_eq!(classify_llm_error(&err), "timeout");
117    }
118
119    #[test]
120    fn classify_429_from_provider_message() {
121        let err = LlmError::Provider("HTTP 429 Too Many Requests".into());
122        assert_eq!(classify_llm_error(&err), "rate_limited");
123    }
124
125    #[test]
126    fn classify_503_from_provider_message() {
127        let err = LlmError::Provider("HTTP 503 Service Unavailable".into());
128        assert_eq!(classify_llm_error(&err), "http_5xx");
129    }
130
131    #[test]
132    fn classify_connect_error() {
133        let err = LlmError::Provider("connection refused".into());
134        assert_eq!(classify_llm_error(&err), "connect_error");
135    }
136
137    #[test]
138    fn classify_unknown_returns_other() {
139        let err = LlmError::Provider("something weird".into());
140        assert_eq!(classify_llm_error(&err), "other");
141    }
142
143    #[test]
144    fn classify_agent_timeout() {
145        let err = AgentError::Timeout;
146        assert_eq!(classify_agent_error(&err), "timeout");
147    }
148
149    #[test]
150    fn classify_agent_llm_wrapped() {
151        let err = AgentError::Llm(LlmError::RateLimited);
152        assert_eq!(classify_agent_error(&err), "rate_limited");
153    }
154
155    #[test]
156    fn classify_agent_tool_timeout() {
157        let err = AgentError::Tool(crate::error::ToolError::Timeout { name: "x".into() });
158        assert_eq!(classify_agent_error(&err), "timeout");
159    }
160
161    #[test]
162    fn failover_config_defaults() {
163        let config = FailoverConfig::default();
164        assert_eq!(config.retry_limit, 2);
165        assert!(config.failover_on.contains(&"timeout".to_string()));
166        assert!(config.failover_on.contains(&"rate_limited".to_string()));
167        assert!(config.failover_on.contains(&"http_5xx".to_string()));
168        assert!(config.failover_on.contains(&"connect_error".to_string()));
169    }
170}