1use thiserror::Error;
2
3#[derive(Debug, Error)]
20#[non_exhaustive]
21pub enum GenaiError {
22 #[error("HTTP request error: {0}")]
23 Http(#[from] reqwest::Error),
24 #[error("SSE parsing error: {0}")]
25 Parse(String),
26 #[error("JSON deserialization error: {0}")]
27 Json(#[from] serde_json::Error),
28 #[error("UTF-8 decoding error: {0}")]
29 Utf8(#[from] std::str::Utf8Error),
30 #[error("API error (HTTP {status_code}): {message}")]
35 Api {
36 status_code: u16,
38 message: String,
40 request_id: Option<String>,
42 retry_after: Option<std::time::Duration>,
51 },
52 #[error("Internal client error: {0}")]
53 Internal(String),
54 #[error("Invalid input: {0}")]
55 InvalidInput(String),
56 #[error("Malformed API response: {0}")]
63 MalformedResponse(String),
64 #[error("Request timed out after {0:?}")]
70 Timeout(std::time::Duration),
71 #[error("Failed to build HTTP client: {0}")]
76 ClientBuild(String),
77}
78
79impl GenaiError {
80 #[must_use]
137 pub fn is_retryable(&self) -> bool {
138 match self {
139 GenaiError::Http(_) => true,
141
142 GenaiError::Api { status_code, .. } => *status_code == 429 || *status_code >= 500,
144
145 GenaiError::Timeout(_) => true,
147
148 GenaiError::Parse(_)
150 | GenaiError::Json(_)
151 | GenaiError::Utf8(_)
152 | GenaiError::Internal(_)
153 | GenaiError::InvalidInput(_)
154 | GenaiError::MalformedResponse(_)
155 | GenaiError::ClientBuild(_) => false,
156 }
157 }
158
159 #[must_use]
186 pub fn retry_after(&self) -> Option<std::time::Duration> {
187 match self {
188 GenaiError::Api { retry_after, .. } => *retry_after,
189 _ => None,
190 }
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 #[test]
199 fn test_genai_error_parse_display() {
200 let error = GenaiError::Parse("Invalid SSE format".to_string());
201 let display = format!("{}", error);
202 assert!(display.contains("SSE parsing error"));
203 assert!(display.contains("Invalid SSE format"));
204 }
205
206 #[test]
207 fn test_genai_error_api_display() {
208 let error = GenaiError::Api {
209 status_code: 429,
210 message: "Rate limited".to_string(),
211 request_id: Some("req-123".to_string()),
212 retry_after: None,
213 };
214 let display = format!("{}", error);
215 assert!(display.contains("429"));
216 assert!(display.contains("Rate limited"));
217 }
218
219 #[test]
220 fn test_genai_error_api_without_request_id() {
221 let error = GenaiError::Api {
222 status_code: 500,
223 message: "Internal error".to_string(),
224 request_id: None,
225 retry_after: None,
226 };
227 let display = format!("{}", error);
228 assert!(display.contains("500"));
229 assert!(display.contains("Internal error"));
230 }
231
232 #[test]
233 fn test_genai_error_internal_display() {
234 let error = GenaiError::Internal("Max function loops exceeded".to_string());
235 let display = format!("{}", error);
236 assert!(display.contains("Internal client error"));
237 assert!(display.contains("Max function loops exceeded"));
238 }
239
240 #[test]
241 fn test_genai_error_invalid_input_display() {
242 let error = GenaiError::InvalidInput("Missing model or agent".to_string());
243 let display = format!("{}", error);
244 assert!(display.contains("Invalid input"));
245 assert!(display.contains("Missing model or agent"));
246 }
247
248 #[test]
249 fn test_genai_error_json_from() {
250 let json_str = "not valid json";
251 let json_err = serde_json::from_str::<serde_json::Value>(json_str).unwrap_err();
252 let genai_err: GenaiError = json_err.into();
253 let display = format!("{}", genai_err);
254 assert!(display.contains("JSON deserialization error"));
255 }
256
257 #[test]
258 fn test_genai_error_utf8_from() {
259 let bytes = vec![0xff, 0xfe];
261 let utf8_err = std::str::from_utf8(&bytes).unwrap_err();
262 let genai_err: GenaiError = utf8_err.into();
263 let display = format!("{}", genai_err);
264 assert!(display.contains("UTF-8 decoding error"));
265 }
266
267 #[test]
268 fn test_genai_error_debug_format() {
269 let error = GenaiError::Api {
270 status_code: 400,
271 message: "Bad request".to_string(),
272 request_id: Some("req-456".to_string()),
273 retry_after: None,
274 };
275 let debug = format!("{:?}", error);
276 assert!(debug.contains("Api"));
277 assert!(debug.contains("400"));
278 assert!(debug.contains("req-456"));
279 }
280
281 #[test]
282 fn test_genai_error_api_status_codes() {
283 let status_codes = [
285 (400, "Bad Request"),
286 (401, "Unauthorized"),
287 (403, "Forbidden"),
288 (404, "Not Found"),
289 (429, "Too Many Requests"),
290 (500, "Internal Server Error"),
291 (503, "Service Unavailable"),
292 ];
293
294 for (code, message) in status_codes {
295 let error = GenaiError::Api {
296 status_code: code,
297 message: message.to_string(),
298 request_id: None,
299 retry_after: None,
300 };
301 let display = format!("{}", error);
302 assert!(
303 display.contains(&code.to_string()),
304 "Expected {} in display: {}",
305 code,
306 display
307 );
308 }
309 }
310
311 #[test]
312 fn test_genai_error_api_with_empty_message() {
313 let error = GenaiError::Api {
315 status_code: 500,
316 message: "".to_string(),
317 request_id: None,
318 retry_after: None,
319 };
320 let display = format!("{}", error);
321 assert!(display.contains("500"));
322 assert!(display.contains("API error"));
324 }
325
326 #[test]
327 fn test_genai_error_malformed_response_display() {
328 let error = GenaiError::MalformedResponse(
329 "Function call 'get_weather' is missing required call_id field".to_string(),
330 );
331 let display = format!("{}", error);
332 assert!(display.contains("Malformed API response"));
333 assert!(display.contains("call_id"));
334 }
335
336 #[test]
337 fn test_genai_error_malformed_response_stream() {
338 let error =
339 GenaiError::MalformedResponse("Stream ended without Complete event".to_string());
340 let display = format!("{}", error);
341 assert!(display.contains("Malformed API response"));
342 assert!(display.contains("Complete event"));
343 }
344
345 #[test]
346 fn test_genai_error_timeout_display() {
347 let error = GenaiError::Timeout(std::time::Duration::from_secs(30));
348 let display = format!("{}", error);
349 assert!(display.contains("Request timed out"));
350 assert!(display.contains("30s"));
351 }
352
353 #[test]
354 fn test_genai_error_timeout_debug() {
355 let error = GenaiError::Timeout(std::time::Duration::from_millis(500));
356 let debug = format!("{:?}", error);
357 assert!(debug.contains("Timeout"));
358 assert!(debug.contains("500ms"));
359 }
360
361 #[test]
362 fn test_genai_error_client_build_display() {
363 let error = GenaiError::ClientBuild("TLS initialization failed".to_string());
364 let display = format!("{}", error);
365 assert!(display.contains("Failed to build HTTP client"));
366 assert!(display.contains("TLS initialization failed"));
367 }
368
369 #[test]
370 fn test_genai_error_client_build_debug() {
371 let error = GenaiError::ClientBuild("some error".to_string());
372 let debug = format!("{:?}", error);
373 assert!(debug.contains("ClientBuild"));
374 assert!(debug.contains("some error"));
375 }
376
377 #[test]
382 fn test_is_retryable_rate_limit_429() {
383 let error = GenaiError::Api {
384 status_code: 429,
385 message: "Resource exhausted".to_string(),
386 request_id: None,
387 retry_after: Some(std::time::Duration::from_secs(60)),
388 };
389 assert!(error.is_retryable(), "429 errors should be retryable");
390 }
391
392 #[test]
393 fn test_is_retryable_server_errors_5xx() {
394 for status_code in [500, 502, 503, 504] {
395 let error = GenaiError::Api {
396 status_code,
397 message: "Server error".to_string(),
398 request_id: None,
399 retry_after: None,
400 };
401 assert!(
402 error.is_retryable(),
403 "{} errors should be retryable",
404 status_code
405 );
406 }
407 }
408
409 #[test]
410 fn test_is_retryable_client_errors_4xx_not_retryable() {
411 for status_code in [400, 401, 403, 404, 422] {
413 let error = GenaiError::Api {
414 status_code,
415 message: "Client error".to_string(),
416 request_id: None,
417 retry_after: None,
418 };
419 assert!(
420 !error.is_retryable(),
421 "{} errors should NOT be retryable",
422 status_code
423 );
424 }
425 }
426
427 #[test]
428 fn test_is_retryable_timeout() {
429 let error = GenaiError::Timeout(std::time::Duration::from_secs(30));
430 assert!(error.is_retryable(), "Timeout errors should be retryable");
431 }
432
433 #[test]
434 fn test_is_retryable_parse_error_not_retryable() {
435 let error = GenaiError::Parse("Invalid SSE".to_string());
436 assert!(
437 !error.is_retryable(),
438 "Parse errors should NOT be retryable"
439 );
440 }
441
442 #[test]
443 fn test_is_retryable_json_error_not_retryable() {
444 let json_str = "not valid json";
445 let json_err = serde_json::from_str::<serde_json::Value>(json_str).unwrap_err();
446 let error: GenaiError = json_err.into();
447 assert!(!error.is_retryable(), "JSON errors should NOT be retryable");
448 }
449
450 #[test]
451 fn test_is_retryable_invalid_input_not_retryable() {
452 let error = GenaiError::InvalidInput("Missing model".to_string());
453 assert!(
454 !error.is_retryable(),
455 "InvalidInput errors should NOT be retryable"
456 );
457 }
458
459 #[test]
460 fn test_is_retryable_malformed_response_not_retryable() {
461 let error = GenaiError::MalformedResponse("Missing call_id".to_string());
462 assert!(
463 !error.is_retryable(),
464 "MalformedResponse errors should NOT be retryable"
465 );
466 }
467
468 #[test]
469 fn test_is_retryable_internal_error_not_retryable() {
470 let error = GenaiError::Internal("Max loops exceeded".to_string());
471 assert!(
472 !error.is_retryable(),
473 "Internal errors should NOT be retryable"
474 );
475 }
476
477 #[test]
478 fn test_is_retryable_client_build_not_retryable() {
479 let error = GenaiError::ClientBuild("TLS init failed".to_string());
480 assert!(
481 !error.is_retryable(),
482 "ClientBuild errors should NOT be retryable"
483 );
484 }
485
486 #[test]
487 fn test_is_retryable_utf8_error_not_retryable() {
488 let bytes = vec![0xff, 0xfe];
489 let utf8_err = std::str::from_utf8(&bytes).unwrap_err();
490 let error: GenaiError = utf8_err.into();
491 assert!(
492 !error.is_retryable(),
493 "UTF-8 errors should NOT be retryable"
494 );
495 }
496}