1use axum::{
4 Json,
5 http::StatusCode,
6 response::{IntoResponse, Response},
7};
8use byokey_types::ByokError;
9use serde_json::json;
10
11pub struct ApiError(pub ByokError);
13
14impl ApiError {
15 fn classify(&self) -> (StatusCode, &'static str, &'static str) {
17 match &self.0 {
18 ByokError::Auth(_) => (
19 StatusCode::UNAUTHORIZED,
20 "authentication_error",
21 "invalid_api_key",
22 ),
23 ByokError::TokenNotFound(_) | ByokError::TokenExpired(_) => (
24 StatusCode::UNAUTHORIZED,
25 "authentication_error",
26 "token_not_found",
27 ),
28 ByokError::UnsupportedModel(_) => (
29 StatusCode::BAD_REQUEST,
30 "invalid_request_error",
31 "model_not_found",
32 ),
33 ByokError::UnsupportedProvider(_) => (
34 StatusCode::BAD_REQUEST,
35 "invalid_request_error",
36 "provider_not_found",
37 ),
38 ByokError::Translation(_) => (
39 StatusCode::BAD_REQUEST,
40 "invalid_request_error",
41 "translation_error",
42 ),
43 ByokError::Upstream { status, .. } => classify_upstream(*status),
44 ByokError::Http(_) => (StatusCode::BAD_GATEWAY, "server_error", "upstream_error"),
45 _ => (
46 StatusCode::INTERNAL_SERVER_ERROR,
47 "server_error",
48 "internal_error",
49 ),
50 }
51 }
52}
53
54fn classify_upstream(status: u16) -> (StatusCode, &'static str, &'static str) {
55 match status {
56 429 => (
57 StatusCode::TOO_MANY_REQUESTS,
58 "rate_limit_error",
59 "rate_limit_exceeded",
60 ),
61 401 => (
62 StatusCode::UNAUTHORIZED,
63 "authentication_error",
64 "invalid_api_key",
65 ),
66 403 => (
67 StatusCode::FORBIDDEN,
68 "permission_error",
69 "insufficient_quota",
70 ),
71 _ => (StatusCode::BAD_GATEWAY, "server_error", "upstream_error"),
72 }
73}
74
75impl IntoResponse for ApiError {
76 fn into_response(self) -> Response {
77 let (status, error_type, error_code) = self.classify();
78 let msg = match &self.0 {
82 ByokError::Upstream {
83 status: s, body, ..
84 } => {
85 format!("upstream error: status={s}, body={body}")
86 }
87 other => other.to_string(),
88 };
89 (
90 status,
91 Json(json!({
92 "error": {
93 "message": msg,
94 "type": error_type,
95 "code": error_code,
96 }
97 })),
98 )
99 .into_response()
100 }
101}
102
103impl From<ByokError> for ApiError {
104 fn from(e: ByokError) -> Self {
105 Self(e)
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use byokey_types::ProviderId;
113 use http_body_util::BodyExt as _;
114
115 async fn extract_error_body(err: ApiError) -> (StatusCode, serde_json::Value) {
116 let resp = err.into_response();
117 let status = resp.status();
118 let bytes = resp.into_body().collect().await.unwrap().to_bytes();
119 let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
120 (status, body)
121 }
122
123 #[tokio::test]
124 async fn test_auth_error() {
125 let (status, body) =
126 extract_error_body(ApiError(ByokError::Auth("bad creds".into()))).await;
127 assert_eq!(status, StatusCode::UNAUTHORIZED);
128 assert_eq!(body["error"]["type"], "authentication_error");
129 assert_eq!(body["error"]["code"], "invalid_api_key");
130 }
131
132 #[tokio::test]
133 async fn test_token_not_found_error() {
134 let (status, body) =
135 extract_error_body(ApiError(ByokError::TokenNotFound(ProviderId::Claude))).await;
136 assert_eq!(status, StatusCode::UNAUTHORIZED);
137 assert_eq!(body["error"]["type"], "authentication_error");
138 assert_eq!(body["error"]["code"], "token_not_found");
139 }
140
141 #[tokio::test]
142 async fn test_unsupported_model_error() {
143 let (status, body) =
144 extract_error_body(ApiError(ByokError::UnsupportedModel("xyz".into()))).await;
145 assert_eq!(status, StatusCode::BAD_REQUEST);
146 assert_eq!(body["error"]["type"], "invalid_request_error");
147 assert_eq!(body["error"]["code"], "model_not_found");
148 }
149
150 #[tokio::test]
151 async fn test_translation_error() {
152 let (status, body) =
153 extract_error_body(ApiError(ByokError::Translation("bad format".into()))).await;
154 assert_eq!(status, StatusCode::BAD_REQUEST);
155 assert_eq!(body["error"]["type"], "invalid_request_error");
156 assert_eq!(body["error"]["code"], "translation_error");
157 }
158
159 #[tokio::test]
160 async fn test_upstream_429_error() {
161 let (status, body) = extract_error_body(ApiError(ByokError::Upstream {
162 retry_after: None,
163 status: 429,
164 body: "rate limited".into(),
165 }))
166 .await;
167 assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
168 assert_eq!(body["error"]["type"], "rate_limit_error");
169 assert_eq!(body["error"]["code"], "rate_limit_exceeded");
170 }
171
172 #[tokio::test]
173 async fn test_upstream_401_error() {
174 let (status, body) = extract_error_body(ApiError(ByokError::Upstream {
175 retry_after: None,
176 status: 401,
177 body: "unauthorized".into(),
178 }))
179 .await;
180 assert_eq!(status, StatusCode::UNAUTHORIZED);
181 assert_eq!(body["error"]["type"], "authentication_error");
182 assert_eq!(body["error"]["code"], "invalid_api_key");
183 }
184
185 #[tokio::test]
186 async fn test_upstream_403_error() {
187 let (status, body) = extract_error_body(ApiError(ByokError::Upstream {
188 retry_after: None,
189 status: 403,
190 body: "forbidden".into(),
191 }))
192 .await;
193 assert_eq!(status, StatusCode::FORBIDDEN);
194 assert_eq!(body["error"]["type"], "permission_error");
195 assert_eq!(body["error"]["code"], "insufficient_quota");
196 }
197
198 #[tokio::test]
199 async fn test_upstream_500_error() {
200 let (status, body) = extract_error_body(ApiError(ByokError::Upstream {
201 retry_after: None,
202 status: 500,
203 body: "server error".into(),
204 }))
205 .await;
206 assert_eq!(status, StatusCode::BAD_GATEWAY);
207 assert_eq!(body["error"]["type"], "server_error");
208 assert_eq!(body["error"]["code"], "upstream_error");
209 }
210
211 #[tokio::test]
212 async fn test_http_transport_error() {
213 let (status, body) =
214 extract_error_body(ApiError(ByokError::Http("connection refused".into()))).await;
215 assert_eq!(status, StatusCode::BAD_GATEWAY);
216 assert_eq!(body["error"]["type"], "server_error");
217 assert_eq!(body["error"]["code"], "upstream_error");
218 }
219
220 #[tokio::test]
221 async fn test_internal_error() {
222 let (status, body) =
223 extract_error_body(ApiError(ByokError::Config("bad config".into()))).await;
224 assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
225 assert_eq!(body["error"]["type"], "server_error");
226 assert_eq!(body["error"]["code"], "internal_error");
227 }
228}