Skip to main content

liter_llm_proxy/
error.rs

1use std::fmt;
2use std::time::Duration;
3
4use axum::Json;
5use axum::http::StatusCode;
6use axum::response::{IntoResponse, Response};
7use liter_llm::error::{ApiError, ErrorResponse, LiterLlmError};
8
9/// An HTTP-aware error that serialises to an OpenAI-compatible JSON body.
10///
11/// `ProxyError` carries the HTTP status code, the structured [`ErrorResponse`]
12/// payload, and an optional `Retry-After` duration so that [`IntoResponse`] can
13/// produce the correct wire representation — including headers — in a single
14/// step.
15#[derive(Debug)]
16pub struct ProxyError {
17    status: StatusCode,
18    body: ErrorResponse,
19    retry_after: Option<Duration>,
20}
21
22impl ProxyError {
23    /// Create a `ProxyError` from a status code and an error type / message
24    /// pair.
25    fn new(status: StatusCode, error_type: impl Into<String>, message: impl Into<String>) -> Self {
26        Self {
27            status,
28            body: ErrorResponse {
29                error: ApiError {
30                    message: message.into(),
31                    error_type: error_type.into(),
32                    param: None,
33                    code: None,
34                },
35            },
36            retry_after: None,
37        }
38    }
39
40    /// 401 Unauthorized.
41    pub fn authentication(message: impl Into<String>) -> Self {
42        Self::new(StatusCode::UNAUTHORIZED, "Authentication", message)
43    }
44
45    /// 404 Not Found.
46    pub fn not_found(message: impl Into<String>) -> Self {
47        Self::new(StatusCode::NOT_FOUND, "NotFound", message)
48    }
49
50    /// 400 Bad Request.
51    pub fn bad_request(message: impl Into<String>) -> Self {
52        Self::new(StatusCode::BAD_REQUEST, "BadRequest", message)
53    }
54
55    /// 500 Internal Server Error.
56    pub fn internal(message: impl Into<String>) -> Self {
57        Self::new(StatusCode::INTERNAL_SERVER_ERROR, "InternalError", message)
58    }
59
60    /// 503 Service Unavailable.
61    pub fn service_unavailable(message: impl Into<String>) -> Self {
62        Self::new(StatusCode::SERVICE_UNAVAILABLE, "ServiceUnavailable", message)
63    }
64
65    /// 403 Forbidden.
66    pub fn forbidden(message: impl Into<String>) -> Self {
67        Self::new(StatusCode::FORBIDDEN, "Forbidden", message)
68    }
69
70    /// 429 Too Many Requests.
71    pub fn rate_limited(message: impl Into<String>) -> Self {
72        Self::new(StatusCode::TOO_MANY_REQUESTS, "RateLimited", message)
73    }
74}
75
76impl fmt::Display for ProxyError {
77    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78        f.write_str(&self.body.error.message)
79    }
80}
81
82impl std::error::Error for ProxyError {}
83
84impl IntoResponse for ProxyError {
85    fn into_response(self) -> Response {
86        let mut response = (self.status, Json(self.body)).into_response();
87        if let Some(duration) = self.retry_after
88            && let Ok(value) = duration.as_secs().to_string().parse()
89        {
90            response.headers_mut().insert("retry-after", value);
91        }
92        response
93    }
94}
95
96impl From<LiterLlmError> for ProxyError {
97    fn from(err: LiterLlmError) -> Self {
98        let error_type = err.error_type().to_owned();
99        let message = err.to_string();
100
101        // Extract retry_after before we lose access to the variant fields.
102        let retry_after = if let LiterLlmError::RateLimited { retry_after, .. } = &err {
103            *retry_after
104        } else {
105            None
106        };
107
108        let status = match &err {
109            LiterLlmError::Authentication { .. } => StatusCode::UNAUTHORIZED,
110            LiterLlmError::RateLimited { .. } => StatusCode::TOO_MANY_REQUESTS,
111            LiterLlmError::BadRequest { .. } => StatusCode::BAD_REQUEST,
112            LiterLlmError::ContextWindowExceeded { .. } => StatusCode::BAD_REQUEST,
113            LiterLlmError::ContentPolicy { .. } => StatusCode::BAD_REQUEST,
114            LiterLlmError::NotFound { .. } => StatusCode::NOT_FOUND,
115            LiterLlmError::BudgetExceeded { .. } => StatusCode::TOO_MANY_REQUESTS,
116            LiterLlmError::HookRejected { .. } => StatusCode::FORBIDDEN,
117            LiterLlmError::Timeout => StatusCode::GATEWAY_TIMEOUT,
118            LiterLlmError::ServiceUnavailable { .. } => StatusCode::SERVICE_UNAVAILABLE,
119            LiterLlmError::ServerError { .. } => StatusCode::INTERNAL_SERVER_ERROR,
120            LiterLlmError::Network(_) => StatusCode::BAD_GATEWAY,
121            LiterLlmError::Streaming { .. } => StatusCode::INTERNAL_SERVER_ERROR,
122            LiterLlmError::EndpointNotSupported { .. } => StatusCode::NOT_IMPLEMENTED,
123            LiterLlmError::InvalidHeader { .. } => StatusCode::BAD_REQUEST,
124            LiterLlmError::Serialization(_) => StatusCode::BAD_REQUEST,
125            LiterLlmError::InternalError { .. } => StatusCode::INTERNAL_SERVER_ERROR,
126            // LiterLlmError is #[non_exhaustive]; treat unknown future variants
127            // as internal server errors.
128            _ => StatusCode::INTERNAL_SERVER_ERROR,
129        };
130
131        Self {
132            status,
133            body: ErrorResponse {
134                error: ApiError {
135                    message,
136                    error_type,
137                    param: None,
138                    code: None,
139                },
140            },
141            retry_after,
142        }
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use std::time::Duration;
149
150    use axum::body::Body;
151    use axum::http::StatusCode;
152    use axum::response::IntoResponse;
153    use http_body_util::BodyExt;
154    use liter_llm::error::{ErrorResponse, LiterLlmError};
155
156    use super::ProxyError;
157
158    /// Helper: convert a `ProxyError` into a response and extract status + JSON
159    /// body.
160    async fn extract(err: ProxyError) -> (StatusCode, ErrorResponse) {
161        let response = err.into_response();
162        let status = response.status();
163        let bytes = Body::new(response.into_body()).collect().await.unwrap().to_bytes();
164        let body: ErrorResponse = serde_json::from_slice(&bytes).unwrap();
165        (status, body)
166    }
167
168    // ── Variant -> HTTP status mapping ───────────────────────────────────
169
170    #[tokio::test]
171    async fn authentication_maps_to_401() {
172        let err: ProxyError = LiterLlmError::Authentication {
173            message: "bad key".into(),
174        }
175        .into();
176        let (status, body) = extract(err).await;
177        assert_eq!(status, StatusCode::UNAUTHORIZED);
178        assert_eq!(body.error.error_type, "Authentication");
179    }
180
181    #[tokio::test]
182    async fn rate_limited_maps_to_429() {
183        let err: ProxyError = LiterLlmError::RateLimited {
184            message: "slow down".into(),
185            retry_after: None,
186        }
187        .into();
188        let (status, _) = extract(err).await;
189        assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
190    }
191
192    #[tokio::test]
193    async fn bad_request_maps_to_400() {
194        let err: ProxyError = LiterLlmError::BadRequest {
195            message: "invalid".into(),
196        }
197        .into();
198        let (status, _) = extract(err).await;
199        assert_eq!(status, StatusCode::BAD_REQUEST);
200    }
201
202    #[tokio::test]
203    async fn context_window_exceeded_maps_to_400() {
204        let err: ProxyError = LiterLlmError::ContextWindowExceeded {
205            message: "too long".into(),
206        }
207        .into();
208        let (status, body) = extract(err).await;
209        assert_eq!(status, StatusCode::BAD_REQUEST);
210        assert_eq!(body.error.error_type, "ContextWindowExceeded");
211    }
212
213    #[tokio::test]
214    async fn content_policy_maps_to_400() {
215        let err: ProxyError = LiterLlmError::ContentPolicy {
216            message: "violation".into(),
217        }
218        .into();
219        let (status, _) = extract(err).await;
220        assert_eq!(status, StatusCode::BAD_REQUEST);
221    }
222
223    #[tokio::test]
224    async fn not_found_maps_to_404() {
225        let err: ProxyError = LiterLlmError::NotFound { message: "gone".into() }.into();
226        let (status, _) = extract(err).await;
227        assert_eq!(status, StatusCode::NOT_FOUND);
228    }
229
230    #[tokio::test]
231    async fn budget_exceeded_maps_to_429() {
232        let err: ProxyError = LiterLlmError::BudgetExceeded {
233            message: "over budget".into(),
234            model: None,
235        }
236        .into();
237        let (status, _) = extract(err).await;
238        assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
239    }
240
241    #[tokio::test]
242    async fn hook_rejected_maps_to_403() {
243        let err: ProxyError = LiterLlmError::HookRejected {
244            message: "denied".into(),
245        }
246        .into();
247        let (status, _) = extract(err).await;
248        assert_eq!(status, StatusCode::FORBIDDEN);
249    }
250
251    #[tokio::test]
252    async fn timeout_maps_to_504() {
253        let err: ProxyError = LiterLlmError::Timeout.into();
254        let (status, _) = extract(err).await;
255        assert_eq!(status, StatusCode::GATEWAY_TIMEOUT);
256    }
257
258    #[tokio::test]
259    async fn service_unavailable_maps_to_503() {
260        let err: ProxyError = LiterLlmError::ServiceUnavailable { message: "down".into() }.into();
261        let (status, _) = extract(err).await;
262        assert_eq!(status, StatusCode::SERVICE_UNAVAILABLE);
263    }
264
265    #[tokio::test]
266    async fn server_error_maps_to_500() {
267        let err: ProxyError = LiterLlmError::ServerError { message: "boom".into() }.into();
268        let (status, _) = extract(err).await;
269        assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
270    }
271
272    #[tokio::test]
273    async fn streaming_maps_to_500() {
274        let err: ProxyError = LiterLlmError::Streaming {
275            message: "broke".into(),
276        }
277        .into();
278        let (status, _) = extract(err).await;
279        assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
280    }
281
282    #[tokio::test]
283    async fn endpoint_not_supported_maps_to_501() {
284        let err: ProxyError = LiterLlmError::EndpointNotSupported {
285            endpoint: "images".into(),
286            provider: "test".into(),
287        }
288        .into();
289        let (status, _) = extract(err).await;
290        assert_eq!(status, StatusCode::NOT_IMPLEMENTED);
291    }
292
293    #[tokio::test]
294    async fn invalid_header_maps_to_400() {
295        let err: ProxyError = LiterLlmError::InvalidHeader {
296            name: "x-bad".into(),
297            reason: "nope".into(),
298        }
299        .into();
300        let (status, _) = extract(err).await;
301        assert_eq!(status, StatusCode::BAD_REQUEST);
302    }
303
304    #[tokio::test]
305    async fn serialization_maps_to_400() {
306        let json_err = serde_json::from_str::<String>("not json").unwrap_err();
307        let err: ProxyError = LiterLlmError::Serialization(json_err).into();
308        let (status, _) = extract(err).await;
309        assert_eq!(status, StatusCode::BAD_REQUEST);
310    }
311
312    #[tokio::test]
313    async fn internal_error_maps_to_500() {
314        let err: ProxyError = LiterLlmError::InternalError { message: "bug".into() }.into();
315        let (status, _) = extract(err).await;
316        assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
317    }
318
319    // ── IntoResponse produces valid JSON ─────────────────────────────────
320
321    #[tokio::test]
322    async fn into_response_produces_valid_json_with_correct_fields() {
323        let err: ProxyError = LiterLlmError::Authentication {
324            message: "invalid api key".into(),
325        }
326        .into();
327        let (status, body) = extract(err).await;
328        assert_eq!(status, StatusCode::UNAUTHORIZED);
329        assert_eq!(body.error.error_type, "Authentication");
330        assert!(body.error.message.contains("invalid api key"));
331    }
332
333    // ── Constructor methods ──────────────────────────────────────────────
334
335    #[tokio::test]
336    async fn constructor_authentication() {
337        let (status, body) = extract(ProxyError::authentication("no token")).await;
338        assert_eq!(status, StatusCode::UNAUTHORIZED);
339        assert_eq!(body.error.error_type, "Authentication");
340        assert_eq!(body.error.message, "no token");
341    }
342
343    #[tokio::test]
344    async fn constructor_not_found() {
345        let (status, _) = extract(ProxyError::not_found("missing")).await;
346        assert_eq!(status, StatusCode::NOT_FOUND);
347    }
348
349    #[tokio::test]
350    async fn constructor_bad_request() {
351        let (status, _) = extract(ProxyError::bad_request("oops")).await;
352        assert_eq!(status, StatusCode::BAD_REQUEST);
353    }
354
355    #[tokio::test]
356    async fn constructor_internal() {
357        let (status, _) = extract(ProxyError::internal("bug")).await;
358        assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
359    }
360
361    #[tokio::test]
362    async fn constructor_forbidden() {
363        let (status, _) = extract(ProxyError::forbidden("nope")).await;
364        assert_eq!(status, StatusCode::FORBIDDEN);
365    }
366
367    #[tokio::test]
368    async fn constructor_rate_limited() {
369        let (status, _) = extract(ProxyError::rate_limited("slow")).await;
370        assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
371    }
372
373    // ── Retry-After header ───────────────────────────────────────────────
374
375    #[tokio::test]
376    async fn rate_limited_with_retry_after_includes_header() {
377        let err: ProxyError = LiterLlmError::RateLimited {
378            message: "slow down".into(),
379            retry_after: Some(Duration::from_secs(30)),
380        }
381        .into();
382        let response = err.into_response();
383        assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
384        let retry = response
385            .headers()
386            .get("retry-after")
387            .expect("retry-after header must be present");
388        assert_eq!(retry.to_str().unwrap(), "30");
389    }
390
391    #[tokio::test]
392    async fn rate_limited_without_retry_after_omits_header() {
393        let err: ProxyError = LiterLlmError::RateLimited {
394            message: "slow down".into(),
395            retry_after: None,
396        }
397        .into();
398        let response = err.into_response();
399        assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
400        assert!(response.headers().get("retry-after").is_none());
401    }
402
403    // ── Display impl ─────────────────────────────────────────────────────
404
405    #[test]
406    fn display_delegates_to_body_message() {
407        let err = ProxyError::authentication("bad api key");
408        assert_eq!(err.to_string(), "bad api key");
409    }
410
411    #[test]
412    fn display_from_core_error() {
413        let err: ProxyError = LiterLlmError::NotFound {
414            message: "model gone".into(),
415        }
416        .into();
417        assert!(err.to_string().contains("model gone"));
418    }
419}