1use serde::{Deserialize, Deserializer};
6
7#[derive(Debug, Clone, Default, Deserialize)]
9pub struct ApiErrorDetail {
10 pub message: Option<String>,
12 #[serde(rename = "type")]
14 pub error_type: Option<String>,
15 pub param: Option<String>,
17 #[serde(default, deserialize_with = "lenient_string")]
20 pub code: Option<String>,
21}
22
23fn lenient_string<'de, D: Deserializer<'de>>(deserializer: D) -> Result<Option<String>, D::Error> {
24 let value = Option::<serde_json::Value>::deserialize(deserializer)?;
25 Ok(match value {
26 None | Some(serde_json::Value::Null) => None,
27 Some(serde_json::Value::String(s)) => Some(s),
28 Some(other) => Some(other.to_string()),
29 })
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35#[non_exhaustive]
36pub enum ApiErrorKind {
37 BadRequest,
39 Authentication,
41 PermissionDenied,
43 NotFound,
45 Conflict,
47 UnprocessableEntity,
49 RateLimit,
51 InternalServer,
53 Other,
55}
56
57impl ApiErrorKind {
58 pub fn from_status(status: u16) -> Self {
59 match status {
60 400 => Self::BadRequest,
61 401 => Self::Authentication,
62 403 => Self::PermissionDenied,
63 404 => Self::NotFound,
64 409 => Self::Conflict,
65 422 => Self::UnprocessableEntity,
66 429 => Self::RateLimit,
67 s if s >= 500 => Self::InternalServer,
68 _ => Self::Other,
69 }
70 }
71}
72
73#[derive(Debug, Clone)]
75pub struct ApiError {
76 pub status: u16,
78 pub kind: ApiErrorKind,
80 pub message: String,
82 pub detail: Option<ApiErrorDetail>,
84 pub request_id: Option<String>,
86}
87
88impl ApiError {
89 pub fn is_retryable(&self) -> bool {
92 matches!(self.status, 408 | 409 | 429) || self.status >= 500
93 }
94}
95
96impl std::fmt::Display for ApiError {
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 f.write_str(&self.message)?;
99 if let Some(id) = &self.request_id {
100 write!(f, " (request_id: {id})")?;
101 }
102 Ok(())
103 }
104}
105
106#[derive(Debug, thiserror::Error)]
108#[non_exhaustive]
109pub enum OpenAIError {
110 #[error("configuration error: {0}")]
112 Config(String),
113 #[error("connection error: {0}")]
115 Connection(String),
116 #[error("request timed out")]
118 Timeout,
119 #[error("{0}")]
122 Api(Box<ApiError>),
123 #[error("stream error: {0}")]
125 Stream(String),
126 #[error("JSON error: {0}")]
128 Json(#[from] serde_json::Error),
129 #[error("HTTP error: {0}")]
131 Http(#[from] reqwest::Error),
132}
133
134impl OpenAIError {
135 pub(crate) fn from_response(status: u16, request_id: Option<String>, body: &str) -> Self {
139 let body = body.trim();
140 let json = serde_json::from_str::<serde_json::Value>(body).ok();
141 let detail = json
142 .as_ref()
143 .and_then(|v| {
144 let error = v.get("error").cloned().unwrap_or_else(|| v.clone());
145 serde_json::from_value::<ApiErrorDetail>(error).ok()
146 })
147 .filter(|d| {
148 d.message.is_some() || d.error_type.is_some() || d.param.is_some() || d.code.is_some()
149 });
150
151 let message = if json.is_some() {
154 format!("Error code: {status} - {body}")
155 } else if body.is_empty() {
156 format!("Error code: {status}")
157 } else {
158 body.to_string()
159 };
160
161 Self::Api(Box::new(ApiError {
162 status,
163 kind: ApiErrorKind::from_status(status),
164 message,
165 detail,
166 request_id,
167 }))
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174
175 #[test]
176 fn maps_statuses_to_kinds() {
177 assert_eq!(ApiErrorKind::from_status(400), ApiErrorKind::BadRequest);
178 assert_eq!(ApiErrorKind::from_status(401), ApiErrorKind::Authentication);
179 assert_eq!(ApiErrorKind::from_status(403), ApiErrorKind::PermissionDenied);
180 assert_eq!(ApiErrorKind::from_status(404), ApiErrorKind::NotFound);
181 assert_eq!(ApiErrorKind::from_status(409), ApiErrorKind::Conflict);
182 assert_eq!(ApiErrorKind::from_status(422), ApiErrorKind::UnprocessableEntity);
183 assert_eq!(ApiErrorKind::from_status(429), ApiErrorKind::RateLimit);
184 assert_eq!(ApiErrorKind::from_status(500), ApiErrorKind::InternalServer);
185 assert_eq!(ApiErrorKind::from_status(503), ApiErrorKind::InternalServer);
186 assert_eq!(ApiErrorKind::from_status(418), ApiErrorKind::Other);
187 }
188
189 #[test]
190 fn parses_error_body() {
191 let body = r#"{"error": {"message": "Invalid API key", "type": "invalid_request_error", "param": null, "code": "invalid_api_key"}}"#;
192 let err = OpenAIError::from_response(401, Some("req_123".into()), body);
193 let OpenAIError::Api(api) = err else { panic!("expected Api error") };
194 assert_eq!(api.status, 401);
195 assert_eq!(api.kind, ApiErrorKind::Authentication);
196 assert_eq!(api.request_id.as_deref(), Some("req_123"));
197 assert!(!api.is_retryable());
198 let detail = api.detail.expect("detail parsed");
199 assert_eq!(detail.message.as_deref(), Some("Invalid API key"));
200 assert_eq!(detail.code.as_deref(), Some("invalid_api_key"));
201 }
202
203 #[test]
204 fn handles_non_json_body_and_numeric_code() {
205 let err = OpenAIError::from_response(502, None, "Bad Gateway");
206 let OpenAIError::Api(api) = err else { panic!() };
207 assert!(api.detail.is_none());
208 assert!(api.is_retryable());
209 assert_eq!(api.message, "Bad Gateway");
211
212 let err = OpenAIError::from_response(502, None, "");
213 let OpenAIError::Api(api) = err else { panic!() };
214 assert_eq!(api.message, "Error code: 502");
215
216 let err = OpenAIError::from_response(429, None, r#"{"error": {"message": "slow down", "code": 42}}"#);
217 let OpenAIError::Api(api) = err else { panic!() };
218 assert_eq!(api.detail.unwrap().code.as_deref(), Some("42"));
219 }
220}