1use axum::Json;
4use axum::http::StatusCode;
5use axum::response::{IntoResponse, Response};
6use opentelemetry::TraceId;
7use opentelemetry::trace::TraceContextExt;
8use serde::Serialize;
9use thiserror::Error;
10use tracing_opentelemetry::OpenTelemetrySpanExt;
11
12#[derive(Debug, Clone)]
14pub struct ApiSuccess<T: Serialize + PartialEq>(StatusCode, Json<T>);
15
16impl<T> PartialEq for ApiSuccess<T>
17where
18 T: Serialize + PartialEq,
19{
20 fn eq(&self, other: &Self) -> bool {
21 self.0 == other.0 && self.1.0 == other.1.0
22 }
23}
24
25impl<T: Serialize + PartialEq> ApiSuccess<T> {
26 pub fn new(status: StatusCode, data: T) -> Self {
27 ApiSuccess(status, Json(data))
28 }
29}
30
31impl<T: Serialize + PartialEq> IntoResponse for ApiSuccess<T> {
32 fn into_response(self) -> Response {
33 (self.0, self.1).into_response()
34 }
35}
36
37#[derive(Debug, Clone, PartialEq, Serialize)]
39pub(crate) struct ApiErrorResponse<T: Serialize + PartialEq> {
40 code: u16,
41 message: T,
42 #[serde(skip_serializing_if = "Option::is_none")]
43 trace_id: Option<String>,
44}
45
46impl<T: Serialize + PartialEq> ApiErrorResponse<T> {
47 pub(crate) fn new(status_code: StatusCode, message: T, trace_id: Option<String>) -> Self {
48 Self {
49 code: status_code.as_u16(),
50 message,
51 trace_id,
52 }
53 }
54}
55
56#[derive(Debug, Clone, PartialEq, Error)]
58pub enum ApiError {
59 #[error("Bad request: {0}")]
60 BadRequest(String),
61
62 #[error("Unauthorized: {0}")]
63 Unauthorized(String),
64
65 #[error("Forbidden: {0}")]
66 Forbidden(String),
67
68 #[error("Not found: {0}")]
69 NotFound(String),
70
71 #[error("Unprocessable entity: {0}")]
72 UnprocessableEntity(String),
73
74 #[error("Internal server error: {0}")]
75 InternalServerError(String),
76
77 #[error("Timeout")]
78 Timeout,
79
80 #[error("Too many requests")]
81 TooManyRequests,
82
83 #[error("Method not allowed")]
84 MethodNotAllowed,
85
86 #[error("Payload too large")]
87 PayloadTooLarge,
88
89 #[error("Service unavailable")]
90 ServiceUnavailable,
91}
92
93impl ApiError {
94 fn response(code: StatusCode, message: &str) -> impl IntoResponse + '_ {
95 let ctx = tracing::Span::current().context();
96 let trace_id = ctx.span().span_context().trace_id();
97 let trace_id = if trace_id == TraceId::INVALID {
98 None
99 } else {
100 Some(trace_id.to_string())
101 };
102
103 match code {
104 StatusCode::REQUEST_TIMEOUT => (
105 StatusCode::REQUEST_TIMEOUT,
106 Json(ApiErrorResponse::new(StatusCode::REQUEST_TIMEOUT, message, trace_id)),
107 ),
108 StatusCode::TOO_MANY_REQUESTS => (
109 StatusCode::TOO_MANY_REQUESTS,
110 Json(ApiErrorResponse::new(StatusCode::TOO_MANY_REQUESTS, message, trace_id)),
111 ),
112 StatusCode::METHOD_NOT_ALLOWED => (
113 StatusCode::METHOD_NOT_ALLOWED,
114 Json(ApiErrorResponse::new(StatusCode::METHOD_NOT_ALLOWED, message, trace_id)),
115 ),
116 StatusCode::PAYLOAD_TOO_LARGE => (
117 StatusCode::PAYLOAD_TOO_LARGE,
118 Json(ApiErrorResponse::new(StatusCode::PAYLOAD_TOO_LARGE, message, trace_id)),
119 ),
120 StatusCode::BAD_REQUEST => (
121 StatusCode::BAD_REQUEST,
122 Json(ApiErrorResponse::new(StatusCode::BAD_REQUEST, message, trace_id)),
123 ),
124 StatusCode::UNAUTHORIZED => (
125 StatusCode::UNAUTHORIZED,
126 Json(ApiErrorResponse::new(StatusCode::UNAUTHORIZED, message, None)),
127 ),
128 StatusCode::FORBIDDEN => (
129 StatusCode::FORBIDDEN,
130 Json(ApiErrorResponse::new(StatusCode::FORBIDDEN, message, trace_id)),
131 ),
132 StatusCode::NOT_FOUND => (
133 StatusCode::NOT_FOUND,
134 Json(ApiErrorResponse::new(StatusCode::NOT_FOUND, message, trace_id)),
135 ),
136 StatusCode::SERVICE_UNAVAILABLE => (
137 StatusCode::SERVICE_UNAVAILABLE,
138 Json(ApiErrorResponse::new(
139 StatusCode::SERVICE_UNAVAILABLE,
140 message,
141 trace_id,
142 )),
143 ),
144 StatusCode::UNPROCESSABLE_ENTITY => (
145 StatusCode::UNPROCESSABLE_ENTITY,
146 Json(ApiErrorResponse::new(
147 StatusCode::UNPROCESSABLE_ENTITY,
148 message,
149 trace_id,
150 )),
151 ),
152 _ => (
153 StatusCode::INTERNAL_SERVER_ERROR,
154 Json(ApiErrorResponse::new(
155 StatusCode::INTERNAL_SERVER_ERROR,
156 message,
157 trace_id,
158 )),
159 ),
160 }
161 }
162}
163
164impl IntoResponse for ApiError {
165 fn into_response(self) -> Response {
166 match self {
167 ApiError::Timeout => Self::response(StatusCode::REQUEST_TIMEOUT, "Request timeout").into_response(),
168 ApiError::TooManyRequests => {
169 Self::response(StatusCode::TOO_MANY_REQUESTS, "Too many requests").into_response()
170 }
171 ApiError::MethodNotAllowed => {
172 Self::response(StatusCode::METHOD_NOT_ALLOWED, "Method not allowed").into_response()
173 }
174 ApiError::PayloadTooLarge => {
175 Self::response(StatusCode::PAYLOAD_TOO_LARGE, "Payload too large").into_response()
176 }
177 ApiError::ServiceUnavailable => {
178 Self::response(StatusCode::SERVICE_UNAVAILABLE, "Service unavailable").into_response()
179 }
180 ApiError::BadRequest(message) => Self::response(StatusCode::BAD_REQUEST, &message).into_response(),
181 ApiError::Unauthorized(message) => Self::response(StatusCode::UNAUTHORIZED, &message).into_response(),
182 ApiError::Forbidden(message) => Self::response(StatusCode::FORBIDDEN, &message).into_response(),
183 ApiError::NotFound(message) => Self::response(StatusCode::NOT_FOUND, &message).into_response(),
184 ApiError::UnprocessableEntity(message) => {
185 Self::response(StatusCode::UNPROCESSABLE_ENTITY, &message).into_response()
186 }
187 ApiError::InternalServerError(message) => {
188 Self::response(StatusCode::INTERNAL_SERVER_ERROR, &message).into_response()
189 }
190 }
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197 use serde_json::json;
198
199 #[test]
200 fn test_api_success_partial_eq() {
201 let success1 = ApiSuccess::new(StatusCode::OK, json!({"data": "test"}));
202 let success2 = ApiSuccess::new(StatusCode::OK, json!({"data": "test"}));
203 assert_eq!(success1, success2);
204
205 let success3 = ApiSuccess::new(StatusCode::BAD_REQUEST, json!({"data": "test"}));
206 assert_ne!(success1, success3);
207 }
208
209 #[tokio::test]
210 async fn test_api_success_into_response() {
211 let data = json!({"hello": "world"});
212 let api_success = ApiSuccess::new(StatusCode::OK, data.clone());
213 let response = api_success.into_response();
214 assert_eq!(response.status(), StatusCode::OK);
215
216 let body = response.into_body();
217 let body_bytes = axum::body::to_bytes(body, 1_024).await.unwrap();
218 let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
219 assert_eq!(body_str, data.to_string());
220 }
221
222 #[test]
223 fn test_new_api_error_response() {
224 let error = ApiErrorResponse::new(StatusCode::BAD_REQUEST, "Bad request", None);
225 assert_eq!(error.code, 400);
226 assert_eq!(error.message, "Bad request");
227 }
228
229 #[tokio::test]
230 async fn test_api_error_into_response_bad_request() {
231 let error = ApiError::BadRequest("Invalid input".to_string());
232 assert_eq!(error.to_string(), "Bad request: Invalid input");
233
234 let response = error.into_response();
235 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
236
237 let body = response.into_body();
238 let body_bytes = axum::body::to_bytes(body, 1_024).await.unwrap();
239 let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
240 assert_eq!(body_str, json!({ "code": 400, "message": "Invalid input" }).to_string());
241 }
242
243 #[tokio::test]
244 async fn test_api_error_into_response_unauthorized() {
245 let error = ApiError::Unauthorized("Not authorized".to_string());
246 assert_eq!(error.to_string(), "Unauthorized: Not authorized");
247
248 let response = error.into_response();
249 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
250
251 let body = response.into_body();
252 let body_bytes = axum::body::to_bytes(body, 1_024).await.unwrap();
253 let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
254 assert_eq!(
255 body_str,
256 json!({ "code": 401, "message": "Not authorized" }).to_string()
257 );
258 }
259
260 #[tokio::test]
261 async fn test_api_error_into_response_forbidden() {
262 let error = ApiError::Forbidden("Access denied".to_string());
263 assert_eq!(error.to_string(), "Forbidden: Access denied");
264
265 let response = error.into_response();
266 assert_eq!(response.status(), StatusCode::FORBIDDEN);
267
268 let body = response.into_body();
269 let body_bytes = axum::body::to_bytes(body, 1_024).await.unwrap();
270 let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
271 assert_eq!(body_str, json!({ "code": 403, "message": "Access denied" }).to_string());
272 }
273
274 #[tokio::test]
275 async fn test_api_error_into_response_not_found() {
276 let error = ApiError::NotFound("Resource missing".to_string());
277 assert_eq!(error.to_string(), "Not found: Resource missing");
278
279 let response = error.into_response();
280 assert_eq!(response.status(), StatusCode::NOT_FOUND);
281
282 let body = response.into_body();
283 let body_bytes = axum::body::to_bytes(body, 1_024).await.unwrap();
284 let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
285 assert_eq!(
286 body_str,
287 json!({ "code": 404, "message": "Resource missing" }).to_string()
288 );
289 }
290
291 #[tokio::test]
292 async fn test_api_error_into_response_unprocessable_entity() {
293 let error = ApiError::UnprocessableEntity("Invalid data".to_string());
294 assert_eq!(error.to_string(), "Unprocessable entity: Invalid data");
295
296 let response = error.into_response();
297 assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
298
299 let body = response.into_body();
300 let body_bytes = axum::body::to_bytes(body, 1_024).await.unwrap();
301 let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
302 assert_eq!(body_str, json!({ "code": 422, "message": "Invalid data" }).to_string());
303 }
304
305 #[tokio::test]
306 async fn test_api_error_into_response_internal_server_error() {
307 let error = ApiError::InternalServerError("Unexpected".to_string());
308 assert_eq!(error.to_string(), "Internal server error: Unexpected");
309
310 let response = error.into_response();
311 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
312
313 let body = response.into_body();
314 let body_bytes = axum::body::to_bytes(body, 1_024).await.unwrap();
315 let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
316 assert_eq!(body_str, json!({ "code": 500, "message": "Unexpected" }).to_string());
317 }
318
319 #[tokio::test]
320 async fn test_api_error_into_response_timeout() {
321 let error = ApiError::Timeout;
322 assert_eq!(error.to_string(), "Timeout");
323
324 let response = error.into_response();
325 assert_eq!(response.status(), StatusCode::REQUEST_TIMEOUT);
326
327 let body = response.into_body();
328 let body_bytes = axum::body::to_bytes(body, 1_024).await.unwrap();
329 let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
330 assert_eq!(
331 body_str,
332 json!({ "code": 408, "message": "Request timeout" }).to_string()
333 );
334 }
335
336 #[tokio::test]
337 async fn test_api_error_into_response_too_many_requests() {
338 let error = ApiError::TooManyRequests;
339 assert_eq!(error.to_string(), "Too many requests");
340
341 let response = error.into_response();
342 assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
343
344 let body = response.into_body();
345 let body_bytes = axum::body::to_bytes(body, 1_024).await.unwrap();
346 let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
347 assert_eq!(
348 body_str,
349 json!({ "code": 429, "message": "Too many requests" }).to_string()
350 );
351 }
352
353 #[tokio::test]
354 async fn test_api_error_into_response_method_not_allowed() {
355 let error = ApiError::MethodNotAllowed;
356 assert_eq!(error.to_string(), "Method not allowed");
357
358 let response = error.into_response();
359 assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
360
361 let body = response.into_body();
362 let body_bytes = axum::body::to_bytes(body, 1_024).await.unwrap();
363 let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
364 assert_eq!(
365 body_str,
366 json!({ "code": 405, "message": "Method not allowed" }).to_string()
367 );
368 }
369
370 #[tokio::test]
371 async fn test_api_error_into_response_payload_too_large() {
372 let error = ApiError::PayloadTooLarge;
373 assert_eq!(error.to_string(), "Payload too large");
374
375 let response = error.into_response();
376 assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
377
378 let body = response.into_body();
379 let body_bytes = axum::body::to_bytes(body, 1_024).await.unwrap();
380 let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
381 assert_eq!(
382 body_str,
383 json!({ "code": 413, "message": "Payload too large" }).to_string()
384 );
385 }
386
387 #[tokio::test]
388 async fn test_api_error_into_response_service_unavailable() {
389 let error = ApiError::ServiceUnavailable;
390 assert_eq!(error.to_string(), "Service unavailable");
391
392 let response = error.into_response();
393 assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
394
395 let body = response.into_body();
396 let body_bytes = axum::body::to_bytes(body, 1_024).await.unwrap();
397 let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
398 assert_eq!(
399 body_str,
400 json!({ "code": 503, "message": "Service unavailable" }).to_string()
401 );
402 }
403
404 #[tokio::test]
405 async fn test_api_error_response() {
406 let response = ApiError::response(StatusCode::INTERNAL_SERVER_ERROR, "Internal server error");
407 let response = response.into_response();
408 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
409
410 let body = response.into_body();
411 let body_bytes = axum::body::to_bytes(body, 1_024).await.unwrap();
412 let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
413 assert_eq!(
414 body_str,
415 json!({ "code": 500, "message": "Internal server error" }).to_string()
416 );
417 }
418}