1use super::VmValue;
2
3#[derive(Debug, Clone)]
4pub enum VmError {
5 StackUnderflow,
6 StackOverflow,
7 UndefinedVariable(String),
8 UndefinedBuiltin(String),
9 ImmutableAssignment(String),
10 TypeError(String),
11 Runtime(String),
12 DivisionByZero,
13 Thrown(VmValue),
14 CategorizedError {
16 message: String,
17 category: ErrorCategory,
18 },
19 Return(VmValue),
20 InvalidInstruction(u8),
21}
22
23#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum ErrorCategory {
26 Timeout,
28 Auth,
30 RateLimit,
32 Overloaded,
36 ServerError,
38 TransientNetwork,
41 SchemaValidation,
43 ToolError,
45 ToolRejected,
47 Cancelled,
49 NotFound,
51 CircuitOpen,
53 Generic,
55}
56
57impl ErrorCategory {
58 pub fn as_str(&self) -> &'static str {
59 match self {
60 ErrorCategory::Timeout => "timeout",
61 ErrorCategory::Auth => "auth",
62 ErrorCategory::RateLimit => "rate_limit",
63 ErrorCategory::Overloaded => "overloaded",
64 ErrorCategory::ServerError => "server_error",
65 ErrorCategory::TransientNetwork => "transient_network",
66 ErrorCategory::SchemaValidation => "schema_validation",
67 ErrorCategory::ToolError => "tool_error",
68 ErrorCategory::ToolRejected => "tool_rejected",
69 ErrorCategory::Cancelled => "cancelled",
70 ErrorCategory::NotFound => "not_found",
71 ErrorCategory::CircuitOpen => "circuit_open",
72 ErrorCategory::Generic => "generic",
73 }
74 }
75
76 pub fn parse(s: &str) -> Self {
77 match s {
78 "timeout" => ErrorCategory::Timeout,
79 "auth" => ErrorCategory::Auth,
80 "rate_limit" => ErrorCategory::RateLimit,
81 "overloaded" => ErrorCategory::Overloaded,
82 "server_error" => ErrorCategory::ServerError,
83 "transient_network" => ErrorCategory::TransientNetwork,
84 "schema_validation" => ErrorCategory::SchemaValidation,
85 "tool_error" => ErrorCategory::ToolError,
86 "tool_rejected" => ErrorCategory::ToolRejected,
87 "cancelled" => ErrorCategory::Cancelled,
88 "not_found" => ErrorCategory::NotFound,
89 "circuit_open" => ErrorCategory::CircuitOpen,
90 _ => ErrorCategory::Generic,
91 }
92 }
93
94 pub fn is_transient(&self) -> bool {
98 matches!(
99 self,
100 ErrorCategory::Timeout
101 | ErrorCategory::RateLimit
102 | ErrorCategory::Overloaded
103 | ErrorCategory::ServerError
104 | ErrorCategory::TransientNetwork
105 )
106 }
107}
108
109pub fn categorized_error(message: impl Into<String>, category: ErrorCategory) -> VmError {
111 VmError::CategorizedError {
112 message: message.into(),
113 category,
114 }
115}
116
117pub fn error_to_category(err: &VmError) -> ErrorCategory {
126 match err {
127 VmError::CategorizedError { category, .. } => category.clone(),
128 VmError::Thrown(VmValue::Dict(d)) => d
129 .get("category")
130 .map(|v| ErrorCategory::parse(&v.display()))
131 .unwrap_or(ErrorCategory::Generic),
132 VmError::Thrown(VmValue::String(s)) => classify_error_message(s),
133 VmError::Runtime(msg) => classify_error_message(msg),
134 _ => ErrorCategory::Generic,
135 }
136}
137
138pub fn classify_error_message(msg: &str) -> ErrorCategory {
141 if let Some(cat) = classify_by_http_status(msg) {
143 return cat;
144 }
145 if msg.contains("Deadline exceeded") || msg.contains("context deadline exceeded") {
148 return ErrorCategory::Timeout;
149 }
150 if msg.contains("overloaded_error") {
151 return ErrorCategory::Overloaded;
153 }
154 if msg.contains("api_error") {
155 return ErrorCategory::ServerError;
157 }
158 if msg.contains("insufficient_quota") || msg.contains("billing_hard_limit_reached") {
159 return ErrorCategory::RateLimit;
161 }
162 if msg.contains("invalid_api_key") || msg.contains("authentication_error") {
163 return ErrorCategory::Auth;
164 }
165 if msg.contains("not_found_error") || msg.contains("model_not_found") {
166 return ErrorCategory::NotFound;
167 }
168 if msg.contains("circuit_open") {
169 return ErrorCategory::CircuitOpen;
170 }
171 let lower = msg.to_lowercase();
173 if lower.contains("connection reset")
174 || lower.contains("connection refused")
175 || lower.contains("connection closed")
176 || lower.contains("broken pipe")
177 || lower.contains("dns error")
178 || lower.contains("stream error")
179 || lower.contains("unexpected eof")
180 {
181 return ErrorCategory::TransientNetwork;
182 }
183 ErrorCategory::Generic
184}
185
186fn classify_by_http_status(msg: &str) -> Option<ErrorCategory> {
190 for code in extract_http_status_codes(msg) {
193 return Some(match code {
194 401 | 403 => ErrorCategory::Auth,
195 404 | 410 => ErrorCategory::NotFound,
196 408 | 504 | 522 | 524 => ErrorCategory::Timeout,
197 429 => ErrorCategory::RateLimit,
198 503 | 529 => ErrorCategory::Overloaded,
199 500 | 502 => ErrorCategory::ServerError,
200 _ => continue,
201 });
202 }
203 None
204}
205
206fn extract_http_status_codes(msg: &str) -> Vec<u16> {
208 let mut codes = Vec::new();
209 let bytes = msg.as_bytes();
210 for i in 0..bytes.len().saturating_sub(2) {
211 if bytes[i].is_ascii_digit()
213 && bytes[i + 1].is_ascii_digit()
214 && bytes[i + 2].is_ascii_digit()
215 {
216 let before_ok = i == 0 || !bytes[i - 1].is_ascii_digit();
218 let after_ok = i + 3 >= bytes.len() || !bytes[i + 3].is_ascii_digit();
219 if before_ok && after_ok {
220 if let Ok(code) = msg[i..i + 3].parse::<u16>() {
221 if (400..=599).contains(&code) {
222 codes.push(code);
223 }
224 }
225 }
226 }
227 }
228 codes
229}
230
231impl std::fmt::Display for VmError {
232 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233 match self {
234 VmError::StackUnderflow => write!(f, "Stack underflow"),
235 VmError::StackOverflow => write!(f, "Stack overflow: too many nested calls"),
236 VmError::UndefinedVariable(n) => write!(f, "Undefined variable: {n}"),
237 VmError::UndefinedBuiltin(n) => write!(f, "Undefined builtin: {n}"),
238 VmError::ImmutableAssignment(n) => {
239 write!(f, "Cannot assign to immutable binding: {n}")
240 }
241 VmError::TypeError(msg) => write!(f, "Type error: {msg}"),
242 VmError::Runtime(msg) => write!(f, "Runtime error: {msg}"),
243 VmError::DivisionByZero => write!(f, "Division by zero"),
244 VmError::Thrown(v) => write!(f, "Thrown: {}", v.display()),
245 VmError::CategorizedError { message, category } => {
246 write!(f, "Error [{}]: {}", category.as_str(), message)
247 }
248 VmError::Return(_) => write!(f, "Return from function"),
249 VmError::InvalidInstruction(op) => write!(f, "Invalid instruction: 0x{op:02x}"),
250 }
251 }
252}
253
254impl std::error::Error for VmError {}