Skip to main content

axum_api_kit/
error.rs

1use axum::{http::StatusCode, Json};
2use serde::Serialize;
3use serde_json::Value;
4use std::fmt;
5
6/// A machine-readable JSON error body.
7///
8/// Serializes as:
9/// ```json
10/// { "code": "NOT_FOUND", "message": "item not found" }
11/// { "code": "VALIDATION_ERROR", "message": "invalid input", "details": { "field": "name" } }
12/// ```
13///
14/// Use the factory methods to get a `(StatusCode, Json<ApiError>)` tuple, which implements
15/// [`IntoResponse`] and can be returned directly from Axum handlers.
16///
17/// # Example
18///
19/// ```rust
20/// use axum::response::IntoResponse;
21/// use axum_api_kit::ApiError;
22///
23/// async fn handler() -> impl IntoResponse {
24///     ApiError::not_found("item not found")
25/// }
26/// ```
27#[derive(Debug, Clone, Serialize)]
28#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
29pub struct ApiError {
30    /// A short, stable, machine-readable error identifier. Use `SCREAMING_SNAKE_CASE`.
31    pub code: String,
32    /// A human-readable description of the error.
33    pub message: String,
34    /// Optional structured details (field-level validation errors, etc.).
35    #[serde(skip_serializing_if = "Option::is_none")]
36    #[cfg_attr(feature = "openapi", schema(value_type = Option<Object>))]
37    pub details: Option<Value>,
38}
39
40impl ApiError {
41    /// Construct a bare `ApiError` without a bundled status code.
42    ///
43    /// Prefer the factory methods ([`not_found`](Self::not_found), etc.) when returning
44    /// responses directly from handlers.
45    pub fn new(code: impl Into<String>, message: impl Into<String>) -> Self {
46        Self {
47            code: code.into(),
48            message: message.into(),
49            details: None,
50        }
51    }
52
53    /// Attach structured details to this error.
54    pub fn with_details(mut self, details: Value) -> Self {
55        self.details = Some(details);
56        self
57    }
58
59    // --- Factory helpers ---
60    // Each returns (StatusCode, Json<ApiError>) which implements IntoResponse.
61
62    /// `400 Bad Request` with the provided `code` and `message`.
63    pub fn bad_request(
64        code: impl Into<String>,
65        message: impl Into<String>,
66    ) -> (StatusCode, Json<Self>) {
67        (StatusCode::BAD_REQUEST, Json(Self::new(code, message)))
68    }
69
70    /// `401 Unauthorized` - `code` defaults to `"AUTH_REQUIRED"`.
71    pub fn unauthorized(message: impl Into<String>) -> (StatusCode, Json<Self>) {
72        (
73            StatusCode::UNAUTHORIZED,
74            Json(Self::new("AUTH_REQUIRED", message)),
75        )
76    }
77
78    /// `403 Forbidden` - `code` defaults to `"FORBIDDEN"`.
79    pub fn forbidden(message: impl Into<String>) -> (StatusCode, Json<Self>) {
80        (StatusCode::FORBIDDEN, Json(Self::new("FORBIDDEN", message)))
81    }
82
83    /// `404 Not Found` - `code` defaults to `"NOT_FOUND"`.
84    pub fn not_found(message: impl Into<String>) -> (StatusCode, Json<Self>) {
85        (StatusCode::NOT_FOUND, Json(Self::new("NOT_FOUND", message)))
86    }
87
88    /// `409 Conflict` - `code` defaults to `"CONFLICT"`.
89    pub fn conflict(message: impl Into<String>) -> (StatusCode, Json<Self>) {
90        (StatusCode::CONFLICT, Json(Self::new("CONFLICT", message)))
91    }
92
93    /// `422 Unprocessable Entity` - `code` defaults to `"VALIDATION_ERROR"`.
94    pub fn unprocessable(message: impl Into<String>) -> (StatusCode, Json<Self>) {
95        (
96            StatusCode::UNPROCESSABLE_ENTITY,
97            Json(Self::new("VALIDATION_ERROR", message)),
98        )
99    }
100
101    /// `500 Internal Server Error` - `code` defaults to `"INTERNAL_ERROR"`.
102    pub fn internal(message: impl Into<String>) -> (StatusCode, Json<Self>) {
103        (
104            StatusCode::INTERNAL_SERVER_ERROR,
105            Json(Self::new("INTERNAL_ERROR", message)),
106        )
107    }
108
109    /// `500 Internal Server Error` for database failures - `code` is `"DB_ERROR"`.
110    pub fn db_error() -> (StatusCode, Json<Self>) {
111        (
112            StatusCode::INTERNAL_SERVER_ERROR,
113            Json(Self::new("DB_ERROR", "database error")),
114        )
115    }
116
117    /// `429 Too Many Requests` - `code` defaults to `"RATE_LIMITED"`.
118    pub fn too_many_requests(message: impl Into<String>) -> (StatusCode, Json<Self>) {
119        (
120            StatusCode::TOO_MANY_REQUESTS,
121            Json(Self::new("RATE_LIMITED", message)),
122        )
123    }
124
125    /// `503 Service Unavailable` - `code` defaults to `"SERVICE_UNAVAILABLE"`.
126    pub fn service_unavailable(message: impl Into<String>) -> (StatusCode, Json<Self>) {
127        (
128            StatusCode::SERVICE_UNAVAILABLE,
129            Json(Self::new("SERVICE_UNAVAILABLE", message)),
130        )
131    }
132
133    /// `501 Not Implemented` - `code` defaults to `"NOT_IMPLEMENTED"`.
134    pub fn not_implemented(message: impl Into<String>) -> (StatusCode, Json<Self>) {
135        (
136            StatusCode::NOT_IMPLEMENTED,
137            Json(Self::new("NOT_IMPLEMENTED", message)),
138        )
139    }
140
141    /// Attach a source error message to this error.
142    ///
143    /// Stores the source in the details field under the `"source"` key.
144    /// Can be chained with other builder methods.
145    ///
146    /// # Example
147    ///
148    /// ```rust
149    /// use axum_api_kit::ApiError;
150    ///
151    /// let err = ApiError::new("NOT_FOUND", "user not found")
152    ///     .with_source("SELECT * FROM users WHERE id = ?")
153    ///     .with_details(serde_json::json!({ "user_id": 42 }));
154    /// ```
155    pub fn with_source(mut self, source: &str) -> Self {
156        let mut details = self.details.take().unwrap_or_else(|| serde_json::json!({}));
157        if let serde_json::Value::Object(ref mut map) = details {
158            map.insert(
159                "source".to_string(),
160                serde_json::Value::String(source.to_string()),
161            );
162        }
163        self.details = Some(details);
164        self
165    }
166}
167
168/// Convert `std::io::Error` to `ApiError` with HTTP 500.
169///
170/// Maps `std::io::Error` to `ApiError::internal()` with the error message.
171/// Enables using the `?` operator in handlers:
172///
173/// ```rust,ignore
174/// async fn handler() -> impl IntoResponse {
175///     let content = std::fs::read_to_string("/data.txt")?;  // auto-converts to ApiError
176///     Ok((StatusCode::OK, content))
177/// }
178/// ```
179impl From<std::io::Error> for ApiError {
180    fn from(err: std::io::Error) -> Self {
181        Self::new("IO_ERROR", format!("IO error: {}", err))
182    }
183}
184
185/// Convert `serde_json::Error` to `ApiError` with HTTP 500.
186///
187/// Maps JSON errors to `ApiError::internal()` with the error message.
188impl From<serde_json::Error> for ApiError {
189    fn from(err: serde_json::Error) -> Self {
190        Self::new("JSON_ERROR", format!("JSON error: {}", err))
191    }
192}
193
194/// Convert `sqlx::Error` to an `ApiError` with a semantically appropriate HTTP status.
195///
196/// Requires the `sqlx` feature flag.
197///
198/// | `sqlx::Error` variant | `code` | HTTP |
199/// |---|---|---|
200/// | `RowNotFound` | `NOT_FOUND` | 404 |
201/// | `Database` (unique/FK violation) | `CONFLICT` | 409 |
202/// | `Database` (check violation) | `VALIDATION_ERROR` | 422 |
203/// | `Database` (other) | `DB_ERROR` | 500 |
204/// | `PoolTimedOut` / `PoolClosed` / `WorkerCrashed` | `SERVICE_UNAVAILABLE` | 503 |
205/// | everything else | `DB_ERROR` | 500 |
206#[cfg(feature = "sqlx")]
207impl From<sqlx::Error> for ApiError {
208    fn from(err: sqlx::Error) -> Self {
209        match err {
210            sqlx::Error::RowNotFound => Self::new("NOT_FOUND", "record not found"),
211            sqlx::Error::PoolTimedOut | sqlx::Error::PoolClosed | sqlx::Error::WorkerCrashed => {
212                Self::new("SERVICE_UNAVAILABLE", "database unavailable")
213            }
214            sqlx::Error::Database(db_err) => {
215                if db_err.is_unique_violation() || db_err.is_foreign_key_violation() {
216                    Self::new("CONFLICT", db_err.message().to_string())
217                } else if db_err.is_check_violation() {
218                    Self::new("VALIDATION_ERROR", db_err.message().to_string())
219                } else {
220                    Self::new("DB_ERROR", db_err.message().to_string())
221                }
222            }
223            _ => Self::new("DB_ERROR", format!("database error: {}", err)),
224        }
225    }
226}
227
228#[cfg(feature = "validator")]
229fn collect_validation_errors(
230    prefix: Option<&str>,
231    errors: &validator::ValidationErrors,
232    out: &mut serde_json::Map<String, serde_json::Value>,
233) {
234    use validator::ValidationErrorsKind;
235
236    for (field, kind) in errors.errors() {
237        let base = if let Some(prefix) = prefix {
238            format!("{}.{}", prefix, field)
239        } else {
240            field.to_string()
241        };
242
243        match kind {
244            ValidationErrorsKind::Field(field_errors) => {
245                let items = field_errors
246                    .iter()
247                    .map(|err| {
248                        let mut obj = serde_json::Map::new();
249                        obj.insert(
250                            "code".to_string(),
251                            serde_json::Value::String(err.code.to_string()),
252                        );
253                        if let Some(message) = &err.message {
254                            obj.insert(
255                                "message".to_string(),
256                                serde_json::Value::String(message.to_string()),
257                            );
258                        }
259                        if !err.params.is_empty() {
260                            let params = match serde_json::to_value(&err.params) {
261                                Ok(v) => v,
262                                Err(_) => serde_json::Value::Null,
263                            };
264                            obj.insert("params".to_string(), params);
265                        }
266                        serde_json::Value::Object(obj)
267                    })
268                    .collect::<Vec<_>>();
269                out.insert(base, serde_json::Value::Array(items));
270            }
271            ValidationErrorsKind::Struct(nested) => {
272                collect_validation_errors(Some(&base), nested, out);
273            }
274            ValidationErrorsKind::List(items) => {
275                for (index, nested) in items {
276                    let indexed = format!("{}[{}]", base, index);
277                    collect_validation_errors(Some(&indexed), nested, out);
278                }
279            }
280        }
281    }
282}
283
284#[cfg(feature = "validator")]
285impl From<validator::ValidationErrors> for ApiError {
286    fn from(errors: validator::ValidationErrors) -> Self {
287        let mut fields = serde_json::Map::new();
288        collect_validation_errors(None, &errors, &mut fields);
289
290        Self::new("VALIDATION_ERROR", "validation failed").with_details(serde_json::json!({
291            "fields": fields
292        }))
293    }
294}
295
296impl fmt::Display for ApiError {
297    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
298        write!(f, "{}: {}", self.code, self.message)
299    }
300}
301
302impl std::error::Error for ApiError {}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use serde_json::json;
308
309    #[test]
310    fn new_sets_fields() {
311        let err = ApiError::new("MY_CODE", "my message");
312        assert_eq!(err.code, "MY_CODE");
313        assert_eq!(err.message, "my message");
314        assert!(err.details.is_none());
315    }
316
317    #[test]
318    fn with_details_sets_details() {
319        let err = ApiError::new("CODE", "msg").with_details(json!({ "field": "name" }));
320        assert_eq!(err.details.unwrap()["field"], "name");
321    }
322
323    #[test]
324    fn serializes_without_details() {
325        let err = ApiError::new("NOT_FOUND", "item not found");
326        let v = serde_json::to_value(&err).unwrap();
327        assert_eq!(v["code"], "NOT_FOUND");
328        assert_eq!(v["message"], "item not found");
329        assert!(v.get("details").is_none());
330    }
331
332    #[test]
333    fn serializes_with_details() {
334        let err = ApiError::new("VALIDATION_ERROR", "invalid").with_details(json!({ "x": 1 }));
335        let v = serde_json::to_value(&err).unwrap();
336        assert_eq!(v["details"]["x"], 1);
337    }
338
339    #[test]
340    fn display_formats_code_and_message() {
341        let err = ApiError::new("NOT_FOUND", "item not found");
342        assert_eq!(err.to_string(), "NOT_FOUND: item not found");
343    }
344
345    #[test]
346    fn implements_std_error() {
347        let err = ApiError::new("ERR", "something failed");
348        let _: &dyn std::error::Error = &err;
349    }
350
351    macro_rules! assert_factory {
352        ($method:expr, $expected_status:expr, $expected_code:expr) => {{
353            let (status, Json(body)) = $method;
354            assert_eq!(status, $expected_status);
355            assert_eq!(body.code, $expected_code);
356        }};
357    }
358
359    #[test]
360    fn bad_request_status_and_code() {
361        assert_factory!(
362            ApiError::bad_request("INVALID_FIELD", "bad"),
363            StatusCode::BAD_REQUEST,
364            "INVALID_FIELD"
365        );
366    }
367
368    #[test]
369    fn unauthorized_status_and_code() {
370        assert_factory!(
371            ApiError::unauthorized("please log in"),
372            StatusCode::UNAUTHORIZED,
373            "AUTH_REQUIRED"
374        );
375    }
376
377    #[test]
378    fn forbidden_status_and_code() {
379        assert_factory!(
380            ApiError::forbidden("no access"),
381            StatusCode::FORBIDDEN,
382            "FORBIDDEN"
383        );
384    }
385
386    #[test]
387    fn not_found_status_and_code() {
388        assert_factory!(
389            ApiError::not_found("missing"),
390            StatusCode::NOT_FOUND,
391            "NOT_FOUND"
392        );
393    }
394
395    #[test]
396    fn conflict_status_and_code() {
397        assert_factory!(
398            ApiError::conflict("already exists"),
399            StatusCode::CONFLICT,
400            "CONFLICT"
401        );
402    }
403
404    #[test]
405    fn unprocessable_status_and_code() {
406        assert_factory!(
407            ApiError::unprocessable("invalid input"),
408            StatusCode::UNPROCESSABLE_ENTITY,
409            "VALIDATION_ERROR"
410        );
411    }
412
413    #[test]
414    fn internal_status_and_code() {
415        assert_factory!(
416            ApiError::internal("oops"),
417            StatusCode::INTERNAL_SERVER_ERROR,
418            "INTERNAL_ERROR"
419        );
420    }
421
422    #[test]
423    fn db_error_status_and_code() {
424        assert_factory!(
425            ApiError::db_error(),
426            StatusCode::INTERNAL_SERVER_ERROR,
427            "DB_ERROR"
428        );
429    }
430
431    #[test]
432    fn too_many_requests_status_and_code() {
433        assert_factory!(
434            ApiError::too_many_requests("slow down"),
435            StatusCode::TOO_MANY_REQUESTS,
436            "RATE_LIMITED"
437        );
438    }
439
440    #[test]
441    fn service_unavailable_status_and_code() {
442        assert_factory!(
443            ApiError::service_unavailable("down for maintenance"),
444            StatusCode::SERVICE_UNAVAILABLE,
445            "SERVICE_UNAVAILABLE"
446        );
447    }
448
449    #[test]
450    fn not_implemented_status_and_code() {
451        assert_factory!(
452            ApiError::not_implemented("coming soon"),
453            StatusCode::NOT_IMPLEMENTED,
454            "NOT_IMPLEMENTED"
455        );
456    }
457
458    #[test]
459    fn with_source_adds_source_to_details() {
460        let err = ApiError::new("NOT_FOUND", "missing").with_source("db query");
461        let v = serde_json::to_value(&err).unwrap();
462        assert_eq!(v["details"]["source"], "db query");
463        assert_eq!(v["code"], "NOT_FOUND");
464    }
465
466    #[test]
467    fn with_source_and_with_details_both_present() {
468        let err = ApiError::new("ERROR", "msg")
469            .with_details(json!({ "user_id": 123 }))
470            .with_source("from somewhere");
471        let v = serde_json::to_value(&err).unwrap();
472        assert_eq!(v["details"]["source"], "from somewhere");
473        assert_eq!(v["details"]["user_id"], 123);
474    }
475
476    #[test]
477    fn from_io_error_creates_io_error_code() {
478        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
479        let api_err: ApiError = io_err.into();
480        assert_eq!(api_err.code, "IO_ERROR");
481        assert!(api_err.message.contains("IO error"));
482    }
483
484    #[test]
485    fn from_serde_json_error_creates_json_error_code() {
486        let json_str = "{ invalid json }";
487        let json_err: Result<serde_json::Value, _> = serde_json::from_str(json_str);
488        let api_err: ApiError = json_err.unwrap_err().into();
489        assert_eq!(api_err.code, "JSON_ERROR");
490        assert!(api_err.message.contains("JSON error"));
491    }
492
493    #[test]
494    fn io_error_conversion_captures_kind() {
495        let io_err = std::io::Error::new(std::io::ErrorKind::PermissionDenied, "permission denied");
496        let api_err: ApiError = io_err.into();
497        assert!(api_err.message.contains("permission denied"));
498    }
499
500    #[cfg(feature = "validator")]
501    #[test]
502    fn from_validation_errors_single_field() {
503        use std::borrow::Cow;
504        use validator::{ValidationError, ValidationErrors};
505
506        let mut errors = ValidationErrors::new();
507        let mut email = ValidationError::new("email");
508        email.message = Some(Cow::Borrowed("invalid email"));
509        errors.add("email", email);
510
511        let api_err: ApiError = errors.into();
512        let v = serde_json::to_value(api_err).unwrap();
513
514        assert_eq!(v["code"], "VALIDATION_ERROR");
515        assert_eq!(v["message"], "validation failed");
516        assert_eq!(v["details"]["fields"]["email"][0]["code"], "email");
517        assert_eq!(
518            v["details"]["fields"]["email"][0]["message"],
519            "invalid email"
520        );
521    }
522
523    #[cfg(feature = "validator")]
524    #[test]
525    fn from_validation_errors_multiple_fields_with_params() {
526        use std::borrow::Cow;
527        use validator::{ValidationError, ValidationErrors};
528
529        let mut errors = ValidationErrors::new();
530
531        let mut username = ValidationError::new("length");
532        username.message = Some(Cow::Borrowed("username too short"));
533        username.add_param(Cow::Borrowed("min"), &3);
534        errors.add("username", username);
535
536        let mut age = ValidationError::new("range");
537        age.add_param(Cow::Borrowed("min"), &18);
538        errors.add("age", age);
539
540        let api_err: ApiError = errors.into();
541        let v = serde_json::to_value(api_err).unwrap();
542
543        assert_eq!(v["details"]["fields"]["username"][0]["code"], "length");
544        assert_eq!(v["details"]["fields"]["username"][0]["params"]["min"], 3);
545        assert_eq!(v["details"]["fields"]["age"][0]["code"], "range");
546        assert_eq!(v["details"]["fields"]["age"][0]["params"]["min"], 18);
547    }
548
549    #[cfg(feature = "sqlx")]
550    #[test]
551    fn sqlx_row_not_found_maps_to_not_found() {
552        let api_err: ApiError = sqlx::Error::RowNotFound.into();
553        assert_eq!(api_err.code, "NOT_FOUND");
554        assert_eq!(api_err.message, "record not found");
555    }
556
557    #[cfg(feature = "sqlx")]
558    #[test]
559    fn sqlx_pool_timed_out_maps_to_service_unavailable() {
560        let api_err: ApiError = sqlx::Error::PoolTimedOut.into();
561        assert_eq!(api_err.code, "SERVICE_UNAVAILABLE");
562    }
563
564    #[cfg(feature = "sqlx")]
565    #[test]
566    fn sqlx_pool_closed_maps_to_service_unavailable() {
567        let api_err: ApiError = sqlx::Error::PoolClosed.into();
568        assert_eq!(api_err.code, "SERVICE_UNAVAILABLE");
569    }
570
571    #[cfg(feature = "sqlx")]
572    #[test]
573    fn sqlx_unknown_variant_maps_to_db_error() {
574        // Protocol is a non-pool, non-database variant that hits the catch-all arm.
575        let api_err: ApiError = sqlx::Error::Protocol("unexpected packet".into()).into();
576        assert_eq!(api_err.code, "DB_ERROR");
577        assert!(api_err.message.contains("database error"));
578    }
579}