axum_route_error/
route_error.rs

1use anyhow::Error as AnyhowError;
2use axum::http::StatusCode;
3use axum::response::IntoResponse;
4use axum::response::Response;
5use axum::Json;
6use serde::Deserialize;
7use serde::Serialize;
8use std::fmt::Debug;
9use std::fmt::Display;
10use std::fmt::Formatter;
11use std::fmt::Result as FmtResult;
12
13use super::RouteErrorOutput;
14use crate::RouteInternalErrorOutput;
15
16/// This Rust module provides a standard error type for routes.
17/// It encapsulates information about errors that occur while handling requests.
18/// It includes a status code, error details, any extra data,
19/// and a public error message.
20///
21/// It includes the means to output these as Json for the user.
22///
23/// The output will be in the form:
24/// ```json
25///     {
26///         "error": "My public error message"
27///     }
28/// ```
29///
30/// Most of the time you will want to simply return one of:
31///
32///  - `RouteError::new_unauthorised()`
33///  - `RouteError::new_not_found()`
34///  - `RouteError::new_bad_request()`
35///  - `RouteError::new_internal_server()`
36///
37/// Depending on which is the most appropriate.
38///
39pub struct RouteError<S = (), const EXPOSE_INTERNAL_ERROR: bool = false>
40where
41    S: Serialize + for<'a> Deserialize<'a> + Debug,
42{
43    status_code: StatusCode,
44    error: Option<AnyhowError>,
45    extra_data: Option<Box<S>>,
46    public_error_message: Option<String>,
47}
48
49impl RouteError<()> {
50    pub fn new_unauthorized() -> RouteError<()> {
51        Self::new_from_status(StatusCode::UNAUTHORIZED)
52    }
53
54    pub fn new_not_found() -> RouteError<()> {
55        Self::new_from_status(StatusCode::NOT_FOUND)
56    }
57
58    pub fn new_bad_request() -> RouteError<()> {
59        Self::new_from_status(StatusCode::BAD_REQUEST)
60    }
61
62    pub fn new_internal_server() -> RouteError<()> {
63        Self::new_from_status(StatusCode::INTERNAL_SERVER_ERROR)
64    }
65
66    pub fn new_conflict() -> RouteError<()> {
67        Self::new_from_status(StatusCode::CONFLICT)
68    }
69
70    pub fn new_forbidden() -> RouteError<()> {
71        Self::new_from_status(StatusCode::FORBIDDEN)
72    }
73
74    pub fn new_from_status(status_code: StatusCode) -> RouteError<()> {
75        Self {
76            status_code,
77            ..Self::default()
78        }
79    }
80}
81
82impl<S, const EXPOSE_INTERNAL_ERROR: bool> RouteError<S, EXPOSE_INTERNAL_ERROR>
83where
84    S: Serialize + for<'a> Deserialize<'a> + Debug,
85{
86    /// Set a new status code for the error response.
87    pub fn set_status_code(self, status_code: StatusCode) -> Self {
88        Self {
89            status_code,
90            ..self
91        }
92    }
93
94    /// Set an internal error.
95    ///
96    /// This is used for tracking the source of the error internally.
97    pub fn set_error(self, error: AnyhowError) -> Self {
98        Self {
99            error: Some(error),
100            ..self
101        }
102    }
103
104    ///
105    /// Sets additional error data to be added to the output.
106    /// Data here must be serialisable into Json.
107    ///
108    /// # Example Code
109    ///
110    /// ```rust
111    /// use axum_route_error::RouteError;
112    /// use serde::Deserialize;
113    /// use serde::Serialize;
114    ///
115    /// #[derive(Deserialize, Serialize, Debug)]
116    /// pub struct UserErrorInformation {
117    ///     pub guid: String
118    /// }
119    ///
120    /// let guid = "abc123".to_string();
121    /// let err = RouteError::new_not_found()
122    ///     .set_error_data(UserErrorInformation {
123    ///         guid,
124    ///     });
125    /// ```
126    ///
127    /// This will return a response with the JSON format:
128    ///
129    /// ```json
130    /// {
131    ///   "error": "The resource was not found",
132    ///   "username": "<the-username>"
133    /// }
134    /// ```
135    ///
136    pub fn set_error_data<NewS>(self, extra_data: NewS) -> RouteError<NewS>
137    where
138        NewS: Serialize + for<'a> Deserialize<'a> + Debug,
139    {
140        RouteError {
141            extra_data: Some(Box::new(extra_data)),
142            status_code: self.status_code,
143            error: self.error,
144            public_error_message: self.public_error_message,
145        }
146    }
147
148    /// Set the error message to display within the error.
149    ///
150    /// If this is not set, then an appropriate message is provided
151    /// based on the status code.
152    pub fn set_public_error_message(self, public_error_message: &str) -> Self {
153        Self {
154            public_error_message: Some(public_error_message.to_string()),
155            ..self
156        }
157    }
158
159    /// Returns the error message that will be shown to the end user.
160    pub fn public_error_message<'a>(&'a self) -> &'a str {
161        if let Some(public_error_message) = self.public_error_message.as_ref() {
162            return public_error_message;
163        }
164
165        status_code_to_public_message(self.status_code())
166    }
167
168    /// Returns the status code for the response.
169    pub fn status_code(&self) -> StatusCode {
170        self.status_code
171    }
172}
173
174impl<S, const EXPOSE_INTERNAL_ERROR: bool> Default for RouteError<S, EXPOSE_INTERNAL_ERROR>
175where
176    S: Serialize + for<'a> Deserialize<'a> + Debug,
177{
178    fn default() -> Self {
179        Self {
180            status_code: StatusCode::INTERNAL_SERVER_ERROR,
181            error: None,
182            extra_data: None,
183            public_error_message: None,
184        }
185    }
186}
187
188impl<S, const EXPOSE_INTERNAL_ERROR: bool> IntoResponse for RouteError<S, EXPOSE_INTERNAL_ERROR>
189where
190    S: Serialize + for<'a> Deserialize<'a> + Debug,
191{
192    fn into_response(self) -> Response {
193        let status = self.status_code();
194        let extra_data = self.extra_data;
195        let error = match self.public_error_message {
196            Some(public_error_message) => public_error_message,
197            None => status_code_to_public_message(status).to_string(),
198        };
199
200        let internal_error = if EXPOSE_INTERNAL_ERROR {
201            self.error.map(|err| RouteInternalErrorOutput {
202                name: format!("{}", err),
203                debug: format!("{:?}", err),
204            })
205        } else {
206            None
207        };
208
209        let output = RouteErrorOutput {
210            error,
211            internal_error,
212            extra_data,
213            ..RouteErrorOutput::default()
214        };
215        let body = Json(output);
216
217        (status, body).into_response()
218    }
219}
220
221impl<S, const EXPOSE_INTERNAL_ERROR: bool> Debug for RouteError<S, EXPOSE_INTERNAL_ERROR>
222where
223    S: Serialize + for<'a> Deserialize<'a> + Debug,
224{
225    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
226        write!(f, "{}, {:?}", self.public_error_message(), self.error)
227    }
228}
229
230impl<S, const EXPOSE_INTERNAL_ERROR: bool> Display for RouteError<S, EXPOSE_INTERNAL_ERROR>
231where
232    S: Serialize + for<'a> Deserialize<'a> + Debug,
233{
234    fn fmt(&self, f: &mut Formatter) -> FmtResult {
235        write!(f, "{}", self.public_error_message())
236    }
237}
238
239/// This essentially means if you can turn it into an Anyhow,
240/// then you can turn it into a RouteError.
241impl<S, const EXPOSE_INTERNAL_ERROR: bool, FE> From<FE> for RouteError<S, EXPOSE_INTERNAL_ERROR>
242where
243    S: Serialize + for<'a> Deserialize<'a> + Debug,
244    FE: Into<AnyhowError>,
245{
246    fn from(error: FE) -> Self {
247        let anyhow_error: AnyhowError = error.into();
248        ::tracing::error!("{:?}", anyhow_error);
249
250        RouteError {
251            status_code: StatusCode::INTERNAL_SERVER_ERROR,
252            error: Some(anyhow_error),
253            ..Self::default()
254        }
255    }
256}
257
258fn status_code_to_public_message(status_code: StatusCode) -> &'static str {
259    match status_code {
260        StatusCode::CONFLICT => "The request is not allowed",
261        StatusCode::UNAUTHORIZED => "You are not authorised to access this endpoint",
262        StatusCode::NOT_FOUND => "The resource was not found",
263        StatusCode::BAD_REQUEST => "Bad request made",
264        StatusCode::FORBIDDEN => "Request is forbidden",
265        StatusCode::IM_A_TEAPOT => "I'm a teapot",
266        StatusCode::TOO_MANY_REQUESTS => "Too many requests",
267        StatusCode::BAD_GATEWAY => "Bad gateway",
268        StatusCode::SERVICE_UNAVAILABLE => "Service unavailable",
269        StatusCode::GATEWAY_TIMEOUT => "Gateway timeout",
270        StatusCode::INTERNAL_SERVER_ERROR => "An unexpected error occurred",
271        _ => "An unknown error occurred",
272    }
273}
274
275#[cfg(test)]
276mod test_route_error {
277    use super::*;
278    use crate::RouteErrorOutput;
279    use anyhow::anyhow;
280    use axum::response::IntoResponse;
281    use http_body_util::BodyExt;
282    use serde_json::from_slice;
283
284    #[tokio::test]
285    async fn it_should_not_output_internal_error() {
286        fn raise_error() -> Result<(), RouteError> {
287            Err(anyhow!("Too many foxes in the DB"))?;
288
289            Ok(())
290        }
291
292        let err = raise_error().unwrap_err();
293        let response = err.into_response();
294        let response_body = response.into_body();
295        let response_bytes = response_body.collect().await.unwrap().to_bytes();
296        let body = from_slice::<RouteErrorOutput<()>>(&response_bytes).unwrap();
297
298        assert_eq!(body.internal_error, None);
299    }
300}