Skip to main content

byokey_proxy/
error.rs

1//! API error type that maps [`ByokError`] variants to HTTP status codes.
2
3use axum::{
4    Json,
5    http::StatusCode,
6    response::{IntoResponse, Response},
7};
8use byokey_types::ByokError;
9use serde_json::json;
10
11/// Wrapper around [`ByokError`] that implements [`IntoResponse`].
12///
13/// Maps error variants to appropriate HTTP status codes:
14/// - `UnsupportedModel` -> 400
15/// - `TokenNotFound` / `TokenExpired` -> 401
16/// - `Http` -> 502
17/// - Everything else -> 500
18pub struct ApiError(pub ByokError);
19
20impl IntoResponse for ApiError {
21    fn into_response(self) -> Response {
22        let (code, msg) = match &self.0 {
23            ByokError::UnsupportedModel(m) => {
24                (StatusCode::BAD_REQUEST, format!("unsupported model: {m}"))
25            }
26            ByokError::TokenNotFound(_) | ByokError::TokenExpired(_) => {
27                (StatusCode::UNAUTHORIZED, self.0.to_string())
28            }
29            ByokError::Http(m) => (StatusCode::BAD_GATEWAY, m.clone()),
30            _ => (StatusCode::INTERNAL_SERVER_ERROR, self.0.to_string()),
31        };
32        (
33            code,
34            Json(json!({"error": {"message": msg, "type": "byokey_error"}})),
35        )
36            .into_response()
37    }
38}
39
40impl From<ByokError> for ApiError {
41    fn from(e: ByokError) -> Self {
42        Self(e)
43    }
44}
45
46#[cfg(test)]
47mod tests {
48    use super::*;
49    use byokey_types::ProviderId;
50
51    #[test]
52    fn test_unsupported_model_is_bad_request() {
53        let err = ApiError(ByokError::UnsupportedModel("xyz".into()));
54        let resp = err.into_response();
55        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
56    }
57
58    #[test]
59    fn test_token_not_found_is_unauthorized() {
60        let err = ApiError(ByokError::TokenNotFound(ProviderId::Claude));
61        let resp = err.into_response();
62        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
63    }
64
65    #[test]
66    fn test_http_error_is_bad_gateway() {
67        let err = ApiError(ByokError::Http("upstream error".into()));
68        let resp = err.into_response();
69        assert_eq!(resp.status(), StatusCode::BAD_GATEWAY);
70    }
71}