axum_route_error/
route_error.rs1use anyhow::Error as AnyhowError;
2use axum::http::StatusCode;
3use axum::response::IntoResponse;
4use axum::response::Response;
5use axum::Json;
6use serde::Deserialize;
7use serde::Serialize;
8use std::fmt::Debug;
9use std::fmt::Display;
10use std::fmt::Formatter;
11use std::fmt::Result as FmtResult;
12
13use super::RouteErrorOutput;
14use crate::RouteInternalErrorOutput;
15
16pub struct RouteError<S = (), const EXPOSE_INTERNAL_ERROR: bool = false>
40where
41 S: Serialize + for<'a> Deserialize<'a> + Debug,
42{
43 status_code: StatusCode,
44 error: Option<AnyhowError>,
45 extra_data: Option<Box<S>>,
46 public_error_message: Option<String>,
47}
48
49impl RouteError<()> {
50 pub fn new_unauthorized() -> RouteError<()> {
51 Self::new_from_status(StatusCode::UNAUTHORIZED)
52 }
53
54 pub fn new_not_found() -> RouteError<()> {
55 Self::new_from_status(StatusCode::NOT_FOUND)
56 }
57
58 pub fn new_bad_request() -> RouteError<()> {
59 Self::new_from_status(StatusCode::BAD_REQUEST)
60 }
61
62 pub fn new_internal_server() -> RouteError<()> {
63 Self::new_from_status(StatusCode::INTERNAL_SERVER_ERROR)
64 }
65
66 pub fn new_conflict() -> RouteError<()> {
67 Self::new_from_status(StatusCode::CONFLICT)
68 }
69
70 pub fn new_forbidden() -> RouteError<()> {
71 Self::new_from_status(StatusCode::FORBIDDEN)
72 }
73
74 pub fn new_from_status(status_code: StatusCode) -> RouteError<()> {
75 Self {
76 status_code,
77 ..Self::default()
78 }
79 }
80}
81
82impl<S, const EXPOSE_INTERNAL_ERROR: bool> RouteError<S, EXPOSE_INTERNAL_ERROR>
83where
84 S: Serialize + for<'a> Deserialize<'a> + Debug,
85{
86 pub fn set_status_code(self, status_code: StatusCode) -> Self {
88 Self {
89 status_code,
90 ..self
91 }
92 }
93
94 pub fn set_error(self, error: AnyhowError) -> Self {
98 Self {
99 error: Some(error),
100 ..self
101 }
102 }
103
104 pub fn set_error_data<NewS>(self, extra_data: NewS) -> RouteError<NewS>
137 where
138 NewS: Serialize + for<'a> Deserialize<'a> + Debug,
139 {
140 RouteError {
141 extra_data: Some(Box::new(extra_data)),
142 status_code: self.status_code,
143 error: self.error,
144 public_error_message: self.public_error_message,
145 }
146 }
147
148 pub fn set_public_error_message(self, public_error_message: &str) -> Self {
153 Self {
154 public_error_message: Some(public_error_message.to_string()),
155 ..self
156 }
157 }
158
159 pub fn public_error_message<'a>(&'a self) -> &'a str {
161 if let Some(public_error_message) = self.public_error_message.as_ref() {
162 return public_error_message;
163 }
164
165 status_code_to_public_message(self.status_code())
166 }
167
168 pub fn status_code(&self) -> StatusCode {
170 self.status_code
171 }
172}
173
174impl<S, const EXPOSE_INTERNAL_ERROR: bool> Default for RouteError<S, EXPOSE_INTERNAL_ERROR>
175where
176 S: Serialize + for<'a> Deserialize<'a> + Debug,
177{
178 fn default() -> Self {
179 Self {
180 status_code: StatusCode::INTERNAL_SERVER_ERROR,
181 error: None,
182 extra_data: None,
183 public_error_message: None,
184 }
185 }
186}
187
188impl<S, const EXPOSE_INTERNAL_ERROR: bool> IntoResponse for RouteError<S, EXPOSE_INTERNAL_ERROR>
189where
190 S: Serialize + for<'a> Deserialize<'a> + Debug,
191{
192 fn into_response(self) -> Response {
193 let status = self.status_code();
194 let extra_data = self.extra_data;
195 let error = match self.public_error_message {
196 Some(public_error_message) => public_error_message,
197 None => status_code_to_public_message(status).to_string(),
198 };
199
200 let internal_error = if EXPOSE_INTERNAL_ERROR {
201 self.error.map(|err| RouteInternalErrorOutput {
202 name: format!("{}", err),
203 debug: format!("{:?}", err),
204 })
205 } else {
206 None
207 };
208
209 let output = RouteErrorOutput {
210 error,
211 internal_error,
212 extra_data,
213 ..RouteErrorOutput::default()
214 };
215 let body = Json(output);
216
217 (status, body).into_response()
218 }
219}
220
221impl<S, const EXPOSE_INTERNAL_ERROR: bool> Debug for RouteError<S, EXPOSE_INTERNAL_ERROR>
222where
223 S: Serialize + for<'a> Deserialize<'a> + Debug,
224{
225 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
226 write!(f, "{}, {:?}", self.public_error_message(), self.error)
227 }
228}
229
230impl<S, const EXPOSE_INTERNAL_ERROR: bool> Display for RouteError<S, EXPOSE_INTERNAL_ERROR>
231where
232 S: Serialize + for<'a> Deserialize<'a> + Debug,
233{
234 fn fmt(&self, f: &mut Formatter) -> FmtResult {
235 write!(f, "{}", self.public_error_message())
236 }
237}
238
239impl<S, const EXPOSE_INTERNAL_ERROR: bool, FE> From<FE> for RouteError<S, EXPOSE_INTERNAL_ERROR>
242where
243 S: Serialize + for<'a> Deserialize<'a> + Debug,
244 FE: Into<AnyhowError>,
245{
246 fn from(error: FE) -> Self {
247 let anyhow_error: AnyhowError = error.into();
248 ::tracing::error!("{:?}", anyhow_error);
249
250 RouteError {
251 status_code: StatusCode::INTERNAL_SERVER_ERROR,
252 error: Some(anyhow_error),
253 ..Self::default()
254 }
255 }
256}
257
258fn status_code_to_public_message(status_code: StatusCode) -> &'static str {
259 match status_code {
260 StatusCode::CONFLICT => "The request is not allowed",
261 StatusCode::UNAUTHORIZED => "You are not authorised to access this endpoint",
262 StatusCode::NOT_FOUND => "The resource was not found",
263 StatusCode::BAD_REQUEST => "Bad request made",
264 StatusCode::FORBIDDEN => "Request is forbidden",
265 StatusCode::IM_A_TEAPOT => "I'm a teapot",
266 StatusCode::TOO_MANY_REQUESTS => "Too many requests",
267 StatusCode::BAD_GATEWAY => "Bad gateway",
268 StatusCode::SERVICE_UNAVAILABLE => "Service unavailable",
269 StatusCode::GATEWAY_TIMEOUT => "Gateway timeout",
270 StatusCode::INTERNAL_SERVER_ERROR => "An unexpected error occurred",
271 _ => "An unknown error occurred",
272 }
273}
274
275#[cfg(test)]
276mod test_route_error {
277 use super::*;
278 use crate::RouteErrorOutput;
279 use anyhow::anyhow;
280 use axum::response::IntoResponse;
281 use http_body_util::BodyExt;
282 use serde_json::from_slice;
283
284 #[tokio::test]
285 async fn it_should_not_output_internal_error() {
286 fn raise_error() -> Result<(), RouteError> {
287 Err(anyhow!("Too many foxes in the DB"))?;
288
289 Ok(())
290 }
291
292 let err = raise_error().unwrap_err();
293 let response = err.into_response();
294 let response_body = response.into_body();
295 let response_bytes = response_body.collect().await.unwrap().to_bytes();
296 let body = from_slice::<RouteErrorOutput<()>>(&response_bytes).unwrap();
297
298 assert_eq!(body.internal_error, None);
299 }
300}