1use axum::{
4 Json,
5 http::StatusCode,
6 response::{IntoResponse, Response},
7};
8use byokey_types::ByokError;
9use serde_json::json;
10
11pub struct ApiError(pub ByokError);
13
14impl ApiError {
15 fn classify(&self) -> (StatusCode, &'static str, &'static str) {
17 match &self.0 {
18 ByokError::Auth(_) => (
19 StatusCode::UNAUTHORIZED,
20 "authentication_error",
21 "invalid_api_key",
22 ),
23 ByokError::TokenNotFound(_) | ByokError::TokenExpired(_) => (
24 StatusCode::UNAUTHORIZED,
25 "authentication_error",
26 "token_not_found",
27 ),
28 ByokError::UnsupportedModel(_) => (
29 StatusCode::BAD_REQUEST,
30 "invalid_request_error",
31 "model_not_found",
32 ),
33 ByokError::Translation(_) => (
34 StatusCode::BAD_REQUEST,
35 "invalid_request_error",
36 "translation_error",
37 ),
38 ByokError::Upstream { status, .. } => classify_upstream(*status),
39 ByokError::Http(_) => (StatusCode::BAD_GATEWAY, "server_error", "upstream_error"),
40 _ => (
41 StatusCode::INTERNAL_SERVER_ERROR,
42 "server_error",
43 "internal_error",
44 ),
45 }
46 }
47}
48
49fn classify_upstream(status: u16) -> (StatusCode, &'static str, &'static str) {
50 match status {
51 429 => (
52 StatusCode::TOO_MANY_REQUESTS,
53 "rate_limit_error",
54 "rate_limit_exceeded",
55 ),
56 401 => (
57 StatusCode::UNAUTHORIZED,
58 "authentication_error",
59 "invalid_api_key",
60 ),
61 403 => (
62 StatusCode::FORBIDDEN,
63 "permission_error",
64 "insufficient_quota",
65 ),
66 _ => (StatusCode::BAD_GATEWAY, "server_error", "upstream_error"),
67 }
68}
69
70impl IntoResponse for ApiError {
71 fn into_response(self) -> Response {
72 let (status, error_type, error_code) = self.classify();
73 let msg = self.0.to_string();
74 (
75 status,
76 Json(json!({
77 "error": {
78 "message": msg,
79 "type": error_type,
80 "code": error_code,
81 }
82 })),
83 )
84 .into_response()
85 }
86}
87
88impl From<ByokError> for ApiError {
89 fn from(e: ByokError) -> Self {
90 Self(e)
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97 use byokey_types::ProviderId;
98 use http_body_util::BodyExt as _;
99
100 async fn extract_error_body(err: ApiError) -> (StatusCode, serde_json::Value) {
101 let resp = err.into_response();
102 let status = resp.status();
103 let bytes = resp.into_body().collect().await.unwrap().to_bytes();
104 let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
105 (status, body)
106 }
107
108 #[tokio::test]
109 async fn test_auth_error() {
110 let (status, body) =
111 extract_error_body(ApiError(ByokError::Auth("bad creds".into()))).await;
112 assert_eq!(status, StatusCode::UNAUTHORIZED);
113 assert_eq!(body["error"]["type"], "authentication_error");
114 assert_eq!(body["error"]["code"], "invalid_api_key");
115 }
116
117 #[tokio::test]
118 async fn test_token_not_found_error() {
119 let (status, body) =
120 extract_error_body(ApiError(ByokError::TokenNotFound(ProviderId::Claude))).await;
121 assert_eq!(status, StatusCode::UNAUTHORIZED);
122 assert_eq!(body["error"]["type"], "authentication_error");
123 assert_eq!(body["error"]["code"], "token_not_found");
124 }
125
126 #[tokio::test]
127 async fn test_unsupported_model_error() {
128 let (status, body) =
129 extract_error_body(ApiError(ByokError::UnsupportedModel("xyz".into()))).await;
130 assert_eq!(status, StatusCode::BAD_REQUEST);
131 assert_eq!(body["error"]["type"], "invalid_request_error");
132 assert_eq!(body["error"]["code"], "model_not_found");
133 }
134
135 #[tokio::test]
136 async fn test_translation_error() {
137 let (status, body) =
138 extract_error_body(ApiError(ByokError::Translation("bad format".into()))).await;
139 assert_eq!(status, StatusCode::BAD_REQUEST);
140 assert_eq!(body["error"]["type"], "invalid_request_error");
141 assert_eq!(body["error"]["code"], "translation_error");
142 }
143
144 #[tokio::test]
145 async fn test_upstream_429_error() {
146 let (status, body) = extract_error_body(ApiError(ByokError::Upstream {
147 status: 429,
148 body: "rate limited".into(),
149 }))
150 .await;
151 assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
152 assert_eq!(body["error"]["type"], "rate_limit_error");
153 assert_eq!(body["error"]["code"], "rate_limit_exceeded");
154 }
155
156 #[tokio::test]
157 async fn test_upstream_401_error() {
158 let (status, body) = extract_error_body(ApiError(ByokError::Upstream {
159 status: 401,
160 body: "unauthorized".into(),
161 }))
162 .await;
163 assert_eq!(status, StatusCode::UNAUTHORIZED);
164 assert_eq!(body["error"]["type"], "authentication_error");
165 assert_eq!(body["error"]["code"], "invalid_api_key");
166 }
167
168 #[tokio::test]
169 async fn test_upstream_403_error() {
170 let (status, body) = extract_error_body(ApiError(ByokError::Upstream {
171 status: 403,
172 body: "forbidden".into(),
173 }))
174 .await;
175 assert_eq!(status, StatusCode::FORBIDDEN);
176 assert_eq!(body["error"]["type"], "permission_error");
177 assert_eq!(body["error"]["code"], "insufficient_quota");
178 }
179
180 #[tokio::test]
181 async fn test_upstream_500_error() {
182 let (status, body) = extract_error_body(ApiError(ByokError::Upstream {
183 status: 500,
184 body: "server error".into(),
185 }))
186 .await;
187 assert_eq!(status, StatusCode::BAD_GATEWAY);
188 assert_eq!(body["error"]["type"], "server_error");
189 assert_eq!(body["error"]["code"], "upstream_error");
190 }
191
192 #[tokio::test]
193 async fn test_http_transport_error() {
194 let (status, body) =
195 extract_error_body(ApiError(ByokError::Http("connection refused".into()))).await;
196 assert_eq!(status, StatusCode::BAD_GATEWAY);
197 assert_eq!(body["error"]["type"], "server_error");
198 assert_eq!(body["error"]["code"], "upstream_error");
199 }
200
201 #[tokio::test]
202 async fn test_internal_error() {
203 let (status, body) =
204 extract_error_body(ApiError(ByokError::Config("bad config".into()))).await;
205 assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
206 assert_eq!(body["error"]["type"], "server_error");
207 assert_eq!(body["error"]["code"], "internal_error");
208 }
209}