1use axum::http::StatusCode;
2use axum::response::{IntoResponse, Response};
3use serde::Serialize;
4
5#[derive(Debug, Serialize)]
9#[cfg_attr(feature = "schema", derive(utoipa::ToSchema))]
10pub struct ErrorResponse {
11 pub error: ErrorInfo,
12}
13
14#[derive(Debug, Serialize)]
16#[cfg_attr(feature = "schema", derive(utoipa::ToSchema))]
17pub struct ErrorInfo {
18 pub code: &'static str,
20 pub message: String,
22}
23
24pub trait ApiError: std::fmt::Display {
32 fn status_code(&self) -> StatusCode;
34
35 fn error_code(&self) -> &'static str;
37
38 fn retry_after_secs(&self) -> Option<u64> {
40 None
41 }
42}
43
44pub fn into_error_response(err: &impl ApiError) -> Response {
54 let status = err.status_code();
55
56 if status.is_server_error() {
57 tracing::error!(error = %err, "internal server error");
58 } else if status.is_client_error() {
59 tracing::warn!(error = %err, "client error");
60 }
61
62 let body = ErrorResponse {
63 error: ErrorInfo {
64 code: err.error_code(),
65 message: err.to_string(),
66 },
67 };
68
69 let mut response = (status, axum::Json(body)).into_response();
70
71 if let Some(secs) = err.retry_after_secs() {
72 response.headers_mut().insert(
73 axum::http::header::RETRY_AFTER,
74 axum::http::HeaderValue::from(secs),
75 );
76 }
77
78 response
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84 use axum::body::to_bytes;
85
86 #[derive(Debug)]
87 enum TestError {
88 BadInput(String),
89 RateLimited { retry_after: u64 },
90 Internal(String),
91 }
92
93 impl std::fmt::Display for TestError {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 match self {
96 Self::BadInput(msg) => write!(f, "bad input: {msg}"),
97 Self::RateLimited { .. } => write!(f, "rate limited"),
98 Self::Internal(msg) => write!(f, "internal error: {msg}"),
99 }
100 }
101 }
102
103 impl ApiError for TestError {
104 fn status_code(&self) -> StatusCode {
105 match self {
106 Self::BadInput(_) => StatusCode::BAD_REQUEST,
107 Self::RateLimited { .. } => StatusCode::TOO_MANY_REQUESTS,
108 Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
109 }
110 }
111
112 fn error_code(&self) -> &'static str {
113 match self {
114 Self::BadInput(_) => "BAD_INPUT",
115 Self::RateLimited { .. } => "RATE_LIMITED",
116 Self::Internal(_) => "INTERNAL_ERROR",
117 }
118 }
119
120 fn retry_after_secs(&self) -> Option<u64> {
121 match self {
122 Self::RateLimited { retry_after } => Some(*retry_after),
123 _ => None,
124 }
125 }
126 }
127
128 async fn body_json(err: TestError) -> serde_json::Value {
129 let response = into_error_response(&err);
130 let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
131 serde_json::from_slice(&bytes).unwrap()
132 }
133
134 async fn into_parts(
135 err: TestError,
136 ) -> (StatusCode, axum::http::HeaderMap, serde_json::Value) {
137 let response = into_error_response(&err);
138 let status = response.status();
139 let headers = response.headers().clone();
140 let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
141 let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
142 (status, headers, body)
143 }
144
145 #[tokio::test]
146 async fn bad_input_is_400() {
147 let response = into_error_response(&TestError::BadInput("oops".into()));
148 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
149 }
150
151 #[tokio::test]
152 async fn rate_limited_is_429() {
153 let response = into_error_response(&TestError::RateLimited { retry_after: 5 });
154 assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
155 }
156
157 #[tokio::test]
158 async fn internal_is_500() {
159 let response = into_error_response(&TestError::Internal("boom".into()));
160 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
161 }
162
163 #[tokio::test]
164 async fn body_has_error_code_and_message() {
165 let body = body_json(TestError::BadInput("test".into())).await;
166 assert_eq!(body["error"]["code"], "BAD_INPUT");
167 assert!(body["error"]["message"].as_str().unwrap().contains("test"));
168 assert_eq!(body.as_object().unwrap().len(), 1);
169 }
170
171 #[tokio::test]
172 async fn rate_limited_includes_retry_after_header() {
173 let (status, headers, _) =
174 into_parts(TestError::RateLimited { retry_after: 42 }).await;
175 assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
176 let retry_after = headers
177 .get(axum::http::header::RETRY_AFTER)
178 .expect("Retry-After header must be present");
179 let value: u64 = retry_after.to_str().unwrap().parse().unwrap();
180 assert_eq!(value, 42);
181 }
182
183 #[tokio::test]
184 async fn non_rate_limited_has_no_retry_after() {
185 let (_, headers, _) = into_parts(TestError::BadInput("x".into())).await;
186 assert!(headers.get(axum::http::header::RETRY_AFTER).is_none());
187 }
188
189 #[tokio::test]
190 async fn error_response_has_json_content_type() {
191 let response = into_error_response(&TestError::BadInput("test".into()));
192 let ct = response
193 .headers()
194 .get(axum::http::header::CONTENT_TYPE)
195 .expect("Content-Type header must be present")
196 .to_str()
197 .unwrap();
198 assert!(
199 ct.contains("application/json"),
200 "expected application/json, got {ct}"
201 );
202 }
203}