Skip to main content

ati/core/
error.rs

1//! Structured error codes and JSON error formatting for agent-first UX.
2//!
3//! Error code taxonomy (dot-separated):
4//! - `input.missing_arg`, `input.invalid_value`
5//! - `auth.expired`, `auth.scope_denied`, `auth.missing_key`
6//! - `provider.timeout`, `provider.upstream_error`, `provider.not_found`
7//! - `tool.not_found`, `tool.execution_failed`
8
9/// Classify an error into a dot-separated error code by inspecting its message.
10pub fn classify_error(err: &dyn std::error::Error) -> &'static str {
11    let msg = err.to_string().to_lowercase();
12
13    if msg.contains("unknown tool") || msg.contains("not found") && msg.contains("tool") {
14        "tool.not_found"
15    } else if msg.contains("scope") || msg.contains("access denied") {
16        "auth.scope_denied"
17    } else if msg.contains("expired") {
18        "auth.expired"
19    } else if msg.contains("key not found")
20        || msg.contains("missing key")
21        || msg.contains("no keys found")
22    {
23        "auth.missing_key"
24    } else if msg.contains("timeout") {
25        "provider.timeout"
26    } else if msg.contains("upstream") || msg.contains("bad gateway") || msg.contains("mcp error") {
27        "provider.upstream_error"
28    } else if msg.contains("provider") && msg.contains("not found") {
29        "provider.not_found"
30    } else if msg.contains("missing") || msg.contains("required") {
31        "input.missing_arg"
32    } else if msg.contains("invalid") || msg.contains("parse") {
33        "input.invalid_value"
34    } else if msg.contains("rate limit") || msg.contains("rate.exceeded") {
35        "rate.exceeded"
36    } else {
37        "tool.execution_failed"
38    }
39}
40
41/// Map an error code to a process exit code.
42pub fn exit_code_for_error(err: &dyn std::error::Error) -> i32 {
43    let code = classify_error(err);
44    match code.split('.').next().unwrap_or("") {
45        "input" => 2,
46        "auth" => 3,
47        "provider" => 4,
48        "rate" => 5,
49        _ => 1,
50    }
51}
52
53/// Format a structured JSON error string for --output json mode.
54pub fn format_structured_error(err: &dyn std::error::Error, verbose: bool) -> String {
55    let code = classify_error(err);
56    let exit = exit_code_for_error(err);
57    let message = err.to_string();
58
59    let mut error_obj = serde_json::json!({
60        "error": {
61            "code": code,
62            "message": message,
63            "exit_code": exit,
64        }
65    });
66
67    if verbose {
68        let mut chain = Vec::new();
69        let mut source = std::error::Error::source(err);
70        while let Some(cause) = source {
71            chain.push(cause.to_string());
72            source = std::error::Error::source(cause);
73        }
74        if !chain.is_empty() {
75            error_obj["error"]["chain"] = serde_json::json!(chain);
76        }
77    }
78
79    serde_json::to_string(&error_obj).unwrap_or_else(|_| {
80        format!(
81            "{{\"error\":{{\"code\":\"{code}\",\"message\":\"{message}\",\"exit_code\":{exit}}}}}"
82        )
83    })
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89
90    #[test]
91    fn test_classify_unknown_tool() {
92        let err: Box<dyn std::error::Error> = "Unknown tool: 'foo'".into();
93        assert_eq!(classify_error(&*err), "tool.not_found");
94    }
95
96    #[test]
97    fn test_classify_scope_denied() {
98        let err: Box<dyn std::error::Error> = "Access denied: scope check failed".into();
99        assert_eq!(classify_error(&*err), "auth.scope_denied");
100    }
101
102    #[test]
103    fn test_classify_expired() {
104        let err: Box<dyn std::error::Error> = "Token expired".into();
105        assert_eq!(classify_error(&*err), "auth.expired");
106    }
107
108    #[test]
109    fn test_classify_generic() {
110        let err: Box<dyn std::error::Error> = "something went wrong".into();
111        assert_eq!(classify_error(&*err), "tool.execution_failed");
112    }
113
114    #[test]
115    fn test_exit_codes() {
116        let input_err: Box<dyn std::error::Error> = "missing required argument".into();
117        assert_eq!(exit_code_for_error(&*input_err), 2);
118
119        let auth_err: Box<dyn std::error::Error> = "Token expired at 12345".into();
120        assert_eq!(exit_code_for_error(&*auth_err), 3);
121
122        let provider_err: Box<dyn std::error::Error> = "upstream API timeout".into();
123        assert_eq!(exit_code_for_error(&*provider_err), 4);
124    }
125
126    #[test]
127    fn test_format_structured_error() {
128        let err: Box<dyn std::error::Error> = "Unknown tool: 'nonexistent'".into();
129        let json_str = format_structured_error(&*err, false);
130        let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
131        assert_eq!(parsed["error"]["code"], "tool.not_found");
132        assert_eq!(parsed["error"]["exit_code"], 1);
133        assert!(parsed["error"]["message"]
134            .as_str()
135            .unwrap()
136            .contains("nonexistent"));
137    }
138}