use std::time::Duration;
use axum::{
Json,
http::StatusCode,
response::{IntoResponse, Response},
};
use serde::Serialize;
#[derive(Debug, thiserror::Error)]
pub enum ProxyError {
#[error("upstream error: {source}")]
Upstream {
source: String,
#[source]
inner: Option<anyhow::Error>,
},
#[error("protocol conversion error: {0}")]
ProtocolConversion(String),
#[error("no channel available for model '{model}'")]
ChannelSelection {
model: String,
},
#[error("compression error: {0}")]
Compression(String),
#[error("bad request: {0}")]
BadRequest(String),
#[error("unauthorized")]
Unauthorized,
#[error("rate limited, retry after {retry_after:?}")]
RateLimited {
retry_after: Duration,
},
#[error("circuit open for channel '{channel}'")]
CircuitOpen {
channel: String,
},
#[error("internal error: {0}")]
Internal(#[from] anyhow::Error),
}
#[derive(Debug, Serialize)]
pub struct ErrorResponse {
pub error: ErrorBody,
}
#[derive(Debug, Serialize)]
pub struct ErrorBody {
pub code: &'static str,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>,
}
impl ErrorBody {
pub fn new(code: &'static str, message: impl Into<String>) -> Self {
Self {
code,
message: message.into(),
detail: None,
}
}
pub fn with_detail(
code: &'static str,
message: impl Into<String>,
detail: impl Into<String>,
) -> Self {
Self {
code,
message: message.into(),
detail: Some(detail.into()),
}
}
}
impl ProxyError {
#[must_use]
pub fn to_response(&self) -> Response {
let (status, body) = match self {
Self::BadRequest(msg) => (
StatusCode::BAD_REQUEST,
ErrorBody::new("bad_request", msg.clone()),
),
Self::Unauthorized => (
StatusCode::UNAUTHORIZED,
ErrorBody::new("unauthorized", "invalid proxy API key"),
),
Self::RateLimited { retry_after } => {
let secs = retry_after.as_secs_f64();
let mut resp = ErrorBody::new(
"rate_limited",
format!("rate limit exceeded, retry after {secs:.1}s"),
);
resp.detail = Some(format!("retry_after_seconds: {secs:.0}"));
(StatusCode::TOO_MANY_REQUESTS, resp)
}
Self::Upstream { source, .. } if source.contains("429") => (
StatusCode::TOO_MANY_REQUESTS,
ErrorBody::new("upstream_rate_limited", "upstream rate limited"),
),
Self::Upstream { source, .. } => (
StatusCode::BAD_GATEWAY,
ErrorBody::new("upstream_error", source.clone()),
),
Self::ProtocolConversion(msg) => (
StatusCode::BAD_GATEWAY,
ErrorBody::new("protocol_conversion", msg.clone()),
),
Self::ChannelSelection { model } => (
StatusCode::SERVICE_UNAVAILABLE,
ErrorBody::new(
"no_channel",
format!("no channel available for model '{model}'"),
),
),
Self::CircuitOpen { channel } => (
StatusCode::SERVICE_UNAVAILABLE,
ErrorBody::new(
"circuit_open",
format!("circuit breaker open for channel '{channel}'"),
),
),
Self::Compression(msg) => (
StatusCode::INTERNAL_SERVER_ERROR,
ErrorBody::new("compression_error", msg.clone()),
),
Self::Internal(_) => (
StatusCode::INTERNAL_SERVER_ERROR,
ErrorBody::new("internal_error", "internal server error"),
),
};
let mut response = Json(ErrorResponse { error: body }).into_response();
*response.status_mut() = status;
response
}
#[must_use]
pub fn status_code(&self) -> StatusCode {
match self {
Self::BadRequest(_) => StatusCode::BAD_REQUEST,
Self::Unauthorized => StatusCode::UNAUTHORIZED,
Self::RateLimited { .. } => StatusCode::TOO_MANY_REQUESTS,
Self::Upstream { source, .. } if source.contains("429") => {
StatusCode::TOO_MANY_REQUESTS
}
Self::Upstream { .. } | Self::ProtocolConversion(_) => StatusCode::BAD_GATEWAY,
Self::CircuitOpen { .. } | Self::ChannelSelection { .. } => {
StatusCode::SERVICE_UNAVAILABLE
}
Self::Compression(_) | Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
}
}
#[must_use]
pub fn error_code(&self) -> &'static str {
match self {
Self::BadRequest(_) => "bad_request",
Self::Unauthorized => "unauthorized",
Self::RateLimited { .. } => "rate_limited",
Self::Upstream { source, .. } if source.contains("429") => "upstream_rate_limited",
Self::Upstream { .. } => "upstream_error",
Self::ProtocolConversion(_) => "protocol_conversion",
Self::CircuitOpen { .. } => "circuit_open",
Self::ChannelSelection { .. } => "no_channel",
Self::Compression(_) => "compression_error",
Self::Internal(_) => "internal_error",
}
}
}
impl IntoResponse for ProxyError {
fn into_response(self) -> Response {
self.to_response()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use std::time::Duration;
use super::*;
#[test]
fn test_bad_request_status_and_code() {
let err = ProxyError::BadRequest("invalid JSON".into());
assert_eq!(err.status_code(), StatusCode::BAD_REQUEST);
assert_eq!(err.error_code(), "bad_request");
}
#[test]
fn test_unauthorized_status() {
let err = ProxyError::Unauthorized;
assert_eq!(err.status_code(), StatusCode::UNAUTHORIZED);
assert_eq!(err.error_code(), "unauthorized");
}
#[test]
fn test_rate_limited_status() {
let err = ProxyError::RateLimited {
retry_after: Duration::from_secs(5),
};
assert_eq!(err.status_code(), StatusCode::TOO_MANY_REQUESTS);
assert_eq!(err.error_code(), "rate_limited");
}
#[test]
fn test_upstream_429_passthrough() {
let err = ProxyError::Upstream {
source: "upstream 429 too many requests".into(),
inner: None,
};
assert_eq!(err.status_code(), StatusCode::TOO_MANY_REQUESTS);
assert_eq!(err.error_code(), "upstream_rate_limited");
}
#[test]
fn test_upstream_error_status() {
let err = ProxyError::Upstream {
source: "connection refused".into(),
inner: None,
};
assert_eq!(err.status_code(), StatusCode::BAD_GATEWAY);
assert_eq!(err.error_code(), "upstream_error");
}
#[test]
fn test_channel_selection_status() {
let err = ProxyError::ChannelSelection {
model: "gpt-5".into(),
};
assert_eq!(err.status_code(), StatusCode::SERVICE_UNAVAILABLE);
assert_eq!(err.error_code(), "no_channel");
}
#[test]
fn test_internal_error_status() {
let err = ProxyError::Internal(anyhow::anyhow!("db connection failed"));
assert_eq!(err.status_code(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(err.error_code(), "internal_error");
}
#[test]
fn test_error_to_response_returns_json() {
let err = ProxyError::BadRequest("test".into());
let response = err.to_response();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
assert!(
response
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.is_some_and(|v| v.contains("application/json"))
);
}
#[test]
fn test_all_variants_have_distinct_codes() {
let codes = [
ProxyError::BadRequest("x".into()).error_code(),
ProxyError::Unauthorized.error_code(),
ProxyError::RateLimited {
retry_after: Duration::from_secs(1),
}
.error_code(),
ProxyError::Upstream {
source: "timeout".into(),
inner: None,
}
.error_code(),
ProxyError::Upstream {
source: "429".into(),
inner: None,
}
.error_code(),
ProxyError::ProtocolConversion("x".into()).error_code(),
ProxyError::CircuitOpen {
channel: "x".into(),
}
.error_code(),
ProxyError::ChannelSelection { model: "x".into() }.error_code(),
ProxyError::Compression("x".into()).error_code(),
ProxyError::Internal(anyhow::anyhow!("x")).error_code(),
];
assert_ne!(codes[3], codes[4]);
}
#[test]
fn test_internal_from_anyhow() {
let source = anyhow::anyhow!("something broke");
let err = ProxyError::from(source);
assert!(matches!(err, ProxyError::Internal(_)));
}
}