Skip to main content

byokey_proxy/
error.rs

1//! API error type that maps [`ByokError`] variants to HTTP status codes.
2
3use axum::{
4    Json,
5    http::StatusCode,
6    response::{IntoResponse, Response},
7};
8use byokey_types::ByokError;
9use serde_json::json;
10
11/// Wrapper around [`ByokError`] that implements [`IntoResponse`].
12pub struct ApiError(pub ByokError);
13
14impl ApiError {
15    /// Returns `(status, error_type, error_code)` for the wrapped error.
16    fn classify(&self) -> (StatusCode, &'static str, &'static str) {
17        match &self.0 {
18            ByokError::Auth(_) => (
19                StatusCode::UNAUTHORIZED,
20                "authentication_error",
21                "invalid_api_key",
22            ),
23            ByokError::TokenNotFound(_) | ByokError::TokenExpired(_) => (
24                StatusCode::UNAUTHORIZED,
25                "authentication_error",
26                "token_not_found",
27            ),
28            ByokError::UnsupportedModel(_) => (
29                StatusCode::BAD_REQUEST,
30                "invalid_request_error",
31                "model_not_found",
32            ),
33            ByokError::Translation(_) => (
34                StatusCode::BAD_REQUEST,
35                "invalid_request_error",
36                "translation_error",
37            ),
38            ByokError::Upstream { status, .. } => classify_upstream(*status),
39            ByokError::Http(_) => (StatusCode::BAD_GATEWAY, "server_error", "upstream_error"),
40            _ => (
41                StatusCode::INTERNAL_SERVER_ERROR,
42                "server_error",
43                "internal_error",
44            ),
45        }
46    }
47}
48
49fn classify_upstream(status: u16) -> (StatusCode, &'static str, &'static str) {
50    match status {
51        429 => (
52            StatusCode::TOO_MANY_REQUESTS,
53            "rate_limit_error",
54            "rate_limit_exceeded",
55        ),
56        401 => (
57            StatusCode::UNAUTHORIZED,
58            "authentication_error",
59            "invalid_api_key",
60        ),
61        403 => (
62            StatusCode::FORBIDDEN,
63            "permission_error",
64            "insufficient_quota",
65        ),
66        _ => (StatusCode::BAD_GATEWAY, "server_error", "upstream_error"),
67    }
68}
69
70impl IntoResponse for ApiError {
71    fn into_response(self) -> Response {
72        let (status, error_type, error_code) = self.classify();
73        let msg = self.0.to_string();
74        (
75            status,
76            Json(json!({
77                "error": {
78                    "message": msg,
79                    "type": error_type,
80                    "code": error_code,
81                }
82            })),
83        )
84            .into_response()
85    }
86}
87
88impl From<ByokError> for ApiError {
89    fn from(e: ByokError) -> Self {
90        Self(e)
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97    use byokey_types::ProviderId;
98    use http_body_util::BodyExt as _;
99
100    async fn extract_error_body(err: ApiError) -> (StatusCode, serde_json::Value) {
101        let resp = err.into_response();
102        let status = resp.status();
103        let bytes = resp.into_body().collect().await.unwrap().to_bytes();
104        let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
105        (status, body)
106    }
107
108    #[tokio::test]
109    async fn test_auth_error() {
110        let (status, body) =
111            extract_error_body(ApiError(ByokError::Auth("bad creds".into()))).await;
112        assert_eq!(status, StatusCode::UNAUTHORIZED);
113        assert_eq!(body["error"]["type"], "authentication_error");
114        assert_eq!(body["error"]["code"], "invalid_api_key");
115    }
116
117    #[tokio::test]
118    async fn test_token_not_found_error() {
119        let (status, body) =
120            extract_error_body(ApiError(ByokError::TokenNotFound(ProviderId::Claude))).await;
121        assert_eq!(status, StatusCode::UNAUTHORIZED);
122        assert_eq!(body["error"]["type"], "authentication_error");
123        assert_eq!(body["error"]["code"], "token_not_found");
124    }
125
126    #[tokio::test]
127    async fn test_unsupported_model_error() {
128        let (status, body) =
129            extract_error_body(ApiError(ByokError::UnsupportedModel("xyz".into()))).await;
130        assert_eq!(status, StatusCode::BAD_REQUEST);
131        assert_eq!(body["error"]["type"], "invalid_request_error");
132        assert_eq!(body["error"]["code"], "model_not_found");
133    }
134
135    #[tokio::test]
136    async fn test_translation_error() {
137        let (status, body) =
138            extract_error_body(ApiError(ByokError::Translation("bad format".into()))).await;
139        assert_eq!(status, StatusCode::BAD_REQUEST);
140        assert_eq!(body["error"]["type"], "invalid_request_error");
141        assert_eq!(body["error"]["code"], "translation_error");
142    }
143
144    #[tokio::test]
145    async fn test_upstream_429_error() {
146        let (status, body) = extract_error_body(ApiError(ByokError::Upstream {
147            status: 429,
148            body: "rate limited".into(),
149        }))
150        .await;
151        assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
152        assert_eq!(body["error"]["type"], "rate_limit_error");
153        assert_eq!(body["error"]["code"], "rate_limit_exceeded");
154    }
155
156    #[tokio::test]
157    async fn test_upstream_401_error() {
158        let (status, body) = extract_error_body(ApiError(ByokError::Upstream {
159            status: 401,
160            body: "unauthorized".into(),
161        }))
162        .await;
163        assert_eq!(status, StatusCode::UNAUTHORIZED);
164        assert_eq!(body["error"]["type"], "authentication_error");
165        assert_eq!(body["error"]["code"], "invalid_api_key");
166    }
167
168    #[tokio::test]
169    async fn test_upstream_403_error() {
170        let (status, body) = extract_error_body(ApiError(ByokError::Upstream {
171            status: 403,
172            body: "forbidden".into(),
173        }))
174        .await;
175        assert_eq!(status, StatusCode::FORBIDDEN);
176        assert_eq!(body["error"]["type"], "permission_error");
177        assert_eq!(body["error"]["code"], "insufficient_quota");
178    }
179
180    #[tokio::test]
181    async fn test_upstream_500_error() {
182        let (status, body) = extract_error_body(ApiError(ByokError::Upstream {
183            status: 500,
184            body: "server error".into(),
185        }))
186        .await;
187        assert_eq!(status, StatusCode::BAD_GATEWAY);
188        assert_eq!(body["error"]["type"], "server_error");
189        assert_eq!(body["error"]["code"], "upstream_error");
190    }
191
192    #[tokio::test]
193    async fn test_http_transport_error() {
194        let (status, body) =
195            extract_error_body(ApiError(ByokError::Http("connection refused".into()))).await;
196        assert_eq!(status, StatusCode::BAD_GATEWAY);
197        assert_eq!(body["error"]["type"], "server_error");
198        assert_eq!(body["error"]["code"], "upstream_error");
199    }
200
201    #[tokio::test]
202    async fn test_internal_error() {
203        let (status, body) =
204            extract_error_body(ApiError(ByokError::Config("bad config".into()))).await;
205        assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
206        assert_eq!(body["error"]["type"], "server_error");
207        assert_eq!(body["error"]["code"], "internal_error");
208    }
209}