Skip to main content

axum_api_kit/
error.rs

1use axum::{http::StatusCode, Json};
2use serde::Serialize;
3use serde_json::Value;
4use std::fmt;
5
6/// A machine-readable JSON error body.
7///
8/// Serializes as:
9/// ```json
10/// { "code": "NOT_FOUND", "message": "item not found" }
11/// { "code": "VALIDATION_ERROR", "message": "invalid input", "details": { "field": "name" } }
12/// ```
13///
14/// Use the factory methods to get a `(StatusCode, Json<ApiError>)` tuple, which implements
15/// [`IntoResponse`] and can be returned directly from Axum handlers.
16///
17/// # Example
18///
19/// ```rust
20/// use axum::response::IntoResponse;
21/// use axum_api_kit::ApiError;
22///
23/// async fn handler() -> impl IntoResponse {
24///     ApiError::not_found("item not found")
25/// }
26/// ```
27#[derive(Debug, Clone, Serialize)]
28pub struct ApiError {
29    /// A short, stable, machine-readable error identifier. Use `SCREAMING_SNAKE_CASE`.
30    pub code: String,
31    /// A human-readable description of the error.
32    pub message: String,
33    /// Optional structured details (field-level validation errors, etc.).
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub details: Option<Value>,
36}
37
38impl ApiError {
39    /// Construct a bare `ApiError` without a bundled status code.
40    ///
41    /// Prefer the factory methods ([`not_found`](Self::not_found), etc.) when returning
42    /// responses directly from handlers.
43    pub fn new(code: impl Into<String>, message: impl Into<String>) -> Self {
44        Self {
45            code: code.into(),
46            message: message.into(),
47            details: None,
48        }
49    }
50
51    /// Attach structured details to this error.
52    pub fn with_details(mut self, details: Value) -> Self {
53        self.details = Some(details);
54        self
55    }
56
57    // --- Factory helpers ---
58    // Each returns (StatusCode, Json<ApiError>) which implements IntoResponse.
59
60    /// `400 Bad Request` with the provided `code` and `message`.
61    pub fn bad_request(
62        code: impl Into<String>,
63        message: impl Into<String>,
64    ) -> (StatusCode, Json<Self>) {
65        (StatusCode::BAD_REQUEST, Json(Self::new(code, message)))
66    }
67
68    /// `401 Unauthorized` - `code` defaults to `"AUTH_REQUIRED"`.
69    pub fn unauthorized(message: impl Into<String>) -> (StatusCode, Json<Self>) {
70        (
71            StatusCode::UNAUTHORIZED,
72            Json(Self::new("AUTH_REQUIRED", message)),
73        )
74    }
75
76    /// `403 Forbidden` - `code` defaults to `"FORBIDDEN"`.
77    pub fn forbidden(message: impl Into<String>) -> (StatusCode, Json<Self>) {
78        (StatusCode::FORBIDDEN, Json(Self::new("FORBIDDEN", message)))
79    }
80
81    /// `404 Not Found` - `code` defaults to `"NOT_FOUND"`.
82    pub fn not_found(message: impl Into<String>) -> (StatusCode, Json<Self>) {
83        (StatusCode::NOT_FOUND, Json(Self::new("NOT_FOUND", message)))
84    }
85
86    /// `409 Conflict` - `code` defaults to `"CONFLICT"`.
87    pub fn conflict(message: impl Into<String>) -> (StatusCode, Json<Self>) {
88        (StatusCode::CONFLICT, Json(Self::new("CONFLICT", message)))
89    }
90
91    /// `422 Unprocessable Entity` - `code` defaults to `"VALIDATION_ERROR"`.
92    pub fn unprocessable(message: impl Into<String>) -> (StatusCode, Json<Self>) {
93        (
94            StatusCode::UNPROCESSABLE_ENTITY,
95            Json(Self::new("VALIDATION_ERROR", message)),
96        )
97    }
98
99    /// `500 Internal Server Error` - `code` defaults to `"INTERNAL_ERROR"`.
100    pub fn internal(message: impl Into<String>) -> (StatusCode, Json<Self>) {
101        (
102            StatusCode::INTERNAL_SERVER_ERROR,
103            Json(Self::new("INTERNAL_ERROR", message)),
104        )
105    }
106
107    /// `500 Internal Server Error` for database failures - `code` is `"DB_ERROR"`.
108    pub fn db_error() -> (StatusCode, Json<Self>) {
109        (
110            StatusCode::INTERNAL_SERVER_ERROR,
111            Json(Self::new("DB_ERROR", "database error")),
112        )
113    }
114
115    /// `429 Too Many Requests` - `code` defaults to `"RATE_LIMITED"`.
116    pub fn too_many_requests(message: impl Into<String>) -> (StatusCode, Json<Self>) {
117        (
118            StatusCode::TOO_MANY_REQUESTS,
119            Json(Self::new("RATE_LIMITED", message)),
120        )
121    }
122
123    /// `503 Service Unavailable` - `code` defaults to `"SERVICE_UNAVAILABLE"`.
124    pub fn service_unavailable(message: impl Into<String>) -> (StatusCode, Json<Self>) {
125        (
126            StatusCode::SERVICE_UNAVAILABLE,
127            Json(Self::new("SERVICE_UNAVAILABLE", message)),
128        )
129    }
130
131    /// `501 Not Implemented` - `code` defaults to `"NOT_IMPLEMENTED"`.
132    pub fn not_implemented(message: impl Into<String>) -> (StatusCode, Json<Self>) {
133        (
134            StatusCode::NOT_IMPLEMENTED,
135            Json(Self::new("NOT_IMPLEMENTED", message)),
136        )
137    }
138
139    /// Attach a source error message to this error.
140    ///
141    /// Stores the source in the details field under the `"source"` key.
142    /// Can be chained with other builder methods.
143    ///
144    /// # Example
145    ///
146    /// ```rust
147    /// use axum_api_kit::ApiError;
148    ///
149    /// let err = ApiError::new("NOT_FOUND", "user not found")
150    ///     .with_source("SELECT * FROM users WHERE id = ?")
151    ///     .with_details(serde_json::json!({ "user_id": 42 }));
152    /// ```
153    pub fn with_source(mut self, source: &str) -> Self {
154        let mut details = self.details.take().unwrap_or_else(|| serde_json::json!({}));
155        if let serde_json::Value::Object(ref mut map) = details {
156            map.insert(
157                "source".to_string(),
158                serde_json::Value::String(source.to_string()),
159            );
160        }
161        self.details = Some(details);
162        self
163    }
164}
165
166/// Convert `std::io::Error` to `ApiError` with HTTP 500.
167///
168/// Maps `std::io::Error` to `ApiError::internal()` with the error message.
169/// Enables using the `?` operator in handlers:
170///
171/// ```rust,ignore
172/// async fn handler() -> impl IntoResponse {
173///     let content = std::fs::read_to_string("/data.txt")?;  // auto-converts to ApiError
174///     Ok((StatusCode::OK, content))
175/// }
176/// ```
177impl From<std::io::Error> for ApiError {
178    fn from(err: std::io::Error) -> Self {
179        Self::new("IO_ERROR", format!("IO error: {}", err))
180    }
181}
182
183/// Convert `serde_json::Error` to `ApiError` with HTTP 500.
184///
185/// Maps JSON errors to `ApiError::internal()` with the error message.
186impl From<serde_json::Error> for ApiError {
187    fn from(err: serde_json::Error) -> Self {
188        Self::new("JSON_ERROR", format!("JSON error: {}", err))
189    }
190}
191
192/// Convert `sqlx::Error` to an `ApiError` with a semantically appropriate HTTP status.
193///
194/// Requires the `sqlx` feature flag.
195///
196/// | `sqlx::Error` variant | `code` | HTTP |
197/// |---|---|---|
198/// | `RowNotFound` | `NOT_FOUND` | 404 |
199/// | `Database` (unique/FK violation) | `CONFLICT` | 409 |
200/// | `Database` (check violation) | `VALIDATION_ERROR` | 422 |
201/// | `Database` (other) | `DB_ERROR` | 500 |
202/// | `PoolTimedOut` / `PoolClosed` / `WorkerCrashed` | `SERVICE_UNAVAILABLE` | 503 |
203/// | everything else | `DB_ERROR` | 500 |
204#[cfg(feature = "sqlx")]
205impl From<sqlx::Error> for ApiError {
206    fn from(err: sqlx::Error) -> Self {
207        match err {
208            sqlx::Error::RowNotFound => Self::new("NOT_FOUND", "record not found"),
209            sqlx::Error::PoolTimedOut | sqlx::Error::PoolClosed | sqlx::Error::WorkerCrashed => {
210                Self::new("SERVICE_UNAVAILABLE", "database unavailable")
211            }
212            sqlx::Error::Database(db_err) => {
213                if db_err.is_unique_violation() || db_err.is_foreign_key_violation() {
214                    Self::new("CONFLICT", db_err.message().to_string())
215                } else if db_err.is_check_violation() {
216                    Self::new("VALIDATION_ERROR", db_err.message().to_string())
217                } else {
218                    Self::new("DB_ERROR", db_err.message().to_string())
219                }
220            }
221            _ => Self::new("DB_ERROR", format!("database error: {}", err)),
222        }
223    }
224}
225
226#[cfg(feature = "validator")]
227fn collect_validation_errors(
228    prefix: Option<&str>,
229    errors: &validator::ValidationErrors,
230    out: &mut serde_json::Map<String, serde_json::Value>,
231) {
232    use validator::ValidationErrorsKind;
233
234    for (field, kind) in errors.errors() {
235        let base = if let Some(prefix) = prefix {
236            format!("{}.{}", prefix, field)
237        } else {
238            field.to_string()
239        };
240
241        match kind {
242            ValidationErrorsKind::Field(field_errors) => {
243                let items = field_errors
244                    .iter()
245                    .map(|err| {
246                        let mut obj = serde_json::Map::new();
247                        obj.insert(
248                            "code".to_string(),
249                            serde_json::Value::String(err.code.to_string()),
250                        );
251                        if let Some(message) = &err.message {
252                            obj.insert(
253                                "message".to_string(),
254                                serde_json::Value::String(message.to_string()),
255                            );
256                        }
257                        if !err.params.is_empty() {
258                            let params = match serde_json::to_value(&err.params) {
259                                Ok(v) => v,
260                                Err(_) => serde_json::Value::Null,
261                            };
262                            obj.insert("params".to_string(), params);
263                        }
264                        serde_json::Value::Object(obj)
265                    })
266                    .collect::<Vec<_>>();
267                out.insert(base, serde_json::Value::Array(items));
268            }
269            ValidationErrorsKind::Struct(nested) => {
270                collect_validation_errors(Some(&base), nested, out);
271            }
272            ValidationErrorsKind::List(items) => {
273                for (index, nested) in items {
274                    let indexed = format!("{}[{}]", base, index);
275                    collect_validation_errors(Some(&indexed), nested, out);
276                }
277            }
278        }
279    }
280}
281
282#[cfg(feature = "validator")]
283impl From<validator::ValidationErrors> for ApiError {
284    fn from(errors: validator::ValidationErrors) -> Self {
285        let mut fields = serde_json::Map::new();
286        collect_validation_errors(None, &errors, &mut fields);
287
288        Self::new("VALIDATION_ERROR", "validation failed").with_details(serde_json::json!({
289            "fields": fields
290        }))
291    }
292}
293
294impl fmt::Display for ApiError {
295    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
296        write!(f, "{}: {}", self.code, self.message)
297    }
298}
299
300impl std::error::Error for ApiError {}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305    use serde_json::json;
306
307    #[test]
308    fn new_sets_fields() {
309        let err = ApiError::new("MY_CODE", "my message");
310        assert_eq!(err.code, "MY_CODE");
311        assert_eq!(err.message, "my message");
312        assert!(err.details.is_none());
313    }
314
315    #[test]
316    fn with_details_sets_details() {
317        let err = ApiError::new("CODE", "msg").with_details(json!({ "field": "name" }));
318        assert_eq!(err.details.unwrap()["field"], "name");
319    }
320
321    #[test]
322    fn serializes_without_details() {
323        let err = ApiError::new("NOT_FOUND", "item not found");
324        let v = serde_json::to_value(&err).unwrap();
325        assert_eq!(v["code"], "NOT_FOUND");
326        assert_eq!(v["message"], "item not found");
327        assert!(v.get("details").is_none());
328    }
329
330    #[test]
331    fn serializes_with_details() {
332        let err = ApiError::new("VALIDATION_ERROR", "invalid").with_details(json!({ "x": 1 }));
333        let v = serde_json::to_value(&err).unwrap();
334        assert_eq!(v["details"]["x"], 1);
335    }
336
337    #[test]
338    fn display_formats_code_and_message() {
339        let err = ApiError::new("NOT_FOUND", "item not found");
340        assert_eq!(err.to_string(), "NOT_FOUND: item not found");
341    }
342
343    #[test]
344    fn implements_std_error() {
345        let err = ApiError::new("ERR", "something failed");
346        let _: &dyn std::error::Error = &err;
347    }
348
349    macro_rules! assert_factory {
350        ($method:expr, $expected_status:expr, $expected_code:expr) => {{
351            let (status, Json(body)) = $method;
352            assert_eq!(status, $expected_status);
353            assert_eq!(body.code, $expected_code);
354        }};
355    }
356
357    #[test]
358    fn bad_request_status_and_code() {
359        assert_factory!(
360            ApiError::bad_request("INVALID_FIELD", "bad"),
361            StatusCode::BAD_REQUEST,
362            "INVALID_FIELD"
363        );
364    }
365
366    #[test]
367    fn unauthorized_status_and_code() {
368        assert_factory!(
369            ApiError::unauthorized("please log in"),
370            StatusCode::UNAUTHORIZED,
371            "AUTH_REQUIRED"
372        );
373    }
374
375    #[test]
376    fn forbidden_status_and_code() {
377        assert_factory!(
378            ApiError::forbidden("no access"),
379            StatusCode::FORBIDDEN,
380            "FORBIDDEN"
381        );
382    }
383
384    #[test]
385    fn not_found_status_and_code() {
386        assert_factory!(
387            ApiError::not_found("missing"),
388            StatusCode::NOT_FOUND,
389            "NOT_FOUND"
390        );
391    }
392
393    #[test]
394    fn conflict_status_and_code() {
395        assert_factory!(
396            ApiError::conflict("already exists"),
397            StatusCode::CONFLICT,
398            "CONFLICT"
399        );
400    }
401
402    #[test]
403    fn unprocessable_status_and_code() {
404        assert_factory!(
405            ApiError::unprocessable("invalid input"),
406            StatusCode::UNPROCESSABLE_ENTITY,
407            "VALIDATION_ERROR"
408        );
409    }
410
411    #[test]
412    fn internal_status_and_code() {
413        assert_factory!(
414            ApiError::internal("oops"),
415            StatusCode::INTERNAL_SERVER_ERROR,
416            "INTERNAL_ERROR"
417        );
418    }
419
420    #[test]
421    fn db_error_status_and_code() {
422        assert_factory!(
423            ApiError::db_error(),
424            StatusCode::INTERNAL_SERVER_ERROR,
425            "DB_ERROR"
426        );
427    }
428
429    #[test]
430    fn too_many_requests_status_and_code() {
431        assert_factory!(
432            ApiError::too_many_requests("slow down"),
433            StatusCode::TOO_MANY_REQUESTS,
434            "RATE_LIMITED"
435        );
436    }
437
438    #[test]
439    fn service_unavailable_status_and_code() {
440        assert_factory!(
441            ApiError::service_unavailable("down for maintenance"),
442            StatusCode::SERVICE_UNAVAILABLE,
443            "SERVICE_UNAVAILABLE"
444        );
445    }
446
447    #[test]
448    fn not_implemented_status_and_code() {
449        assert_factory!(
450            ApiError::not_implemented("coming soon"),
451            StatusCode::NOT_IMPLEMENTED,
452            "NOT_IMPLEMENTED"
453        );
454    }
455
456    #[test]
457    fn with_source_adds_source_to_details() {
458        let err = ApiError::new("NOT_FOUND", "missing").with_source("db query");
459        let v = serde_json::to_value(&err).unwrap();
460        assert_eq!(v["details"]["source"], "db query");
461        assert_eq!(v["code"], "NOT_FOUND");
462    }
463
464    #[test]
465    fn with_source_and_with_details_both_present() {
466        let err = ApiError::new("ERROR", "msg")
467            .with_details(json!({ "user_id": 123 }))
468            .with_source("from somewhere");
469        let v = serde_json::to_value(&err).unwrap();
470        assert_eq!(v["details"]["source"], "from somewhere");
471        assert_eq!(v["details"]["user_id"], 123);
472    }
473
474    #[test]
475    fn from_io_error_creates_io_error_code() {
476        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
477        let api_err: ApiError = io_err.into();
478        assert_eq!(api_err.code, "IO_ERROR");
479        assert!(api_err.message.contains("IO error"));
480    }
481
482    #[test]
483    fn from_serde_json_error_creates_json_error_code() {
484        let json_str = "{ invalid json }";
485        let json_err: Result<serde_json::Value, _> = serde_json::from_str(json_str);
486        let api_err: ApiError = json_err.unwrap_err().into();
487        assert_eq!(api_err.code, "JSON_ERROR");
488        assert!(api_err.message.contains("JSON error"));
489    }
490
491    #[test]
492    fn io_error_conversion_captures_kind() {
493        let io_err = std::io::Error::new(std::io::ErrorKind::PermissionDenied, "permission denied");
494        let api_err: ApiError = io_err.into();
495        assert!(api_err.message.contains("permission denied"));
496    }
497
498    #[cfg(feature = "validator")]
499    #[test]
500    fn from_validation_errors_single_field() {
501        use std::borrow::Cow;
502        use validator::{ValidationError, ValidationErrors};
503
504        let mut errors = ValidationErrors::new();
505        let mut email = ValidationError::new("email");
506        email.message = Some(Cow::Borrowed("invalid email"));
507        errors.add("email", email);
508
509        let api_err: ApiError = errors.into();
510        let v = serde_json::to_value(api_err).unwrap();
511
512        assert_eq!(v["code"], "VALIDATION_ERROR");
513        assert_eq!(v["message"], "validation failed");
514        assert_eq!(v["details"]["fields"]["email"][0]["code"], "email");
515        assert_eq!(
516            v["details"]["fields"]["email"][0]["message"],
517            "invalid email"
518        );
519    }
520
521    #[cfg(feature = "validator")]
522    #[test]
523    fn from_validation_errors_multiple_fields_with_params() {
524        use std::borrow::Cow;
525        use validator::{ValidationError, ValidationErrors};
526
527        let mut errors = ValidationErrors::new();
528
529        let mut username = ValidationError::new("length");
530        username.message = Some(Cow::Borrowed("username too short"));
531        username.add_param(Cow::Borrowed("min"), &3);
532        errors.add("username", username);
533
534        let mut age = ValidationError::new("range");
535        age.add_param(Cow::Borrowed("min"), &18);
536        errors.add("age", age);
537
538        let api_err: ApiError = errors.into();
539        let v = serde_json::to_value(api_err).unwrap();
540
541        assert_eq!(v["details"]["fields"]["username"][0]["code"], "length");
542        assert_eq!(v["details"]["fields"]["username"][0]["params"]["min"], 3);
543        assert_eq!(v["details"]["fields"]["age"][0]["code"], "range");
544        assert_eq!(v["details"]["fields"]["age"][0]["params"]["min"], 18);
545    }
546
547    #[cfg(feature = "sqlx")]
548    #[test]
549    fn sqlx_row_not_found_maps_to_not_found() {
550        let api_err: ApiError = sqlx::Error::RowNotFound.into();
551        assert_eq!(api_err.code, "NOT_FOUND");
552        assert_eq!(api_err.message, "record not found");
553    }
554
555    #[cfg(feature = "sqlx")]
556    #[test]
557    fn sqlx_pool_timed_out_maps_to_service_unavailable() {
558        let api_err: ApiError = sqlx::Error::PoolTimedOut.into();
559        assert_eq!(api_err.code, "SERVICE_UNAVAILABLE");
560    }
561
562    #[cfg(feature = "sqlx")]
563    #[test]
564    fn sqlx_pool_closed_maps_to_service_unavailable() {
565        let api_err: ApiError = sqlx::Error::PoolClosed.into();
566        assert_eq!(api_err.code, "SERVICE_UNAVAILABLE");
567    }
568
569    #[cfg(feature = "sqlx")]
570    #[test]
571    fn sqlx_unknown_variant_maps_to_db_error() {
572        // Protocol is a non-pool, non-database variant that hits the catch-all arm.
573        let api_err: ApiError = sqlx::Error::Protocol("unexpected packet".into()).into();
574        assert_eq!(api_err.code, "DB_ERROR");
575        assert!(api_err.message.contains("database error"));
576    }
577}