1use axum::http::StatusCode;
2use axum::response::{IntoResponse, Response};
3use serde::Serialize;
4
5#[derive(Debug, Serialize)]
9pub struct ErrorResponse {
10 pub error: ErrorInfo,
11}
12
13#[derive(Debug, Serialize)]
15pub struct ErrorInfo {
16 pub code: &'static str,
18 pub message: String,
20}
21
22pub trait ApiError: std::fmt::Display {
30 fn status_code(&self) -> StatusCode;
32
33 fn error_code(&self) -> &'static str;
35
36 fn retry_after_secs(&self) -> Option<u64> {
38 None
39 }
40}
41
42pub fn into_error_response(err: &impl ApiError) -> Response {
52 let status = err.status_code();
53
54 if status.is_server_error() {
55 tracing::error!(error = %err, "internal server error");
56 } else if status.is_client_error() {
57 tracing::warn!(error = %err, "client error");
58 }
59
60 let body = ErrorResponse {
61 error: ErrorInfo {
62 code: err.error_code(),
63 message: err.to_string(),
64 },
65 };
66
67 let mut response = (status, axum::Json(body)).into_response();
68
69 if let Some(secs) = err.retry_after_secs() {
70 response.headers_mut().insert(
71 axum::http::header::RETRY_AFTER,
72 axum::http::HeaderValue::from(secs),
73 );
74 }
75
76 response
77}
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82 use axum::body::to_bytes;
83
84 #[derive(Debug)]
85 enum TestError {
86 BadInput(String),
87 RateLimited { retry_after: u64 },
88 Internal(String),
89 }
90
91 impl std::fmt::Display for TestError {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 match self {
94 Self::BadInput(msg) => write!(f, "bad input: {msg}"),
95 Self::RateLimited { .. } => write!(f, "rate limited"),
96 Self::Internal(msg) => write!(f, "internal error: {msg}"),
97 }
98 }
99 }
100
101 impl ApiError for TestError {
102 fn status_code(&self) -> StatusCode {
103 match self {
104 Self::BadInput(_) => StatusCode::BAD_REQUEST,
105 Self::RateLimited { .. } => StatusCode::TOO_MANY_REQUESTS,
106 Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
107 }
108 }
109
110 fn error_code(&self) -> &'static str {
111 match self {
112 Self::BadInput(_) => "BAD_INPUT",
113 Self::RateLimited { .. } => "RATE_LIMITED",
114 Self::Internal(_) => "INTERNAL_ERROR",
115 }
116 }
117
118 fn retry_after_secs(&self) -> Option<u64> {
119 match self {
120 Self::RateLimited { retry_after } => Some(*retry_after),
121 _ => None,
122 }
123 }
124 }
125
126 async fn body_json(err: TestError) -> serde_json::Value {
127 let response = into_error_response(&err);
128 let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
129 serde_json::from_slice(&bytes).unwrap()
130 }
131
132 async fn into_parts(
133 err: TestError,
134 ) -> (StatusCode, axum::http::HeaderMap, serde_json::Value) {
135 let response = into_error_response(&err);
136 let status = response.status();
137 let headers = response.headers().clone();
138 let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
139 let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
140 (status, headers, body)
141 }
142
143 #[tokio::test]
144 async fn bad_input_is_400() {
145 let response = into_error_response(&TestError::BadInput("oops".into()));
146 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
147 }
148
149 #[tokio::test]
150 async fn rate_limited_is_429() {
151 let response = into_error_response(&TestError::RateLimited { retry_after: 5 });
152 assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
153 }
154
155 #[tokio::test]
156 async fn internal_is_500() {
157 let response = into_error_response(&TestError::Internal("boom".into()));
158 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
159 }
160
161 #[tokio::test]
162 async fn body_has_error_code_and_message() {
163 let body = body_json(TestError::BadInput("test".into())).await;
164 assert_eq!(body["error"]["code"], "BAD_INPUT");
165 assert!(body["error"]["message"].as_str().unwrap().contains("test"));
166 assert_eq!(body.as_object().unwrap().len(), 1);
167 }
168
169 #[tokio::test]
170 async fn rate_limited_includes_retry_after_header() {
171 let (status, headers, _) =
172 into_parts(TestError::RateLimited { retry_after: 42 }).await;
173 assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
174 let retry_after = headers
175 .get(axum::http::header::RETRY_AFTER)
176 .expect("Retry-After header must be present");
177 let value: u64 = retry_after.to_str().unwrap().parse().unwrap();
178 assert_eq!(value, 42);
179 }
180
181 #[tokio::test]
182 async fn non_rate_limited_has_no_retry_after() {
183 let (_, headers, _) = into_parts(TestError::BadInput("x".into())).await;
184 assert!(headers.get(axum::http::header::RETRY_AFTER).is_none());
185 }
186
187 #[tokio::test]
188 async fn error_response_has_json_content_type() {
189 let response = into_error_response(&TestError::BadInput("test".into()));
190 let ct = response
191 .headers()
192 .get(axum::http::header::CONTENT_TYPE)
193 .expect("Content-Type header must be present")
194 .to_str()
195 .unwrap();
196 assert!(
197 ct.contains("application/json"),
198 "expected application/json, got {ct}"
199 );
200 }
201}