api_tools/server/axum/
response.rs

1//! API response module
2
3use axum::Json;
4use axum::http::StatusCode;
5use axum::response::{IntoResponse, Response};
6use opentelemetry::TraceId;
7use opentelemetry::trace::TraceContextExt;
8use serde::Serialize;
9use thiserror::Error;
10use tracing_opentelemetry::OpenTelemetrySpanExt;
11
12/// API response success
13#[derive(Debug, Clone)]
14pub struct ApiSuccess<T: Serialize + PartialEq>(StatusCode, Json<T>);
15
16impl<T> PartialEq for ApiSuccess<T>
17where
18    T: Serialize + PartialEq,
19{
20    fn eq(&self, other: &Self) -> bool {
21        self.0 == other.0 && self.1.0 == other.1.0
22    }
23}
24
25impl<T: Serialize + PartialEq> ApiSuccess<T> {
26    pub fn new(status: StatusCode, data: T) -> Self {
27        ApiSuccess(status, Json(data))
28    }
29}
30
31impl<T: Serialize + PartialEq> IntoResponse for ApiSuccess<T> {
32    fn into_response(self) -> Response {
33        (self.0, self.1).into_response()
34    }
35}
36
37/// Generic response structure shared by all API responses.
38#[derive(Debug, Clone, PartialEq, Serialize)]
39pub(crate) struct ApiErrorResponse<T: Serialize + PartialEq> {
40    code: u16,
41    message: T,
42    #[serde(skip_serializing_if = "Option::is_none")]
43    trace_id: Option<String>,
44}
45
46impl<T: Serialize + PartialEq> ApiErrorResponse<T> {
47    pub(crate) fn new(status_code: StatusCode, message: T, trace_id: Option<String>) -> Self {
48        Self {
49            code: status_code.as_u16(),
50            message,
51            trace_id,
52        }
53    }
54}
55
56/// API error
57#[derive(Debug, Clone, PartialEq, Error)]
58pub enum ApiError {
59    #[error("Bad request: {0}")]
60    BadRequest(String),
61
62    #[error("Unauthorized: {0}")]
63    Unauthorized(String),
64
65    #[error("Forbidden: {0}")]
66    Forbidden(String),
67
68    #[error("Not found: {0}")]
69    NotFound(String),
70
71    #[error("Unprocessable entity: {0}")]
72    UnprocessableEntity(String),
73
74    #[error("Internal server error: {0}")]
75    InternalServerError(String),
76
77    #[error("Timeout")]
78    Timeout,
79
80    #[error("Too many requests")]
81    TooManyRequests,
82
83    #[error("Method not allowed")]
84    MethodNotAllowed,
85
86    #[error("Payload too large")]
87    PayloadTooLarge,
88
89    #[error("Service unavailable")]
90    ServiceUnavailable,
91}
92
93impl ApiError {
94    fn response(code: StatusCode, message: &str) -> impl IntoResponse + '_ {
95        let ctx = tracing::Span::current().context();
96        let trace_id = ctx.span().span_context().trace_id();
97        let trace_id = if trace_id == TraceId::INVALID {
98            None
99        } else {
100            Some(trace_id.to_string())
101        };
102
103        match code {
104            StatusCode::REQUEST_TIMEOUT => (
105                StatusCode::REQUEST_TIMEOUT,
106                Json(ApiErrorResponse::new(StatusCode::REQUEST_TIMEOUT, message, trace_id)),
107            ),
108            StatusCode::TOO_MANY_REQUESTS => (
109                StatusCode::TOO_MANY_REQUESTS,
110                Json(ApiErrorResponse::new(StatusCode::TOO_MANY_REQUESTS, message, trace_id)),
111            ),
112            StatusCode::METHOD_NOT_ALLOWED => (
113                StatusCode::METHOD_NOT_ALLOWED,
114                Json(ApiErrorResponse::new(StatusCode::METHOD_NOT_ALLOWED, message, trace_id)),
115            ),
116            StatusCode::PAYLOAD_TOO_LARGE => (
117                StatusCode::PAYLOAD_TOO_LARGE,
118                Json(ApiErrorResponse::new(StatusCode::PAYLOAD_TOO_LARGE, message, trace_id)),
119            ),
120            StatusCode::BAD_REQUEST => (
121                StatusCode::BAD_REQUEST,
122                Json(ApiErrorResponse::new(StatusCode::BAD_REQUEST, message, trace_id)),
123            ),
124            StatusCode::UNAUTHORIZED => (
125                StatusCode::UNAUTHORIZED,
126                Json(ApiErrorResponse::new(StatusCode::UNAUTHORIZED, message, None)),
127            ),
128            StatusCode::FORBIDDEN => (
129                StatusCode::FORBIDDEN,
130                Json(ApiErrorResponse::new(StatusCode::FORBIDDEN, message, trace_id)),
131            ),
132            StatusCode::NOT_FOUND => (
133                StatusCode::NOT_FOUND,
134                Json(ApiErrorResponse::new(StatusCode::NOT_FOUND, message, trace_id)),
135            ),
136            StatusCode::SERVICE_UNAVAILABLE => (
137                StatusCode::SERVICE_UNAVAILABLE,
138                Json(ApiErrorResponse::new(
139                    StatusCode::SERVICE_UNAVAILABLE,
140                    message,
141                    trace_id,
142                )),
143            ),
144            StatusCode::UNPROCESSABLE_ENTITY => (
145                StatusCode::UNPROCESSABLE_ENTITY,
146                Json(ApiErrorResponse::new(
147                    StatusCode::UNPROCESSABLE_ENTITY,
148                    message,
149                    trace_id,
150                )),
151            ),
152            _ => (
153                StatusCode::INTERNAL_SERVER_ERROR,
154                Json(ApiErrorResponse::new(
155                    StatusCode::INTERNAL_SERVER_ERROR,
156                    message,
157                    trace_id,
158                )),
159            ),
160        }
161    }
162}
163
164impl IntoResponse for ApiError {
165    fn into_response(self) -> Response {
166        match self {
167            ApiError::Timeout => Self::response(StatusCode::REQUEST_TIMEOUT, "Request timeout").into_response(),
168            ApiError::TooManyRequests => {
169                Self::response(StatusCode::TOO_MANY_REQUESTS, "Too many requests").into_response()
170            }
171            ApiError::MethodNotAllowed => {
172                Self::response(StatusCode::METHOD_NOT_ALLOWED, "Method not allowed").into_response()
173            }
174            ApiError::PayloadTooLarge => {
175                Self::response(StatusCode::PAYLOAD_TOO_LARGE, "Payload too large").into_response()
176            }
177            ApiError::ServiceUnavailable => {
178                Self::response(StatusCode::SERVICE_UNAVAILABLE, "Service unavailable").into_response()
179            }
180            ApiError::BadRequest(message) => Self::response(StatusCode::BAD_REQUEST, &message).into_response(),
181            ApiError::Unauthorized(message) => Self::response(StatusCode::UNAUTHORIZED, &message).into_response(),
182            ApiError::Forbidden(message) => Self::response(StatusCode::FORBIDDEN, &message).into_response(),
183            ApiError::NotFound(message) => Self::response(StatusCode::NOT_FOUND, &message).into_response(),
184            ApiError::UnprocessableEntity(message) => {
185                Self::response(StatusCode::UNPROCESSABLE_ENTITY, &message).into_response()
186            }
187            ApiError::InternalServerError(message) => {
188                Self::response(StatusCode::INTERNAL_SERVER_ERROR, &message).into_response()
189            }
190        }
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197    use serde_json::json;
198
199    #[test]
200    fn test_api_success_partial_eq() {
201        let success1 = ApiSuccess::new(StatusCode::OK, json!({"data": "test"}));
202        let success2 = ApiSuccess::new(StatusCode::OK, json!({"data": "test"}));
203        assert_eq!(success1, success2);
204
205        let success3 = ApiSuccess::new(StatusCode::BAD_REQUEST, json!({"data": "test"}));
206        assert_ne!(success1, success3);
207    }
208
209    #[tokio::test]
210    async fn test_api_success_into_response() {
211        let data = json!({"hello": "world"});
212        let api_success = ApiSuccess::new(StatusCode::OK, data.clone());
213        let response = api_success.into_response();
214        assert_eq!(response.status(), StatusCode::OK);
215
216        let body = response.into_body();
217        let body_bytes = axum::body::to_bytes(body, 1_024).await.unwrap();
218        let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
219        assert_eq!(body_str, data.to_string());
220    }
221
222    #[test]
223    fn test_new_api_error_response() {
224        let error = ApiErrorResponse::new(StatusCode::BAD_REQUEST, "Bad request", None);
225        assert_eq!(error.code, 400);
226        assert_eq!(error.message, "Bad request");
227    }
228
229    #[tokio::test]
230    async fn test_api_error_into_response_bad_request() {
231        let error = ApiError::BadRequest("Invalid input".to_string());
232        assert_eq!(error.to_string(), "Bad request: Invalid input");
233
234        let response = error.into_response();
235        assert_eq!(response.status(), StatusCode::BAD_REQUEST);
236
237        let body = response.into_body();
238        let body_bytes = axum::body::to_bytes(body, 1_024).await.unwrap();
239        let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
240        assert_eq!(body_str, json!({ "code": 400, "message": "Invalid input" }).to_string());
241    }
242
243    #[tokio::test]
244    async fn test_api_error_into_response_unauthorized() {
245        let error = ApiError::Unauthorized("Not authorized".to_string());
246        assert_eq!(error.to_string(), "Unauthorized: Not authorized");
247
248        let response = error.into_response();
249        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
250
251        let body = response.into_body();
252        let body_bytes = axum::body::to_bytes(body, 1_024).await.unwrap();
253        let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
254        assert_eq!(
255            body_str,
256            json!({ "code": 401, "message": "Not authorized" }).to_string()
257        );
258    }
259
260    #[tokio::test]
261    async fn test_api_error_into_response_forbidden() {
262        let error = ApiError::Forbidden("Access denied".to_string());
263        assert_eq!(error.to_string(), "Forbidden: Access denied");
264
265        let response = error.into_response();
266        assert_eq!(response.status(), StatusCode::FORBIDDEN);
267
268        let body = response.into_body();
269        let body_bytes = axum::body::to_bytes(body, 1_024).await.unwrap();
270        let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
271        assert_eq!(body_str, json!({ "code": 403, "message": "Access denied" }).to_string());
272    }
273
274    #[tokio::test]
275    async fn test_api_error_into_response_not_found() {
276        let error = ApiError::NotFound("Resource missing".to_string());
277        assert_eq!(error.to_string(), "Not found: Resource missing");
278
279        let response = error.into_response();
280        assert_eq!(response.status(), StatusCode::NOT_FOUND);
281
282        let body = response.into_body();
283        let body_bytes = axum::body::to_bytes(body, 1_024).await.unwrap();
284        let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
285        assert_eq!(
286            body_str,
287            json!({ "code": 404, "message": "Resource missing" }).to_string()
288        );
289    }
290
291    #[tokio::test]
292    async fn test_api_error_into_response_unprocessable_entity() {
293        let error = ApiError::UnprocessableEntity("Invalid data".to_string());
294        assert_eq!(error.to_string(), "Unprocessable entity: Invalid data");
295
296        let response = error.into_response();
297        assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
298
299        let body = response.into_body();
300        let body_bytes = axum::body::to_bytes(body, 1_024).await.unwrap();
301        let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
302        assert_eq!(body_str, json!({ "code": 422, "message": "Invalid data" }).to_string());
303    }
304
305    #[tokio::test]
306    async fn test_api_error_into_response_internal_server_error() {
307        let error = ApiError::InternalServerError("Unexpected".to_string());
308        assert_eq!(error.to_string(), "Internal server error: Unexpected");
309
310        let response = error.into_response();
311        assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
312
313        let body = response.into_body();
314        let body_bytes = axum::body::to_bytes(body, 1_024).await.unwrap();
315        let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
316        assert_eq!(body_str, json!({ "code": 500, "message": "Unexpected" }).to_string());
317    }
318
319    #[tokio::test]
320    async fn test_api_error_into_response_timeout() {
321        let error = ApiError::Timeout;
322        assert_eq!(error.to_string(), "Timeout");
323
324        let response = error.into_response();
325        assert_eq!(response.status(), StatusCode::REQUEST_TIMEOUT);
326
327        let body = response.into_body();
328        let body_bytes = axum::body::to_bytes(body, 1_024).await.unwrap();
329        let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
330        assert_eq!(
331            body_str,
332            json!({ "code": 408, "message": "Request timeout" }).to_string()
333        );
334    }
335
336    #[tokio::test]
337    async fn test_api_error_into_response_too_many_requests() {
338        let error = ApiError::TooManyRequests;
339        assert_eq!(error.to_string(), "Too many requests");
340
341        let response = error.into_response();
342        assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
343
344        let body = response.into_body();
345        let body_bytes = axum::body::to_bytes(body, 1_024).await.unwrap();
346        let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
347        assert_eq!(
348            body_str,
349            json!({ "code": 429, "message": "Too many requests" }).to_string()
350        );
351    }
352
353    #[tokio::test]
354    async fn test_api_error_into_response_method_not_allowed() {
355        let error = ApiError::MethodNotAllowed;
356        assert_eq!(error.to_string(), "Method not allowed");
357
358        let response = error.into_response();
359        assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
360
361        let body = response.into_body();
362        let body_bytes = axum::body::to_bytes(body, 1_024).await.unwrap();
363        let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
364        assert_eq!(
365            body_str,
366            json!({ "code": 405, "message": "Method not allowed" }).to_string()
367        );
368    }
369
370    #[tokio::test]
371    async fn test_api_error_into_response_payload_too_large() {
372        let error = ApiError::PayloadTooLarge;
373        assert_eq!(error.to_string(), "Payload too large");
374
375        let response = error.into_response();
376        assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
377
378        let body = response.into_body();
379        let body_bytes = axum::body::to_bytes(body, 1_024).await.unwrap();
380        let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
381        assert_eq!(
382            body_str,
383            json!({ "code": 413, "message": "Payload too large" }).to_string()
384        );
385    }
386
387    #[tokio::test]
388    async fn test_api_error_into_response_service_unavailable() {
389        let error = ApiError::ServiceUnavailable;
390        assert_eq!(error.to_string(), "Service unavailable");
391
392        let response = error.into_response();
393        assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
394
395        let body = response.into_body();
396        let body_bytes = axum::body::to_bytes(body, 1_024).await.unwrap();
397        let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
398        assert_eq!(
399            body_str,
400            json!({ "code": 503, "message": "Service unavailable" }).to_string()
401        );
402    }
403
404    #[tokio::test]
405    async fn test_api_error_response() {
406        let response = ApiError::response(StatusCode::INTERNAL_SERVER_ERROR, "Internal server error");
407        let response = response.into_response();
408        assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
409
410        let body = response.into_body();
411        let body_bytes = axum::body::to_bytes(body, 1_024).await.unwrap();
412        let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
413        assert_eq!(
414            body_str,
415            json!({ "code": 500, "message": "Internal server error" }).to_string()
416        );
417    }
418}