1use axum::http::StatusCode;
2use axum::response::{IntoResponse, Response};
3use serde::Serialize;
4
5#[derive(Debug, Serialize)]
9pub struct ErrorResponse {
10 pub error: ErrorInfo,
11}
12
13#[derive(Debug, Serialize)]
15pub struct ErrorInfo {
16 pub code: &'static str,
18 pub message: String,
20}
21
22pub trait ApiError: std::fmt::Display {
30 fn status_code(&self) -> StatusCode;
32
33 fn error_code(&self) -> &'static str;
35
36 fn retry_after_secs(&self) -> Option<u64> {
38 None
39 }
40}
41
42pub fn into_error_response(err: &(impl ApiError + std::fmt::Display)) -> Response {
52 let status = err.status_code();
53
54 if status == StatusCode::INTERNAL_SERVER_ERROR {
55 tracing::error!(error = %err, "internal server error");
56 }
57
58 let body = ErrorResponse {
59 error: ErrorInfo {
60 code: err.error_code(),
61 message: err.to_string(),
62 },
63 };
64
65 let mut response = (status, axum::Json(body)).into_response();
66
67 if let Some(secs) = err.retry_after_secs() {
68 response.headers_mut().insert(
69 axum::http::header::RETRY_AFTER,
70 axum::http::HeaderValue::from(secs),
71 );
72 }
73
74 response
75}
76
77#[cfg(test)]
78mod tests {
79 use super::*;
80 use axum::body::to_bytes;
81
82 #[derive(Debug)]
83 enum TestError {
84 BadInput(String),
85 RateLimited { retry_after: u64 },
86 Internal(String),
87 }
88
89 impl std::fmt::Display for TestError {
90 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91 match self {
92 Self::BadInput(msg) => write!(f, "bad input: {msg}"),
93 Self::RateLimited { .. } => write!(f, "rate limited"),
94 Self::Internal(msg) => write!(f, "internal error: {msg}"),
95 }
96 }
97 }
98
99 impl ApiError for TestError {
100 fn status_code(&self) -> StatusCode {
101 match self {
102 Self::BadInput(_) => StatusCode::BAD_REQUEST,
103 Self::RateLimited { .. } => StatusCode::TOO_MANY_REQUESTS,
104 Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
105 }
106 }
107
108 fn error_code(&self) -> &'static str {
109 match self {
110 Self::BadInput(_) => "BAD_INPUT",
111 Self::RateLimited { .. } => "RATE_LIMITED",
112 Self::Internal(_) => "INTERNAL_ERROR",
113 }
114 }
115
116 fn retry_after_secs(&self) -> Option<u64> {
117 match self {
118 Self::RateLimited { retry_after } => Some(*retry_after),
119 _ => None,
120 }
121 }
122 }
123
124 async fn body_json(err: TestError) -> serde_json::Value {
125 let response = into_error_response(&err);
126 let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
127 serde_json::from_slice(&bytes).unwrap()
128 }
129
130 async fn into_parts(
131 err: TestError,
132 ) -> (StatusCode, axum::http::HeaderMap, serde_json::Value) {
133 let response = into_error_response(&err);
134 let status = response.status();
135 let headers = response.headers().clone();
136 let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
137 let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
138 (status, headers, body)
139 }
140
141 #[tokio::test]
142 async fn bad_input_is_400() {
143 let response = into_error_response(&TestError::BadInput("oops".into()));
144 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
145 }
146
147 #[tokio::test]
148 async fn rate_limited_is_429() {
149 let response = into_error_response(&TestError::RateLimited { retry_after: 5 });
150 assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
151 }
152
153 #[tokio::test]
154 async fn internal_is_500() {
155 let response = into_error_response(&TestError::Internal("boom".into()));
156 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
157 }
158
159 #[tokio::test]
160 async fn body_has_error_code_and_message() {
161 let body = body_json(TestError::BadInput("test".into())).await;
162 assert_eq!(body["error"]["code"], "BAD_INPUT");
163 assert!(body["error"]["message"].as_str().unwrap().contains("test"));
164 assert_eq!(body.as_object().unwrap().len(), 1);
165 }
166
167 #[tokio::test]
168 async fn rate_limited_includes_retry_after_header() {
169 let (status, headers, _) =
170 into_parts(TestError::RateLimited { retry_after: 42 }).await;
171 assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
172 let retry_after = headers
173 .get(axum::http::header::RETRY_AFTER)
174 .expect("Retry-After header must be present");
175 let value: u64 = retry_after.to_str().unwrap().parse().unwrap();
176 assert_eq!(value, 42);
177 }
178
179 #[tokio::test]
180 async fn non_rate_limited_has_no_retry_after() {
181 let (_, headers, _) = into_parts(TestError::BadInput("x".into())).await;
182 assert!(headers.get(axum::http::header::RETRY_AFTER).is_none());
183 }
184}