1use axum::{
4 Json,
5 http::StatusCode,
6 response::{IntoResponse, Response},
7};
8use byokey_types::ByokError;
9use serde_json::json;
10
11pub struct ApiError(pub ByokError);
13
14impl ApiError {
15 fn classify(&self) -> (StatusCode, &'static str, &'static str) {
17 match &self.0 {
18 ByokError::Auth(_) => (
19 StatusCode::UNAUTHORIZED,
20 "authentication_error",
21 "invalid_api_key",
22 ),
23 ByokError::TokenNotFound(_) | ByokError::TokenExpired(_) => (
24 StatusCode::UNAUTHORIZED,
25 "authentication_error",
26 "token_not_found",
27 ),
28 ByokError::UnsupportedModel(_) => (
29 StatusCode::BAD_REQUEST,
30 "invalid_request_error",
31 "model_not_found",
32 ),
33 ByokError::Translation(_) => (
34 StatusCode::BAD_REQUEST,
35 "invalid_request_error",
36 "translation_error",
37 ),
38 ByokError::Http(m) => classify_http(m),
39 _ => (
40 StatusCode::INTERNAL_SERVER_ERROR,
41 "server_error",
42 "internal_error",
43 ),
44 }
45 }
46}
47
48fn classify_http(msg: &str) -> (StatusCode, &'static str, &'static str) {
49 if msg.contains("429") {
50 (
51 StatusCode::TOO_MANY_REQUESTS,
52 "rate_limit_error",
53 "rate_limit_exceeded",
54 )
55 } else if msg.contains("401") {
56 (
57 StatusCode::UNAUTHORIZED,
58 "authentication_error",
59 "invalid_api_key",
60 )
61 } else if msg.contains("403") {
62 (
63 StatusCode::FORBIDDEN,
64 "permission_error",
65 "insufficient_quota",
66 )
67 } else {
68 (StatusCode::BAD_GATEWAY, "server_error", "upstream_error")
69 }
70}
71
72impl IntoResponse for ApiError {
73 fn into_response(self) -> Response {
74 let (status, error_type, error_code) = self.classify();
75 let msg = self.0.to_string();
76 (
77 status,
78 Json(json!({
79 "error": {
80 "message": msg,
81 "type": error_type,
82 "code": error_code,
83 }
84 })),
85 )
86 .into_response()
87 }
88}
89
90impl From<ByokError> for ApiError {
91 fn from(e: ByokError) -> Self {
92 Self(e)
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99 use byokey_types::ProviderId;
100 use http_body_util::BodyExt as _;
101
102 async fn extract_error_body(err: ApiError) -> (StatusCode, serde_json::Value) {
103 let resp = err.into_response();
104 let status = resp.status();
105 let bytes = resp.into_body().collect().await.unwrap().to_bytes();
106 let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
107 (status, body)
108 }
109
110 #[tokio::test]
111 async fn test_auth_error() {
112 let (status, body) =
113 extract_error_body(ApiError(ByokError::Auth("bad creds".into()))).await;
114 assert_eq!(status, StatusCode::UNAUTHORIZED);
115 assert_eq!(body["error"]["type"], "authentication_error");
116 assert_eq!(body["error"]["code"], "invalid_api_key");
117 }
118
119 #[tokio::test]
120 async fn test_token_not_found_error() {
121 let (status, body) =
122 extract_error_body(ApiError(ByokError::TokenNotFound(ProviderId::Claude))).await;
123 assert_eq!(status, StatusCode::UNAUTHORIZED);
124 assert_eq!(body["error"]["type"], "authentication_error");
125 assert_eq!(body["error"]["code"], "token_not_found");
126 }
127
128 #[tokio::test]
129 async fn test_unsupported_model_error() {
130 let (status, body) =
131 extract_error_body(ApiError(ByokError::UnsupportedModel("xyz".into()))).await;
132 assert_eq!(status, StatusCode::BAD_REQUEST);
133 assert_eq!(body["error"]["type"], "invalid_request_error");
134 assert_eq!(body["error"]["code"], "model_not_found");
135 }
136
137 #[tokio::test]
138 async fn test_translation_error() {
139 let (status, body) =
140 extract_error_body(ApiError(ByokError::Translation("bad format".into()))).await;
141 assert_eq!(status, StatusCode::BAD_REQUEST);
142 assert_eq!(body["error"]["type"], "invalid_request_error");
143 assert_eq!(body["error"]["code"], "translation_error");
144 }
145
146 #[tokio::test]
147 async fn test_http_429_error() {
148 let (status, body) =
149 extract_error_body(ApiError(ByokError::Http("status 429 too many".into()))).await;
150 assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
151 assert_eq!(body["error"]["type"], "rate_limit_error");
152 assert_eq!(body["error"]["code"], "rate_limit_exceeded");
153 }
154
155 #[tokio::test]
156 async fn test_http_401_error() {
157 let (status, body) =
158 extract_error_body(ApiError(ByokError::Http("status 401 unauthorized".into()))).await;
159 assert_eq!(status, StatusCode::UNAUTHORIZED);
160 assert_eq!(body["error"]["type"], "authentication_error");
161 assert_eq!(body["error"]["code"], "invalid_api_key");
162 }
163
164 #[tokio::test]
165 async fn test_http_403_error() {
166 let (status, body) =
167 extract_error_body(ApiError(ByokError::Http("status 403 forbidden".into()))).await;
168 assert_eq!(status, StatusCode::FORBIDDEN);
169 assert_eq!(body["error"]["type"], "permission_error");
170 assert_eq!(body["error"]["code"], "insufficient_quota");
171 }
172
173 #[tokio::test]
174 async fn test_http_other_error() {
175 let (status, body) =
176 extract_error_body(ApiError(ByokError::Http("upstream error".into()))).await;
177 assert_eq!(status, StatusCode::BAD_GATEWAY);
178 assert_eq!(body["error"]["type"], "server_error");
179 assert_eq!(body["error"]["code"], "upstream_error");
180 }
181
182 #[tokio::test]
183 async fn test_internal_error() {
184 let (status, body) =
185 extract_error_body(ApiError(ByokError::Config("bad config".into()))).await;
186 assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
187 assert_eq!(body["error"]["type"], "server_error");
188 assert_eq!(body["error"]["code"], "internal_error");
189 }
190}