1pub type Result<T> = std::result::Result<T, Error>;
5
6#[derive(Debug, thiserror::Error)]
8pub enum Error {
9 #[error("HTTP request failed: {0}")]
11 Http(#[from] reqwest::Error),
12
13 #[error("JSON error: {0}")]
15 Json(#[from] serde_json::Error),
16
17 #[error("Invalid URL: {0}")]
19 Url(#[from] url::ParseError),
20
21 #[error("API error ({status}): {message}")]
23 Api {
24 status: u16,
26 message: String,
28 error_type: Option<String>,
30 },
31
32 #[error("Authentication failed: {0}")]
34 Authentication(String),
35
36 #[error(
38 "{}",
39 format_rate_limited_error(retry_after.as_ref().copied(), message.as_deref())
40 )]
41 RateLimited {
42 retry_after: Option<u64>,
44 message: Option<String>,
46 },
47
48 #[error("Request timed out")]
50 Timeout,
51
52 #[error("Invalid request: {0}")]
54 InvalidRequest(String),
55
56 #[cfg(feature = "realtime")]
58 #[error("WebSocket error: {0}")]
59 WebSocket(Box<tokio_tungstenite::tungstenite::Error>),
60
61 #[error("Stream error: {0}")]
63 Stream(String),
64
65 #[error("Base64 decode error: {0}")]
67 Base64(#[from] base64::DecodeError),
68
69 #[error("Configuration error: {0}")]
71 Config(String),
72
73 #[error("I/O error: {0}")]
75 Io(#[from] std::io::Error),
76}
77
78#[cfg(feature = "realtime")]
79impl From<tokio_tungstenite::tungstenite::Error> for Error {
80 fn from(err: tokio_tungstenite::tungstenite::Error) -> Self {
81 Error::WebSocket(Box::new(err))
82 }
83}
84
85#[derive(Debug, serde::Deserialize)]
87pub(crate) struct ApiErrorResponse {
88 pub error: ApiErrorDetail,
90}
91
92#[derive(Debug, serde::Deserialize)]
94#[allow(dead_code)]
95pub(crate) struct ApiErrorDetail {
96 pub message: String,
98 #[serde(rename = "type")]
100 pub error_type: Option<String>,
101 pub code: Option<String>,
103}
104
105const ERROR_BODY_SNIPPET_LIMIT: usize = 4096;
106
107fn format_rate_limited_error(retry_after: Option<u64>, message: Option<&str>) -> String {
108 let mut out = match retry_after {
109 Some(seconds) => format!("Rate limit exceeded. Retry after {seconds} seconds."),
110 None => "Rate limit exceeded.".to_string(),
111 };
112
113 if let Some(message) = message.filter(|m| !m.trim().is_empty()) {
114 out.push_str(" Server message: ");
115 out.push_str(message);
116 }
117
118 out
119}
120
121fn body_snippet(body: &str) -> Option<String> {
122 let trimmed = body.trim();
123 if trimmed.is_empty() {
124 return None;
125 }
126
127 let mut chars = trimmed.chars();
128 let snippet: String = chars.by_ref().take(ERROR_BODY_SNIPPET_LIMIT).collect();
129
130 if chars.next().is_some() {
131 Some(format!("{snippet}...[truncated]"))
132 } else {
133 Some(snippet)
134 }
135}
136
137impl Error {
138 pub async fn from_response(response: reqwest::Response) -> Self {
140 let status = response.status().as_u16();
141 let retry_after = response
142 .headers()
143 .get("retry-after")
144 .and_then(|v| v.to_str().ok())
145 .and_then(|v| v.parse().ok());
146
147 let body_text = match response.bytes().await {
148 Ok(bytes) => String::from_utf8_lossy(&bytes).into_owned(),
149 Err(_) => String::new(),
150 };
151
152 let parsed_error = serde_json::from_str::<ApiErrorResponse>(&body_text).ok();
153
154 if status == 429 {
156 let message = parsed_error
157 .as_ref()
158 .map(|api_error| api_error.error.message.clone())
159 .or_else(|| body_snippet(&body_text));
160
161 return Error::RateLimited {
162 retry_after,
163 message,
164 };
165 }
166
167 match parsed_error {
169 Some(api_error) => Error::Api {
170 status,
171 message: api_error.error.message,
172 error_type: api_error.error.error_type,
173 },
174 None => Error::Api {
175 status,
176 message: body_snippet(&body_text)
177 .map(|snippet| format!("HTTP {status}: {snippet}"))
178 .unwrap_or_else(|| format!("HTTP {status}")),
179 error_type: None,
180 },
181 }
182 }
183
184 pub fn is_retryable(&self) -> bool {
186 match self {
187 Error::RateLimited { .. } => true,
188 Error::Timeout => true,
189 Error::Api { status, .. } => {
190 *status >= 500 && *status < 600
192 }
193 Error::Http(e) => e.is_timeout() || e.is_connect(),
194 _ => false,
195 }
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 #[test]
206 fn rate_limited_is_retryable() {
207 let err = Error::RateLimited {
208 retry_after: Some(30),
209 message: None,
210 };
211 assert!(err.is_retryable());
212 }
213
214 #[test]
215 fn rate_limited_without_retry_after_is_retryable() {
216 let err = Error::RateLimited {
217 retry_after: None,
218 message: None,
219 };
220 assert!(err.is_retryable());
221 }
222
223 #[test]
224 fn timeout_is_retryable() {
225 let err = Error::Timeout;
226 assert!(err.is_retryable());
227 }
228
229 #[test]
230 fn api_500_is_retryable() {
231 let err = Error::Api {
232 status: 500,
233 message: "Internal Server Error".to_string(),
234 error_type: None,
235 };
236 assert!(err.is_retryable());
237 }
238
239 #[test]
240 fn api_502_is_retryable() {
241 let err = Error::Api {
242 status: 502,
243 message: "Bad Gateway".to_string(),
244 error_type: None,
245 };
246 assert!(err.is_retryable());
247 }
248
249 #[test]
250 fn api_503_is_retryable() {
251 let err = Error::Api {
252 status: 503,
253 message: "Service Unavailable".to_string(),
254 error_type: None,
255 };
256 assert!(err.is_retryable());
257 }
258
259 #[test]
260 fn api_400_is_not_retryable() {
261 let err = Error::Api {
262 status: 400,
263 message: "Bad Request".to_string(),
264 error_type: None,
265 };
266 assert!(!err.is_retryable());
267 }
268
269 #[test]
270 fn api_401_is_not_retryable() {
271 let err = Error::Api {
272 status: 401,
273 message: "Unauthorized".to_string(),
274 error_type: None,
275 };
276 assert!(!err.is_retryable());
277 }
278
279 #[test]
280 fn api_403_is_not_retryable() {
281 let err = Error::Api {
282 status: 403,
283 message: "Forbidden".to_string(),
284 error_type: None,
285 };
286 assert!(!err.is_retryable());
287 }
288
289 #[test]
290 fn api_404_is_not_retryable() {
291 let err = Error::Api {
292 status: 404,
293 message: "Not Found".to_string(),
294 error_type: None,
295 };
296 assert!(!err.is_retryable());
297 }
298
299 #[test]
300 fn api_422_is_not_retryable() {
301 let err = Error::Api {
302 status: 422,
303 message: "Unprocessable Entity".to_string(),
304 error_type: Some("validation_error".to_string()),
305 };
306 assert!(!err.is_retryable());
307 }
308
309 #[test]
310 fn authentication_error_is_not_retryable() {
311 let err = Error::Authentication("Invalid token".to_string());
312 assert!(!err.is_retryable());
313 }
314
315 #[test]
316 fn invalid_request_is_not_retryable() {
317 let err = Error::InvalidRequest("Missing field".to_string());
318 assert!(!err.is_retryable());
319 }
320
321 #[test]
322 fn json_error_is_not_retryable() {
323 let serde_err = serde_json::from_str::<serde_json::Value>("bad json").unwrap_err();
324 let err = Error::Json(serde_err);
325 assert!(!err.is_retryable());
326 }
327
328 #[test]
329 fn stream_error_is_not_retryable() {
330 let err = Error::Stream("connection reset".to_string());
331 assert!(!err.is_retryable());
332 }
333
334 #[test]
335 fn config_error_is_not_retryable() {
336 let err = Error::Config("Missing API key".to_string());
337 assert!(!err.is_retryable());
338 }
339
340 #[test]
343 fn error_display_api() {
344 let err = Error::Api {
345 status: 404,
346 message: "Not Found".to_string(),
347 error_type: None,
348 };
349 assert_eq!(format!("{err}"), "API error (404): Not Found");
350 }
351
352 #[test]
353 fn error_display_rate_limited_with_retry_and_message() {
354 let err = Error::RateLimited {
355 retry_after: Some(60),
356 message: Some("Too many requests".to_string()),
357 };
358 let display = format!("{err}");
359 assert_eq!(
360 display,
361 "Rate limit exceeded. Retry after 60 seconds. Server message: Too many requests"
362 );
363 assert!(!display.contains("Some("));
364 assert!(!display.contains("None"));
365 }
366
367 #[test]
368 fn error_display_rate_limited_without_retry_or_message() {
369 let err = Error::RateLimited {
370 retry_after: None,
371 message: None,
372 };
373 let display = format!("{err}");
374 assert_eq!(display, "Rate limit exceeded.");
375 assert!(!display.contains("Some("));
376 assert!(!display.contains("None"));
377 }
378
379 #[tokio::test]
380 async fn from_response_parses_api_error_body() {
381 use wiremock::matchers::{method, path};
382 use wiremock::{Mock, MockServer, ResponseTemplate};
383
384 let server = MockServer::start().await;
385 Mock::given(method("GET"))
386 .and(path("/err-json"))
387 .respond_with(ResponseTemplate::new(400).set_body_json(serde_json::json!({
388 "error": {
389 "message": "Model not found",
390 "type": "invalid_request_error"
391 }
392 })))
393 .mount(&server)
394 .await;
395
396 let response = reqwest::Client::new()
397 .get(format!("{}/err-json", server.uri()))
398 .send()
399 .await
400 .unwrap();
401 let err = Error::from_response(response).await;
402
403 match err {
404 Error::Api {
405 status,
406 message,
407 error_type,
408 } => {
409 assert_eq!(status, 400);
410 assert_eq!(message, "Model not found");
411 assert_eq!(error_type.as_deref(), Some("invalid_request_error"));
412 }
413 _ => panic!("expected Error::Api"),
414 }
415 }
416
417 #[tokio::test]
418 async fn from_response_non_json_includes_snippet() {
419 use wiremock::matchers::{method, path};
420 use wiremock::{Mock, MockServer, ResponseTemplate};
421
422 let server = MockServer::start().await;
423 Mock::given(method("GET"))
424 .and(path("/err-text"))
425 .respond_with(ResponseTemplate::new(500).set_body_string("upstream exploded"))
426 .mount(&server)
427 .await;
428
429 let response = reqwest::Client::new()
430 .get(format!("{}/err-text", server.uri()))
431 .send()
432 .await
433 .unwrap();
434 let err = Error::from_response(response).await;
435
436 match err {
437 Error::Api {
438 status,
439 message,
440 error_type,
441 } => {
442 assert_eq!(status, 500);
443 assert!(message.contains("HTTP 500"));
444 assert!(message.contains("upstream exploded"));
445 assert!(error_type.is_none());
446 }
447 _ => panic!("expected Error::Api"),
448 }
449 }
450
451 #[tokio::test]
452 async fn from_response_non_json_truncates_long_body() {
453 use wiremock::matchers::{method, path};
454 use wiremock::{Mock, MockServer, ResponseTemplate};
455
456 let long_body = "a".repeat(ERROR_BODY_SNIPPET_LIMIT + 128);
457 let server = MockServer::start().await;
458 Mock::given(method("GET"))
459 .and(path("/err-long"))
460 .respond_with(ResponseTemplate::new(502).set_body_string(long_body))
461 .mount(&server)
462 .await;
463
464 let response = reqwest::Client::new()
465 .get(format!("{}/err-long", server.uri()))
466 .send()
467 .await
468 .unwrap();
469 let err = Error::from_response(response).await;
470
471 match err {
472 Error::Api {
473 status, message, ..
474 } => {
475 assert_eq!(status, 502);
476 assert!(message.contains("[truncated]"));
477 }
478 _ => panic!("expected Error::Api"),
479 }
480 }
481
482 #[tokio::test]
483 async fn from_response_429_includes_retry_after_and_message() {
484 use wiremock::matchers::{method, path};
485 use wiremock::{Mock, MockServer, ResponseTemplate};
486
487 let server = MockServer::start().await;
488 Mock::given(method("GET"))
489 .and(path("/err-429"))
490 .respond_with(
491 ResponseTemplate::new(429)
492 .insert_header("retry-after", "7")
493 .set_body_json(serde_json::json!({
494 "error": {
495 "message": "Too many requests"
496 }
497 })),
498 )
499 .mount(&server)
500 .await;
501
502 let response = reqwest::Client::new()
503 .get(format!("{}/err-429", server.uri()))
504 .send()
505 .await
506 .unwrap();
507 let err = Error::from_response(response).await;
508
509 match err {
510 Error::RateLimited {
511 retry_after,
512 message,
513 } => {
514 assert_eq!(retry_after, Some(7));
515 assert_eq!(message.as_deref(), Some("Too many requests"));
516 }
517 _ => panic!("expected Error::RateLimited"),
518 }
519 }
520
521 #[test]
522 fn error_display_timeout() {
523 let err = Error::Timeout;
524 assert_eq!(format!("{err}"), "Request timed out");
525 }
526
527 #[test]
528 fn error_display_authentication() {
529 let err = Error::Authentication("bad key".to_string());
530 assert_eq!(format!("{err}"), "Authentication failed: bad key");
531 }
532
533 #[test]
534 fn error_display_config() {
535 let err = Error::Config("missing key".to_string());
536 assert_eq!(format!("{err}"), "Configuration error: missing key");
537 }
538
539 #[test]
540 fn error_display_stream() {
541 let err = Error::Stream("parse failure".to_string());
542 assert_eq!(format!("{err}"), "Stream error: parse failure");
543 }
544
545 #[test]
548 fn api_error_response_deserialize() {
549 let json = serde_json::json!({
550 "error": {
551 "message": "Model not found",
552 "type": "invalid_request_error",
553 "code": "model_not_found"
554 }
555 });
556 let resp: ApiErrorResponse = serde_json::from_value(json).unwrap();
557 assert_eq!(resp.error.message, "Model not found");
558 assert_eq!(
559 resp.error.error_type.as_deref(),
560 Some("invalid_request_error")
561 );
562 assert_eq!(resp.error.code.as_deref(), Some("model_not_found"));
563 }
564
565 #[test]
566 fn api_error_response_minimal() {
567 let json = serde_json::json!({
568 "error": {
569 "message": "Something went wrong"
570 }
571 });
572 let resp: ApiErrorResponse = serde_json::from_value(json).unwrap();
573 assert_eq!(resp.error.message, "Something went wrong");
574 assert!(resp.error.error_type.is_none());
575 }
576}