Skip to main content

byokey_proxy/
error.rs

1//! API error type that maps [`ByokError`] variants to HTTP status codes.
2
3use axum::{
4    Json,
5    http::StatusCode,
6    response::{IntoResponse, Response},
7};
8use byokey_types::ByokError;
9use serde_json::json;
10
11/// Wrapper around [`ByokError`] that implements [`IntoResponse`].
12pub struct ApiError(pub ByokError);
13
14impl ApiError {
15    /// Returns `(status, error_type, error_code)` for the wrapped error.
16    fn classify(&self) -> (StatusCode, &'static str, &'static str) {
17        match &self.0 {
18            ByokError::Auth(_) => (
19                StatusCode::UNAUTHORIZED,
20                "authentication_error",
21                "invalid_api_key",
22            ),
23            ByokError::TokenNotFound(_) | ByokError::TokenExpired(_) => (
24                StatusCode::UNAUTHORIZED,
25                "authentication_error",
26                "token_not_found",
27            ),
28            ByokError::UnsupportedModel(_) => (
29                StatusCode::BAD_REQUEST,
30                "invalid_request_error",
31                "model_not_found",
32            ),
33            ByokError::Translation(_) => (
34                StatusCode::BAD_REQUEST,
35                "invalid_request_error",
36                "translation_error",
37            ),
38            ByokError::Http(m) => classify_http(m),
39            _ => (
40                StatusCode::INTERNAL_SERVER_ERROR,
41                "server_error",
42                "internal_error",
43            ),
44        }
45    }
46}
47
48fn classify_http(msg: &str) -> (StatusCode, &'static str, &'static str) {
49    if msg.contains("429") {
50        (
51            StatusCode::TOO_MANY_REQUESTS,
52            "rate_limit_error",
53            "rate_limit_exceeded",
54        )
55    } else if msg.contains("401") {
56        (
57            StatusCode::UNAUTHORIZED,
58            "authentication_error",
59            "invalid_api_key",
60        )
61    } else if msg.contains("403") {
62        (
63            StatusCode::FORBIDDEN,
64            "permission_error",
65            "insufficient_quota",
66        )
67    } else {
68        (StatusCode::BAD_GATEWAY, "server_error", "upstream_error")
69    }
70}
71
72impl IntoResponse for ApiError {
73    fn into_response(self) -> Response {
74        let (status, error_type, error_code) = self.classify();
75        let msg = self.0.to_string();
76        (
77            status,
78            Json(json!({
79                "error": {
80                    "message": msg,
81                    "type": error_type,
82                    "code": error_code,
83                }
84            })),
85        )
86            .into_response()
87    }
88}
89
90impl From<ByokError> for ApiError {
91    fn from(e: ByokError) -> Self {
92        Self(e)
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99    use byokey_types::ProviderId;
100    use http_body_util::BodyExt as _;
101
102    async fn extract_error_body(err: ApiError) -> (StatusCode, serde_json::Value) {
103        let resp = err.into_response();
104        let status = resp.status();
105        let bytes = resp.into_body().collect().await.unwrap().to_bytes();
106        let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
107        (status, body)
108    }
109
110    #[tokio::test]
111    async fn test_auth_error() {
112        let (status, body) =
113            extract_error_body(ApiError(ByokError::Auth("bad creds".into()))).await;
114        assert_eq!(status, StatusCode::UNAUTHORIZED);
115        assert_eq!(body["error"]["type"], "authentication_error");
116        assert_eq!(body["error"]["code"], "invalid_api_key");
117    }
118
119    #[tokio::test]
120    async fn test_token_not_found_error() {
121        let (status, body) =
122            extract_error_body(ApiError(ByokError::TokenNotFound(ProviderId::Claude))).await;
123        assert_eq!(status, StatusCode::UNAUTHORIZED);
124        assert_eq!(body["error"]["type"], "authentication_error");
125        assert_eq!(body["error"]["code"], "token_not_found");
126    }
127
128    #[tokio::test]
129    async fn test_unsupported_model_error() {
130        let (status, body) =
131            extract_error_body(ApiError(ByokError::UnsupportedModel("xyz".into()))).await;
132        assert_eq!(status, StatusCode::BAD_REQUEST);
133        assert_eq!(body["error"]["type"], "invalid_request_error");
134        assert_eq!(body["error"]["code"], "model_not_found");
135    }
136
137    #[tokio::test]
138    async fn test_translation_error() {
139        let (status, body) =
140            extract_error_body(ApiError(ByokError::Translation("bad format".into()))).await;
141        assert_eq!(status, StatusCode::BAD_REQUEST);
142        assert_eq!(body["error"]["type"], "invalid_request_error");
143        assert_eq!(body["error"]["code"], "translation_error");
144    }
145
146    #[tokio::test]
147    async fn test_http_429_error() {
148        let (status, body) =
149            extract_error_body(ApiError(ByokError::Http("status 429 too many".into()))).await;
150        assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
151        assert_eq!(body["error"]["type"], "rate_limit_error");
152        assert_eq!(body["error"]["code"], "rate_limit_exceeded");
153    }
154
155    #[tokio::test]
156    async fn test_http_401_error() {
157        let (status, body) =
158            extract_error_body(ApiError(ByokError::Http("status 401 unauthorized".into()))).await;
159        assert_eq!(status, StatusCode::UNAUTHORIZED);
160        assert_eq!(body["error"]["type"], "authentication_error");
161        assert_eq!(body["error"]["code"], "invalid_api_key");
162    }
163
164    #[tokio::test]
165    async fn test_http_403_error() {
166        let (status, body) =
167            extract_error_body(ApiError(ByokError::Http("status 403 forbidden".into()))).await;
168        assert_eq!(status, StatusCode::FORBIDDEN);
169        assert_eq!(body["error"]["type"], "permission_error");
170        assert_eq!(body["error"]["code"], "insufficient_quota");
171    }
172
173    #[tokio::test]
174    async fn test_http_other_error() {
175        let (status, body) =
176            extract_error_body(ApiError(ByokError::Http("upstream error".into()))).await;
177        assert_eq!(status, StatusCode::BAD_GATEWAY);
178        assert_eq!(body["error"]["type"], "server_error");
179        assert_eq!(body["error"]["code"], "upstream_error");
180    }
181
182    #[tokio::test]
183    async fn test_internal_error() {
184        let (status, body) =
185            extract_error_body(ApiError(ByokError::Config("bad config".into()))).await;
186        assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
187        assert_eq!(body["error"]["type"], "server_error");
188        assert_eq!(body["error"]["code"], "internal_error");
189    }
190}