1use axum::http::StatusCode;
2use axum::response::{IntoResponse, Response};
3use serde::Serialize;
4
5#[derive(Debug, Serialize)]
9#[cfg_attr(feature = "schema", derive(utoipa::ToSchema))]
10pub struct ErrorResponse {
11 pub error: ErrorInfo,
12}
13
14#[derive(Debug, Serialize)]
16#[cfg_attr(feature = "schema", derive(utoipa::ToSchema))]
17pub struct ErrorInfo {
18 pub code: &'static str,
20 pub message: String,
22}
23
24pub trait ApiError: std::fmt::Display {
32 fn status_code(&self) -> StatusCode;
34
35 fn error_code(&self) -> &'static str;
37
38 fn retry_after_secs(&self) -> Option<u64> {
40 None
41 }
42}
43
44pub fn into_error_response(err: &impl ApiError) -> Response {
54 let status = err.status_code();
55
56 if status.is_server_error() {
57 tracing::error!(error = %err, "internal server error");
58 } else if status.is_client_error() {
59 tracing::warn!(error = %err, "client error");
60 }
61
62 build_error_response(err)
63}
64
65pub fn build_error_response(err: &impl ApiError) -> Response {
70 let status = err.status_code();
71
72 let body = ErrorResponse {
73 error: ErrorInfo {
74 code: err.error_code(),
75 message: err.to_string(),
76 },
77 };
78
79 let mut response = (status, axum::Json(body)).into_response();
80
81 if let Some(secs) = err.retry_after_secs() {
82 response.headers_mut().insert(
83 axum::http::header::RETRY_AFTER,
84 axum::http::HeaderValue::from(secs),
85 );
86 }
87
88 response
89}
90
91#[cfg(test)]
92mod tests {
93 use super::*;
94 use axum::body::to_bytes;
95
96 #[derive(Debug)]
97 enum TestError {
98 BadInput(String),
99 RateLimited { retry_after: u64 },
100 Internal(String),
101 }
102
103 impl std::fmt::Display for TestError {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 match self {
106 Self::BadInput(msg) => write!(f, "bad input: {msg}"),
107 Self::RateLimited { .. } => write!(f, "rate limited"),
108 Self::Internal(msg) => write!(f, "internal error: {msg}"),
109 }
110 }
111 }
112
113 impl ApiError for TestError {
114 fn status_code(&self) -> StatusCode {
115 match self {
116 Self::BadInput(_) => StatusCode::BAD_REQUEST,
117 Self::RateLimited { .. } => StatusCode::TOO_MANY_REQUESTS,
118 Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
119 }
120 }
121
122 fn error_code(&self) -> &'static str {
123 match self {
124 Self::BadInput(_) => "BAD_INPUT",
125 Self::RateLimited { .. } => "RATE_LIMITED",
126 Self::Internal(_) => "INTERNAL_ERROR",
127 }
128 }
129
130 fn retry_after_secs(&self) -> Option<u64> {
131 match self {
132 Self::RateLimited { retry_after } => Some(*retry_after),
133 _ => None,
134 }
135 }
136 }
137
138 async fn body_json(err: TestError) -> serde_json::Value {
139 let response = into_error_response(&err);
140 let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
141 serde_json::from_slice(&bytes).unwrap()
142 }
143
144 async fn into_parts(err: TestError) -> (StatusCode, axum::http::HeaderMap, serde_json::Value) {
145 let response = into_error_response(&err);
146 let status = response.status();
147 let headers = response.headers().clone();
148 let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
149 let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
150 (status, headers, body)
151 }
152
153 #[tokio::test]
154 async fn bad_input_is_400() {
155 let response = into_error_response(&TestError::BadInput("oops".into()));
156 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
157 }
158
159 #[tokio::test]
160 async fn rate_limited_is_429() {
161 let response = into_error_response(&TestError::RateLimited { retry_after: 5 });
162 assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
163 }
164
165 #[tokio::test]
166 async fn internal_is_500() {
167 let response = into_error_response(&TestError::Internal("boom".into()));
168 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
169 }
170
171 #[tokio::test]
172 async fn body_has_error_code_and_message() {
173 let body = body_json(TestError::BadInput("test".into())).await;
174 assert_eq!(body["error"]["code"], "BAD_INPUT");
175 assert!(body["error"]["message"].as_str().unwrap().contains("test"));
176 assert_eq!(body.as_object().unwrap().len(), 1);
177 }
178
179 #[tokio::test]
180 async fn rate_limited_includes_retry_after_header() {
181 let (status, headers, _) = into_parts(TestError::RateLimited { retry_after: 42 }).await;
182 assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
183 let retry_after = headers
184 .get(axum::http::header::RETRY_AFTER)
185 .expect("Retry-After header must be present");
186 let value: u64 = retry_after.to_str().unwrap().parse().unwrap();
187 assert_eq!(value, 42);
188 }
189
190 #[tokio::test]
191 async fn non_rate_limited_has_no_retry_after() {
192 let (_, headers, _) = into_parts(TestError::BadInput("x".into())).await;
193 assert!(headers.get(axum::http::header::RETRY_AFTER).is_none());
194 }
195
196 #[tokio::test]
197 async fn error_response_has_json_content_type() {
198 let response = into_error_response(&TestError::BadInput("test".into()));
199 let ct = response
200 .headers()
201 .get(axum::http::header::CONTENT_TYPE)
202 .expect("Content-Type header must be present")
203 .to_str()
204 .unwrap();
205 assert!(
206 ct.contains("application/json"),
207 "expected application/json, got {ct}"
208 );
209 }
210}