1pub fn classify_error(err: &dyn std::error::Error) -> &'static str {
11 let msg = err.to_string().to_lowercase();
12
13 if msg.contains("unknown tool") || msg.contains("not found") && msg.contains("tool") {
14 "tool.not_found"
15 } else if msg.contains("scope") || msg.contains("access denied") {
16 "auth.scope_denied"
17 } else if msg.contains("expired") {
18 "auth.expired"
19 } else if msg.contains("key not found")
20 || msg.contains("missing key")
21 || msg.contains("no keys found")
22 {
23 "auth.missing_key"
24 } else if msg.contains("timeout") {
25 "provider.timeout"
26 } else if msg.contains("upstream") || msg.contains("bad gateway") || msg.contains("mcp error") {
27 "provider.upstream_error"
28 } else if msg.contains("provider") && msg.contains("not found") {
29 "provider.not_found"
30 } else if msg.contains("missing") || msg.contains("required") {
31 "input.missing_arg"
32 } else if msg.contains("invalid") || msg.contains("parse") {
33 "input.invalid_value"
34 } else if msg.contains("rate limit") || msg.contains("rate.exceeded") {
35 "rate.exceeded"
36 } else {
37 "tool.execution_failed"
38 }
39}
40
41pub fn exit_code_for_error(err: &dyn std::error::Error) -> i32 {
43 let code = classify_error(err);
44 match code.split('.').next().unwrap_or("") {
45 "input" => 2,
46 "auth" => 3,
47 "provider" => 4,
48 "rate" => 5,
49 _ => 1,
50 }
51}
52
53pub fn format_structured_error(err: &dyn std::error::Error, verbose: bool) -> String {
55 let code = classify_error(err);
56 let exit = exit_code_for_error(err);
57 let message = err.to_string();
58
59 let mut error_obj = serde_json::json!({
60 "error": {
61 "code": code,
62 "message": message,
63 "exit_code": exit,
64 }
65 });
66
67 if verbose {
68 let mut chain = Vec::new();
69 let mut source = std::error::Error::source(err);
70 while let Some(cause) = source {
71 chain.push(cause.to_string());
72 source = std::error::Error::source(cause);
73 }
74 if !chain.is_empty() {
75 error_obj["error"]["chain"] = serde_json::json!(chain);
76 }
77 }
78
79 serde_json::to_string(&error_obj).unwrap_or_else(|_| {
80 format!(
81 "{{\"error\":{{\"code\":\"{code}\",\"message\":\"{message}\",\"exit_code\":{exit}}}}}"
82 )
83 })
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89
90 #[test]
91 fn test_classify_unknown_tool() {
92 let err: Box<dyn std::error::Error> = "Unknown tool: 'foo'".into();
93 assert_eq!(classify_error(&*err), "tool.not_found");
94 }
95
96 #[test]
97 fn test_classify_scope_denied() {
98 let err: Box<dyn std::error::Error> = "Access denied: scope check failed".into();
99 assert_eq!(classify_error(&*err), "auth.scope_denied");
100 }
101
102 #[test]
103 fn test_classify_expired() {
104 let err: Box<dyn std::error::Error> = "Token expired".into();
105 assert_eq!(classify_error(&*err), "auth.expired");
106 }
107
108 #[test]
109 fn test_classify_generic() {
110 let err: Box<dyn std::error::Error> = "something went wrong".into();
111 assert_eq!(classify_error(&*err), "tool.execution_failed");
112 }
113
114 #[test]
115 fn test_exit_codes() {
116 let input_err: Box<dyn std::error::Error> = "missing required argument".into();
117 assert_eq!(exit_code_for_error(&*input_err), 2);
118
119 let auth_err: Box<dyn std::error::Error> = "Token expired at 12345".into();
120 assert_eq!(exit_code_for_error(&*auth_err), 3);
121
122 let provider_err: Box<dyn std::error::Error> = "upstream API timeout".into();
123 assert_eq!(exit_code_for_error(&*provider_err), 4);
124 }
125
126 #[test]
127 fn test_format_structured_error() {
128 let err: Box<dyn std::error::Error> = "Unknown tool: 'nonexistent'".into();
129 let json_str = format_structured_error(&*err, false);
130 let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
131 assert_eq!(parsed["error"]["code"], "tool.not_found");
132 assert_eq!(parsed["error"]["exit_code"], 1);
133 assert!(parsed["error"]["message"]
134 .as_str()
135 .unwrap()
136 .contains("nonexistent"));
137 }
138}