error_ext/
axum.rs

1//! Error utilities for axum.
2
3use axum::{
4    Json,
5    http::StatusCode,
6    response::{IntoResponse, Response},
7};
8
9/// Error that can be used as axum response, with an appropriate HTTP status code and – except for
10/// `Internal` – with one or more error messages conveyed as a JSON string array.
11#[derive(Debug)]
12#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
13pub enum Error {
14    /// `400 Bad Request`, e.g. because of invalid path or query arguments.
15    InvalidArgs(Vec<String>),
16
17    /// `401 Unauthorized`.
18    Unauthorized(String),
19
20    /// `403` Forbidden.
21    Forbidden(String),
22
23    /// `404 Not Found`.
24    NotFound(String),
25
26    /// `409 Conflict`, e.g. because of an already existing resource.
27    Conflict(String),
28
29    /// `422 Unprocessable Entity`, e.g. because of the JSON payload could not be parsed.
30    InvalidEntity(Vec<String>),
31
32    /// `500 Internal Server Error`.
33    Internal,
34
35    /// `503 Service Unavailable`.
36    ServiceUnavailable,
37}
38
39impl Error {
40    /// Create [Error::InvalidArgs] with the given error.
41    pub fn invalid_args<T>(error: T) -> Self
42    where
43        T: ToString,
44    {
45        let errors = vec![error.to_string()];
46        Error::InvalidArgs(errors)
47    }
48
49    /// Create [Error::InvalidArgs] with the given errors.
50    pub fn invalid_args_all<I, T>(errors: I) -> Self
51    where
52        I: IntoIterator<Item = T>,
53        T: ToString,
54    {
55        let errors = errors.into_iter().map(|e| e.to_string()).collect();
56        Error::InvalidArgs(errors)
57    }
58
59    /// Create [Error::Unauthorized] with the given error.
60    pub fn unauthorized<T>(error: T) -> Self
61    where
62        T: ToString,
63    {
64        Error::Unauthorized(error.to_string())
65    }
66
67    /// Create [Error::Forbidden] with the given error.
68    pub fn forbidden<T>(error: T) -> Self
69    where
70        T: ToString,
71    {
72        Error::Forbidden(error.to_string())
73    }
74
75    /// Create [Error::NotFound] with the given error.
76    pub fn not_found<T>(error: T) -> Self
77    where
78        T: ToString,
79    {
80        Error::NotFound(error.to_string())
81    }
82
83    /// Create [Error::Conflict] with the given error.
84    pub fn conflict<T>(error: T) -> Self
85    where
86        T: ToString,
87    {
88        Error::Conflict(error.to_string())
89    }
90
91    /// Create [Error::InvalidEntity] with the given error.
92    pub fn invalid_entity<T>(error: T) -> Self
93    where
94        T: ToString,
95    {
96        let errors = vec![error.to_string()];
97        Error::InvalidEntity(errors)
98    }
99
100    /// Create [Error::InvalidEntity] with the given errors.
101    pub fn invalid_entity_all<I, T>(errors: I) -> Self
102    where
103        I: IntoIterator<Item = T>,
104        T: ToString,
105    {
106        let errors = errors.into_iter().map(|e| e.to_string()).collect();
107        Error::InvalidEntity(errors)
108    }
109}
110
111impl IntoResponse for Error {
112    fn into_response(self) -> Response {
113        match self {
114            Error::InvalidArgs(errors) => {
115                let errors = Json(
116                    errors
117                        .into_iter()
118                        .map(|e| e.to_string())
119                        .collect::<Vec<_>>(),
120                );
121                (StatusCode::BAD_REQUEST, errors).into_response()
122            }
123
124            Error::Unauthorized(error) => {
125                let errors = Json(vec![error.to_string()]);
126                (StatusCode::UNAUTHORIZED, errors).into_response()
127            }
128
129            Error::Forbidden(error) => {
130                let errors = Json(vec![error.to_string()]);
131                (StatusCode::FORBIDDEN, errors).into_response()
132            }
133
134            Error::NotFound(error) => {
135                let errors = Json(vec![error.to_string()]);
136                (StatusCode::NOT_FOUND, errors).into_response()
137            }
138
139            Error::Conflict(error) => {
140                let errors = Json(vec![error.to_string()]);
141                (StatusCode::CONFLICT, errors).into_response()
142            }
143
144            Error::InvalidEntity(errors) => {
145                let errors = Json(
146                    errors
147                        .into_iter()
148                        .map(|e| e.to_string())
149                        .collect::<Vec<_>>(),
150                );
151                (StatusCode::UNPROCESSABLE_ENTITY, errors).into_response()
152            }
153
154            Error::Internal => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
155
156            Error::ServiceUnavailable => StatusCode::SERVICE_UNAVAILABLE.into_response(),
157        }
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use anyhow::anyhow;
165    use std::iter;
166    use thiserror::Error;
167
168    #[derive(Debug, Error)]
169    #[error("test")]
170    struct TestError;
171
172    #[test]
173    fn test_invalid_args() {
174        let _ = Error::invalid_args("test").into_response();
175        let _ = Error::invalid_args(anyhow!("test")).into_response();
176        let response = Error::invalid_args(TestError).into_response();
177        assert_eq!(response.status(), StatusCode::BAD_REQUEST);
178    }
179
180    #[test]
181    fn test_invalid_args_all() {
182        let _ = Error::invalid_args_all(vec!["test"]).into_response();
183        let response = Error::invalid_args_all(iter::once(TestError)).into_response();
184        assert_eq!(response.status(), StatusCode::BAD_REQUEST);
185    }
186
187    #[test]
188    fn test_unauthorized() {
189        let response = Error::unauthorized("test").into_response();
190        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
191    }
192
193    #[test]
194    fn test_forbidden() {
195        let response = Error::forbidden("test").into_response();
196        assert_eq!(response.status(), StatusCode::FORBIDDEN);
197    }
198
199    #[test]
200    fn test_not_found() {
201        let _ = Error::not_found("test").into_response();
202        let _ = Error::not_found(anyhow!("test")).into_response();
203        let response = Error::not_found(TestError).into_response();
204        assert_eq!(response.status(), StatusCode::NOT_FOUND);
205    }
206
207    #[test]
208    fn test_conflict() {
209        let _ = Error::conflict("test").into_response();
210        let _ = Error::conflict(anyhow!("test")).into_response();
211        let response = Error::conflict(TestError).into_response();
212        assert_eq!(response.status(), StatusCode::CONFLICT);
213    }
214
215    #[test]
216    fn test_invalid_entity() {
217        let _ = Error::invalid_entity("test").into_response();
218        let _ = Error::invalid_entity(anyhow!("test")).into_response();
219        let response = Error::invalid_entity(TestError).into_response();
220        assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
221    }
222
223    #[test]
224    fn test_invalid_entity_all() {
225        let _ = Error::invalid_entity_all(vec!["test"]).into_response();
226        let response = Error::invalid_entity_all(iter::once(TestError)).into_response();
227        assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
228    }
229
230    #[test]
231    fn test_internal() {
232        let response = Error::Internal.into_response();
233        assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
234    }
235}