Skip to main content

byokey_proxy/
error.rs

1//! API error type that maps [`ByokError`] variants to HTTP status codes.
2
3use axum::{
4    Json,
5    http::StatusCode,
6    response::{IntoResponse, Response},
7};
8use byokey_types::ByokError;
9use serde_json::json;
10
11/// Wrapper around [`ByokError`] that implements [`IntoResponse`].
12pub struct ApiError(pub ByokError);
13
14impl ApiError {
15    /// Returns `(status, error_type, error_code)` for the wrapped error.
16    fn classify(&self) -> (StatusCode, &'static str, &'static str) {
17        match &self.0 {
18            ByokError::Auth(_) => (
19                StatusCode::UNAUTHORIZED,
20                "authentication_error",
21                "invalid_api_key",
22            ),
23            ByokError::TokenNotFound(_) | ByokError::TokenExpired(_) => (
24                StatusCode::UNAUTHORIZED,
25                "authentication_error",
26                "token_not_found",
27            ),
28            ByokError::UnsupportedModel(_) => (
29                StatusCode::BAD_REQUEST,
30                "invalid_request_error",
31                "model_not_found",
32            ),
33            ByokError::UnsupportedProvider(_) => (
34                StatusCode::BAD_REQUEST,
35                "invalid_request_error",
36                "provider_not_found",
37            ),
38            ByokError::Translation(_) => (
39                StatusCode::BAD_REQUEST,
40                "invalid_request_error",
41                "translation_error",
42            ),
43            ByokError::Upstream { status, .. } => classify_upstream(*status),
44            ByokError::Http(_) => (StatusCode::BAD_GATEWAY, "server_error", "upstream_error"),
45            _ => (
46                StatusCode::INTERNAL_SERVER_ERROR,
47                "server_error",
48                "internal_error",
49            ),
50        }
51    }
52}
53
54fn classify_upstream(status: u16) -> (StatusCode, &'static str, &'static str) {
55    match status {
56        429 => (
57            StatusCode::TOO_MANY_REQUESTS,
58            "rate_limit_error",
59            "rate_limit_exceeded",
60        ),
61        401 => (
62            StatusCode::UNAUTHORIZED,
63            "authentication_error",
64            "invalid_api_key",
65        ),
66        403 => (
67            StatusCode::FORBIDDEN,
68            "permission_error",
69            "insufficient_quota",
70        ),
71        _ => (StatusCode::BAD_GATEWAY, "server_error", "upstream_error"),
72    }
73}
74
75impl IntoResponse for ApiError {
76    fn into_response(self) -> Response {
77        let (status, error_type, error_code) = self.classify();
78        // Upstream errors: build the client message from fields directly so
79        // the body is forwarded to the original caller. `Display` omits the
80        // body to keep it out of logs/Sentry, so we don't want to use it here.
81        let msg = match &self.0 {
82            ByokError::Upstream {
83                status: s, body, ..
84            } => {
85                format!("upstream error: status={s}, body={body}")
86            }
87            other => other.to_string(),
88        };
89        (
90            status,
91            Json(json!({
92                "error": {
93                    "message": msg,
94                    "type": error_type,
95                    "code": error_code,
96                }
97            })),
98        )
99            .into_response()
100    }
101}
102
103impl From<ByokError> for ApiError {
104    fn from(e: ByokError) -> Self {
105        Self(e)
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use byokey_types::ProviderId;
113    use http_body_util::BodyExt as _;
114
115    async fn extract_error_body(err: ApiError) -> (StatusCode, serde_json::Value) {
116        let resp = err.into_response();
117        let status = resp.status();
118        let bytes = resp.into_body().collect().await.unwrap().to_bytes();
119        let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
120        (status, body)
121    }
122
123    #[tokio::test]
124    async fn test_auth_error() {
125        let (status, body) =
126            extract_error_body(ApiError(ByokError::Auth("bad creds".into()))).await;
127        assert_eq!(status, StatusCode::UNAUTHORIZED);
128        assert_eq!(body["error"]["type"], "authentication_error");
129        assert_eq!(body["error"]["code"], "invalid_api_key");
130    }
131
132    #[tokio::test]
133    async fn test_token_not_found_error() {
134        let (status, body) =
135            extract_error_body(ApiError(ByokError::TokenNotFound(ProviderId::Claude))).await;
136        assert_eq!(status, StatusCode::UNAUTHORIZED);
137        assert_eq!(body["error"]["type"], "authentication_error");
138        assert_eq!(body["error"]["code"], "token_not_found");
139    }
140
141    #[tokio::test]
142    async fn test_unsupported_model_error() {
143        let (status, body) =
144            extract_error_body(ApiError(ByokError::UnsupportedModel("xyz".into()))).await;
145        assert_eq!(status, StatusCode::BAD_REQUEST);
146        assert_eq!(body["error"]["type"], "invalid_request_error");
147        assert_eq!(body["error"]["code"], "model_not_found");
148    }
149
150    #[tokio::test]
151    async fn test_translation_error() {
152        let (status, body) =
153            extract_error_body(ApiError(ByokError::Translation("bad format".into()))).await;
154        assert_eq!(status, StatusCode::BAD_REQUEST);
155        assert_eq!(body["error"]["type"], "invalid_request_error");
156        assert_eq!(body["error"]["code"], "translation_error");
157    }
158
159    #[tokio::test]
160    async fn test_upstream_429_error() {
161        let (status, body) = extract_error_body(ApiError(ByokError::Upstream {
162            retry_after: None,
163            status: 429,
164            body: "rate limited".into(),
165        }))
166        .await;
167        assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
168        assert_eq!(body["error"]["type"], "rate_limit_error");
169        assert_eq!(body["error"]["code"], "rate_limit_exceeded");
170    }
171
172    #[tokio::test]
173    async fn test_upstream_401_error() {
174        let (status, body) = extract_error_body(ApiError(ByokError::Upstream {
175            retry_after: None,
176            status: 401,
177            body: "unauthorized".into(),
178        }))
179        .await;
180        assert_eq!(status, StatusCode::UNAUTHORIZED);
181        assert_eq!(body["error"]["type"], "authentication_error");
182        assert_eq!(body["error"]["code"], "invalid_api_key");
183    }
184
185    #[tokio::test]
186    async fn test_upstream_403_error() {
187        let (status, body) = extract_error_body(ApiError(ByokError::Upstream {
188            retry_after: None,
189            status: 403,
190            body: "forbidden".into(),
191        }))
192        .await;
193        assert_eq!(status, StatusCode::FORBIDDEN);
194        assert_eq!(body["error"]["type"], "permission_error");
195        assert_eq!(body["error"]["code"], "insufficient_quota");
196    }
197
198    #[tokio::test]
199    async fn test_upstream_500_error() {
200        let (status, body) = extract_error_body(ApiError(ByokError::Upstream {
201            retry_after: None,
202            status: 500,
203            body: "server error".into(),
204        }))
205        .await;
206        assert_eq!(status, StatusCode::BAD_GATEWAY);
207        assert_eq!(body["error"]["type"], "server_error");
208        assert_eq!(body["error"]["code"], "upstream_error");
209    }
210
211    #[tokio::test]
212    async fn test_http_transport_error() {
213        let (status, body) =
214            extract_error_body(ApiError(ByokError::Http("connection refused".into()))).await;
215        assert_eq!(status, StatusCode::BAD_GATEWAY);
216        assert_eq!(body["error"]["type"], "server_error");
217        assert_eq!(body["error"]["code"], "upstream_error");
218    }
219
220    #[tokio::test]
221    async fn test_internal_error() {
222        let (status, body) =
223            extract_error_body(ApiError(ByokError::Config("bad config".into()))).await;
224        assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
225        assert_eq!(body["error"]["type"], "server_error");
226        assert_eq!(body["error"]["code"], "internal_error");
227    }
228}