1use std::time::Duration;
4
5use axum::{
6 Json,
7 http::StatusCode,
8 response::{IntoResponse, Response},
9};
10use serde::Serialize;
11
12#[derive(Debug, thiserror::Error)]
17pub enum ProxyError {
18 #[error("upstream error: {source}")]
20 Upstream {
21 source: String,
23 #[source]
25 inner: Option<anyhow::Error>,
26 },
27
28 #[error("protocol conversion error: {0}")]
30 ProtocolConversion(String),
31
32 #[error("no channel available for model '{model}'")]
34 ChannelSelection {
35 model: String,
37 },
38
39 #[error("compression error: {0}")]
41 Compression(String),
42
43 #[error("bad request: {0}")]
45 BadRequest(String),
46
47 #[error("unauthorized")]
49 Unauthorized,
50
51 #[error("rate limited, retry after {retry_after:?}")]
53 RateLimited {
54 retry_after: Duration,
56 },
57
58 #[error("circuit open for channel '{channel}'")]
60 CircuitOpen {
61 channel: String,
63 },
64
65 #[error("internal error: {0}")]
67 Internal(#[from] anyhow::Error),
68}
69
70#[derive(Debug, Serialize)]
72pub struct ErrorResponse {
73 pub error: ErrorBody,
75}
76
77#[derive(Debug, Serialize)]
79pub struct ErrorBody {
80 pub code: &'static str,
82 pub message: String,
84 #[serde(skip_serializing_if = "Option::is_none")]
86 pub detail: Option<String>,
87}
88
89impl ErrorBody {
90 pub fn new(code: &'static str, message: impl Into<String>) -> Self {
92 Self {
93 code,
94 message: message.into(),
95 detail: None,
96 }
97 }
98
99 pub fn with_detail(
101 code: &'static str,
102 message: impl Into<String>,
103 detail: impl Into<String>,
104 ) -> Self {
105 Self {
106 code,
107 message: message.into(),
108 detail: Some(detail.into()),
109 }
110 }
111}
112
113impl ProxyError {
114 #[must_use]
119 pub fn to_response(&self) -> Response {
120 let (status, body) = match self {
121 Self::BadRequest(msg) => (
122 StatusCode::BAD_REQUEST,
123 ErrorBody::new("bad_request", msg.clone()),
124 ),
125 Self::Unauthorized => (
126 StatusCode::UNAUTHORIZED,
127 ErrorBody::new("unauthorized", "invalid proxy API key"),
128 ),
129 Self::RateLimited { retry_after } => {
130 let secs = retry_after.as_secs_f64();
131 let mut resp = ErrorBody::new(
132 "rate_limited",
133 format!("rate limit exceeded, retry after {secs:.1}s"),
134 );
135 resp.detail = Some(format!("retry_after_seconds: {secs:.0}"));
136 (StatusCode::TOO_MANY_REQUESTS, resp)
137 }
138 Self::Upstream { source, .. } if source.contains("429") => (
139 StatusCode::TOO_MANY_REQUESTS,
140 ErrorBody::new("upstream_rate_limited", "upstream rate limited"),
141 ),
142 Self::Upstream { source, .. } => (
143 StatusCode::BAD_GATEWAY,
144 ErrorBody::new("upstream_error", source.clone()),
145 ),
146 Self::ProtocolConversion(msg) => (
147 StatusCode::BAD_GATEWAY,
148 ErrorBody::new("protocol_conversion", msg.clone()),
149 ),
150 Self::ChannelSelection { model } => (
151 StatusCode::SERVICE_UNAVAILABLE,
152 ErrorBody::new(
153 "no_channel",
154 format!("no channel available for model '{model}'"),
155 ),
156 ),
157 Self::CircuitOpen { channel } => (
158 StatusCode::SERVICE_UNAVAILABLE,
159 ErrorBody::new(
160 "circuit_open",
161 format!("circuit breaker open for channel '{channel}'"),
162 ),
163 ),
164 Self::Compression(msg) => (
165 StatusCode::INTERNAL_SERVER_ERROR,
166 ErrorBody::new("compression_error", msg.clone()),
167 ),
168 Self::Internal(_) => (
169 StatusCode::INTERNAL_SERVER_ERROR,
170 ErrorBody::new("internal_error", "internal server error"),
171 ),
172 };
173
174 let mut response = Json(ErrorResponse { error: body }).into_response();
175 *response.status_mut() = status;
176 response
177 }
178
179 #[must_use]
181 pub fn status_code(&self) -> StatusCode {
182 match self {
183 Self::BadRequest(_) => StatusCode::BAD_REQUEST,
184 Self::Unauthorized => StatusCode::UNAUTHORIZED,
185 Self::RateLimited { .. } => StatusCode::TOO_MANY_REQUESTS,
186 Self::Upstream { source, .. } if source.contains("429") => {
187 StatusCode::TOO_MANY_REQUESTS
188 }
189 Self::Upstream { .. } | Self::ProtocolConversion(_) => StatusCode::BAD_GATEWAY,
190 Self::CircuitOpen { .. } | Self::ChannelSelection { .. } => {
191 StatusCode::SERVICE_UNAVAILABLE
192 }
193 Self::Compression(_) | Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
194 }
195 }
196
197 #[must_use]
199 pub fn error_code(&self) -> &'static str {
200 match self {
201 Self::BadRequest(_) => "bad_request",
202 Self::Unauthorized => "unauthorized",
203 Self::RateLimited { .. } => "rate_limited",
204 Self::Upstream { source, .. } if source.contains("429") => "upstream_rate_limited",
205 Self::Upstream { .. } => "upstream_error",
206 Self::ProtocolConversion(_) => "protocol_conversion",
207 Self::CircuitOpen { .. } => "circuit_open",
208 Self::ChannelSelection { .. } => "no_channel",
209 Self::Compression(_) => "compression_error",
210 Self::Internal(_) => "internal_error",
211 }
212 }
213}
214
215impl IntoResponse for ProxyError {
216 fn into_response(self) -> Response {
217 self.to_response()
218 }
219}
220
221#[cfg(test)]
222#[allow(clippy::unwrap_used)]
223mod tests {
224 use std::time::Duration;
225
226 use super::*;
227
228 #[test]
229 fn test_bad_request_status_and_code() {
230 let err = ProxyError::BadRequest("invalid JSON".into());
231 assert_eq!(err.status_code(), StatusCode::BAD_REQUEST);
232 assert_eq!(err.error_code(), "bad_request");
233 }
234
235 #[test]
236 fn test_unauthorized_status() {
237 let err = ProxyError::Unauthorized;
238 assert_eq!(err.status_code(), StatusCode::UNAUTHORIZED);
239 assert_eq!(err.error_code(), "unauthorized");
240 }
241
242 #[test]
243 fn test_rate_limited_status() {
244 let err = ProxyError::RateLimited {
245 retry_after: Duration::from_secs(5),
246 };
247 assert_eq!(err.status_code(), StatusCode::TOO_MANY_REQUESTS);
248 assert_eq!(err.error_code(), "rate_limited");
249 }
250
251 #[test]
252 fn test_upstream_429_passthrough() {
253 let err = ProxyError::Upstream {
254 source: "upstream 429 too many requests".into(),
255 inner: None,
256 };
257 assert_eq!(err.status_code(), StatusCode::TOO_MANY_REQUESTS);
258 assert_eq!(err.error_code(), "upstream_rate_limited");
259 }
260
261 #[test]
262 fn test_upstream_error_status() {
263 let err = ProxyError::Upstream {
264 source: "connection refused".into(),
265 inner: None,
266 };
267 assert_eq!(err.status_code(), StatusCode::BAD_GATEWAY);
268 assert_eq!(err.error_code(), "upstream_error");
269 }
270
271 #[test]
272 fn test_channel_selection_status() {
273 let err = ProxyError::ChannelSelection {
274 model: "gpt-5".into(),
275 };
276 assert_eq!(err.status_code(), StatusCode::SERVICE_UNAVAILABLE);
277 assert_eq!(err.error_code(), "no_channel");
278 }
279
280 #[test]
281 fn test_internal_error_status() {
282 let err = ProxyError::Internal(anyhow::anyhow!("db connection failed"));
283 assert_eq!(err.status_code(), StatusCode::INTERNAL_SERVER_ERROR);
284 assert_eq!(err.error_code(), "internal_error");
285 }
286
287 #[test]
288 fn test_error_to_response_returns_json() {
289 let err = ProxyError::BadRequest("test".into());
290 let response = err.to_response();
291 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
292 assert!(
293 response
294 .headers()
295 .get("content-type")
296 .and_then(|v| v.to_str().ok())
297 .is_some_and(|v| v.contains("application/json"))
298 );
299 }
300
301 #[test]
302 fn test_all_variants_have_distinct_codes() {
303 let codes = [
305 ProxyError::BadRequest("x".into()).error_code(),
306 ProxyError::Unauthorized.error_code(),
307 ProxyError::RateLimited {
308 retry_after: Duration::from_secs(1),
309 }
310 .error_code(),
311 ProxyError::Upstream {
312 source: "timeout".into(),
313 inner: None,
314 }
315 .error_code(),
316 ProxyError::Upstream {
317 source: "429".into(),
318 inner: None,
319 }
320 .error_code(),
321 ProxyError::ProtocolConversion("x".into()).error_code(),
322 ProxyError::CircuitOpen {
323 channel: "x".into(),
324 }
325 .error_code(),
326 ProxyError::ChannelSelection { model: "x".into() }.error_code(),
327 ProxyError::Compression("x".into()).error_code(),
328 ProxyError::Internal(anyhow::anyhow!("x")).error_code(),
329 ];
330 assert_ne!(codes[3], codes[4]);
332 }
333
334 #[test]
335 fn test_internal_from_anyhow() {
336 let source = anyhow::anyhow!("something broke");
337 let err = ProxyError::from(source);
338 assert!(matches!(err, ProxyError::Internal(_)));
339 }
340}