axum_anyhow/
extensions.rs

1use crate::{ApiError, ApiResult};
2use anyhow::Result;
3use axum::http::StatusCode;
4
5/// Extension trait for `anyhow::Result` to convert errors into `ApiError` with HTTP
6/// status codes.
7///
8/// This trait provides methods to convert `Result<T, E>` where `E: Into<anyhow::Error>`
9/// into `ApiResult<T>`, attaching HTTP status codes and error details.
10///
11/// # Example
12///
13/// ```rust
14/// use anyhow::{anyhow, Result};
15/// use axum_anyhow::{ApiResult, ResultExt};
16/// use axum::http::StatusCode;
17///
18/// fn validate_email(email: &str) -> Result<String> {
19///     if email.contains('@') {
20///         Ok(email.to_string())
21///     } else {
22///         Err(anyhow!("Invalid email format"))
23///     }
24/// }
25///
26/// async fn handler(email: String) -> ApiResult<String> {
27///     let validated = validate_email(&email)
28///         .context_bad_request("Invalid Email", "Email must contain @")?;
29///     Ok(validated)
30/// }
31///
32/// # tokio_test::block_on(async {
33/// let api_result = handler("not-an-email".to_string()).await;
34/// assert!(api_result.is_err());
35/// let err = api_result.unwrap_err();
36/// assert_eq!(err.status, StatusCode::BAD_REQUEST);
37/// assert_eq!(err.title, "Invalid Email");
38/// assert_eq!(err.detail, "Email must contain @");
39/// # })
40/// ```
41pub trait ResultExt<T> {
42    /// Converts an error to an `ApiError` with a custom status code.
43    ///
44    /// # Arguments
45    ///
46    /// * `status` - The HTTP status code to use
47    /// * `title` - A short, human-readable summary of the error
48    /// * `detail` - A detailed explanation of the error
49    fn context_status(self, status: StatusCode, title: &str, detail: &str) -> ApiResult<T>;
50
51    /// Converts an error to a 400 Bad Request error.
52    ///
53    /// # Arguments
54    ///
55    /// * `title` - A short, human-readable summary of the error
56    /// * `detail` - A detailed explanation of the error
57    fn context_bad_request(self, title: &str, detail: &str) -> ApiResult<T>;
58
59    /// Converts an error to a 401 Unauthorized error (missing or invalid credentials).
60    ///
61    /// # Arguments
62    ///
63    /// * `title` - A short, human-readable summary of the error
64    /// * `detail` - A detailed explanation of the error
65    fn context_unauthorized(self, title: &str, detail: &str) -> ApiResult<T>;
66
67    /// Converts an error to a 403 Forbidden error (authenticated but lacks permissions).
68    ///
69    /// # Arguments
70    ///
71    /// * `title` - A short, human-readable summary of the error
72    /// * `detail` - A detailed explanation of the error
73    fn context_forbidden(self, title: &str, detail: &str) -> ApiResult<T>;
74
75    /// Converts an error to a 404 Not Found error.
76    ///
77    /// # Arguments
78    ///
79    /// * `title` - A short, human-readable summary of the error
80    /// * `detail` - A detailed explanation of the error
81    fn context_not_found(self, title: &str, detail: &str) -> ApiResult<T>;
82
83    /// Converts an error to a 500 Internal Server Error.
84    ///
85    /// # Arguments
86    ///
87    /// * `title` - A short, human-readable summary of the error
88    /// * `detail` - A detailed explanation of the error
89    fn context_internal(self, title: &str, detail: &str) -> ApiResult<T>;
90}
91
92impl<T, E> ResultExt<T> for Result<T, E>
93where
94    E: IntoApiError,
95{
96    fn context_status(self, status: StatusCode, title: &str, detail: &str) -> ApiResult<T> {
97        self.map_err(|err| err.context_status(status, title, detail))
98    }
99
100    fn context_bad_request(self, title: &str, detail: &str) -> ApiResult<T> {
101        self.map_err(|err| err.context_bad_request(title, detail))
102    }
103
104    fn context_unauthorized(self, title: &str, detail: &str) -> ApiResult<T> {
105        self.map_err(|err| err.context_unauthorized(title, detail))
106    }
107
108    fn context_forbidden(self, title: &str, detail: &str) -> ApiResult<T> {
109        self.map_err(|err| err.context_forbidden(title, detail))
110    }
111
112    fn context_not_found(self, title: &str, detail: &str) -> ApiResult<T> {
113        self.map_err(|err| err.context_not_found(title, detail))
114    }
115
116    fn context_internal(self, title: &str, detail: &str) -> ApiResult<T> {
117        self.map_err(|err| err.context_internal(title, detail))
118    }
119}
120
121/// Extension trait for `Option<T>` to convert `None` into `ApiError` with HTTP status codes.
122///
123/// This trait provides methods to convert `Option<T>` into `ApiResult<T>`, converting
124/// `None` values into errors with appropriate HTTP status codes and error details.
125///
126/// # Example
127///
128/// ```rust
129/// use axum_anyhow::{ApiResult, OptionExt};
130/// use axum::http::StatusCode;
131///
132///
133/// fn find_user(id: u32) -> Option<String> {
134///     if (id == 0) {
135///         Some("Alice".to_string())
136///     } else {
137///         None
138///     }
139/// }
140///
141/// async fn handler(id: u32) -> ApiResult<String> {
142///     let user = find_user(id)
143///         .context_not_found("User Not Found", "No user with that ID exists")?;
144///     Ok(user)
145/// }
146///
147/// # tokio_test::block_on(async {
148/// let api_result = handler(1).await;
149/// assert!(api_result.is_err());
150/// let err = api_result.unwrap_err();
151/// assert_eq!(err.status, StatusCode::NOT_FOUND);
152/// assert_eq!(err.title, "User Not Found");
153/// assert_eq!(err.detail, "No user with that ID exists");
154/// # })
155/// ```
156pub trait OptionExt<T> {
157    /// Converts `None` to an `ApiError` with a custom status code.
158    ///
159    /// # Arguments
160    ///
161    /// * `status` - The HTTP status code to use
162    /// * `title` - A short, human-readable summary of the error
163    /// * `detail` - A detailed explanation of the error
164    fn context_status(self, status: StatusCode, title: &str, detail: &str) -> ApiResult<T>;
165
166    /// Converts `None` to a 400 Bad Request error.
167    ///
168    /// # Arguments
169    ///
170    /// * `title` - A short, human-readable summary of the error
171    /// * `detail` - A detailed explanation of the error
172    fn context_bad_request(self, title: &str, detail: &str) -> ApiResult<T>;
173
174    /// Converts `None` to a 401 Unauthorized error (missing or invalid credentials).
175    ///
176    /// # Arguments
177    ///
178    /// * `title` - A short, human-readable summary of the error
179    /// * `detail` - A detailed explanation of the error
180    fn context_unauthorized(self, title: &str, detail: &str) -> ApiResult<T>;
181
182    /// Converts `None` to a 403 Forbidden error (authenticated but lacks permissions).
183    ///
184    /// # Arguments
185    ///
186    /// * `title` - A short, human-readable summary of the error
187    /// * `detail` - A detailed explanation of the error
188    fn context_forbidden(self, title: &str, detail: &str) -> ApiResult<T>;
189
190    /// Converts `None` to a 404 Not Found error.
191    ///
192    /// # Arguments
193    ///
194    /// * `title` - A short, human-readable summary of the error
195    /// * `detail` - A detailed explanation of the error
196    fn context_not_found(self, title: &str, detail: &str) -> ApiResult<T>;
197
198    /// Converts `None` to a 500 Internal Server Error.
199    ///
200    /// # Arguments
201    ///
202    /// * `title` - A short, human-readable summary of the error
203    /// * `detail` - A detailed explanation of the error
204    fn context_internal(self, title: &str, detail: &str) -> ApiResult<T>;
205}
206
207impl<T> OptionExt<T> for Option<T> {
208    fn context_status(self, status: StatusCode, title: &str, detail: &str) -> ApiResult<T> {
209        self.ok_or_else(|| {
210            ApiError::builder()
211                .status(status)
212                .title(title)
213                .detail(detail)
214                .build()
215        })
216    }
217
218    fn context_bad_request(self, title: &str, detail: &str) -> ApiResult<T> {
219        self.context_status(StatusCode::BAD_REQUEST, title, detail)
220    }
221
222    fn context_unauthorized(self, title: &str, detail: &str) -> ApiResult<T> {
223        self.context_status(StatusCode::UNAUTHORIZED, title, detail)
224    }
225
226    fn context_forbidden(self, title: &str, detail: &str) -> ApiResult<T> {
227        self.context_status(StatusCode::FORBIDDEN, title, detail)
228    }
229
230    fn context_not_found(self, title: &str, detail: &str) -> ApiResult<T> {
231        self.context_status(StatusCode::NOT_FOUND, title, detail)
232    }
233
234    fn context_internal(self, title: &str, detail: &str) -> ApiResult<T> {
235        self.context_status(StatusCode::INTERNAL_SERVER_ERROR, title, detail)
236    }
237}
238
239/// Extension trait for converting any error type into `ApiError` with HTTP status codes.
240///
241/// This trait is implemented for all types that can be converted into `anyhow::Error`.
242/// It provides methods to directly convert errors into `ApiError` instances with
243/// specific HTTP status codes.
244///
245/// # Example
246///
247/// ```rust
248/// use anyhow::anyhow;
249/// use axum_anyhow::{ApiError, IntoApiError};
250///
251/// let error = anyhow!("Something went wrong");
252/// let api_error: ApiError = error.context_internal("Internal Error", "Database failed");
253/// ```
254pub trait IntoApiError {
255    /// Converts an error to an `ApiError` with a custom status code.
256    ///
257    /// # Arguments
258    ///
259    /// * `status` - The HTTP status code to use
260    /// * `title` - A short, human-readable summary of the error
261    /// * `detail` - A detailed explanation of the error
262    fn context_status(self, status: StatusCode, title: &str, detail: &str) -> ApiError;
263
264    /// Converts an error to a 400 Bad Request error.
265    ///
266    /// # Arguments
267    ///
268    /// * `title` - A short, human-readable summary of the error
269    /// * `detail` - A detailed explanation of the error
270    fn context_bad_request(self, title: &str, detail: &str) -> ApiError;
271
272    /// Converts an error to a 401 Unauthorized error (missing or invalid credentials).
273    ///
274    /// # Arguments
275    ///
276    /// * `title` - A short, human-readable summary of the error
277    /// * `detail` - A detailed explanation of the error
278    fn context_unauthorized(self, title: &str, detail: &str) -> ApiError;
279
280    /// Converts an error to a 403 Forbidden error (authenticated but lacks permissions).
281    ///
282    /// # Arguments
283    ///
284    /// * `title` - A short, human-readable summary of the error
285    /// * `detail` - A detailed explanation of the error
286    fn context_forbidden(self, title: &str, detail: &str) -> ApiError;
287
288    /// Converts an error to a 404 Not Found error.
289    ///
290    /// # Arguments
291    ///
292    /// * `title` - A short, human-readable summary of the error
293    /// * `detail` - A detailed explanation of the error
294    fn context_not_found(self, title: &str, detail: &str) -> ApiError;
295
296    /// Converts an error to a 500 Internal Server Error.
297    ///
298    /// # Arguments
299    ///
300    /// * `title` - A short, human-readable summary of the error
301    /// * `detail` - A detailed explanation of the error
302    fn context_internal(self, title: &str, detail: &str) -> ApiError;
303}
304
305impl<E> IntoApiError for E
306where
307    E: Into<anyhow::Error>,
308{
309    fn context_status(self, status: StatusCode, title: &str, detail: &str) -> ApiError {
310        ApiError::builder()
311            .status(status)
312            .title(title)
313            .detail(detail)
314            .error(self)
315            .build()
316    }
317
318    fn context_bad_request(self, title: &str, detail: &str) -> ApiError {
319        self.context_status(StatusCode::BAD_REQUEST, title, detail)
320    }
321
322    fn context_unauthorized(self, title: &str, detail: &str) -> ApiError {
323        self.context_status(StatusCode::UNAUTHORIZED, title, detail)
324    }
325
326    fn context_forbidden(self, title: &str, detail: &str) -> ApiError {
327        self.context_status(StatusCode::FORBIDDEN, title, detail)
328    }
329
330    fn context_not_found(self, title: &str, detail: &str) -> ApiError {
331        self.context_status(StatusCode::NOT_FOUND, title, detail)
332    }
333
334    fn context_internal(self, title: &str, detail: &str) -> ApiError {
335        self.context_status(StatusCode::INTERNAL_SERVER_ERROR, title, detail)
336    }
337}
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342    use anyhow::anyhow;
343
344    #[test]
345    fn test_result_ext_context_bad_request_on_err() {
346        let result: Result<i32> = Err(anyhow!("Original error"));
347        let api_result = result.context_bad_request("Bad Request", "Invalid data");
348
349        assert!(api_result.is_err());
350        let err = api_result.unwrap_err();
351        assert_eq!(err.status, StatusCode::BAD_REQUEST);
352        assert_eq!(err.title, "Bad Request");
353        assert_eq!(err.detail, "Invalid data");
354    }
355
356    #[test]
357    fn test_result_ext_context_bad_request_on_ok() {
358        let result: Result<i32> = Ok(42);
359        let api_result = result.context_bad_request("Bad Request", "Invalid data");
360
361        assert!(api_result.is_ok());
362        assert_eq!(api_result.unwrap(), 42);
363    }
364
365    #[test]
366    fn test_result_ext_context_unauthorized() {
367        let result: Result<String> = Err(anyhow!("Token missing"));
368        let api_result = result.context_unauthorized("Unauthorized", "No token");
369
370        assert!(api_result.is_err());
371        let err = api_result.unwrap_err();
372        assert_eq!(err.status, StatusCode::UNAUTHORIZED);
373        assert_eq!(err.title, "Unauthorized");
374        assert_eq!(err.detail, "No token");
375    }
376
377    #[test]
378    fn test_result_ext_context_forbidden() {
379        let result: Result<String> = Err(anyhow!("Permission denied"));
380        let api_result = result.context_forbidden("Forbidden", "Insufficient permissions");
381
382        assert!(api_result.is_err());
383        let err = api_result.unwrap_err();
384        assert_eq!(err.status, StatusCode::FORBIDDEN);
385        assert_eq!(err.title, "Forbidden");
386        assert_eq!(err.detail, "Insufficient permissions");
387    }
388
389    #[test]
390    fn test_result_ext_context_not_found() {
391        let result: Result<String> = Err(anyhow!("Resource missing"));
392        let api_result = result.context_not_found("Not Found", "User not found");
393
394        assert!(api_result.is_err());
395        let err = api_result.unwrap_err();
396        assert_eq!(err.status, StatusCode::NOT_FOUND);
397        assert_eq!(err.title, "Not Found");
398        assert_eq!(err.detail, "User not found");
399    }
400
401    #[test]
402    fn test_result_ext_context_internal() {
403        let result: Result<String> = Err(anyhow!("Database error"));
404        let api_result = result.context_internal("Internal Error", "Database failed");
405
406        assert!(api_result.is_err());
407        let err = api_result.unwrap_err();
408        assert_eq!(err.status, StatusCode::INTERNAL_SERVER_ERROR);
409        assert_eq!(err.title, "Internal Error");
410        assert_eq!(err.detail, "Database failed");
411    }
412
413    #[test]
414    fn test_result_ext_context_status() {
415        let result: Result<String> = Err(anyhow!("Custom error"));
416        let api_result = result.context_status(StatusCode::CONFLICT, "Conflict", "Duplicate entry");
417
418        assert!(api_result.is_err());
419        let err = api_result.unwrap_err();
420        assert_eq!(err.status, StatusCode::CONFLICT);
421        assert_eq!(err.title, "Conflict");
422        assert_eq!(err.detail, "Duplicate entry");
423    }
424
425    #[test]
426    fn test_result_ext_with_non_anyhow_error() {
427        let result = "not_a_number".parse::<i32>();
428        let api_result = result.context_bad_request("Bad Request", "Value must be a number");
429
430        assert!(api_result.is_err());
431        let err = api_result.unwrap_err();
432        assert_eq!(err.status, StatusCode::BAD_REQUEST);
433        assert_eq!(err.title, "Bad Request");
434        assert_eq!(err.detail, "Value must be a number");
435    }
436
437    #[test]
438    fn test_option_ext_context_bad_request_on_none() {
439        let option: Option<i32> = None;
440        let api_result = option.context_bad_request("Bad Request", "Value is required");
441
442        assert!(api_result.is_err());
443        let err = api_result.unwrap_err();
444        assert_eq!(err.status, StatusCode::BAD_REQUEST);
445        assert_eq!(err.title, "Bad Request");
446        assert_eq!(err.detail, "Value is required");
447    }
448
449    #[test]
450    fn test_option_ext_context_bad_request_on_some() {
451        let option: Option<i32> = Some(42);
452        let api_result = option.context_bad_request("Bad Request", "Value is required");
453
454        assert!(api_result.is_ok());
455        assert_eq!(api_result.unwrap(), 42);
456    }
457
458    #[test]
459    fn test_option_ext_context_unauthorized() {
460        let option: Option<String> = None;
461        let api_result = option.context_unauthorized("Unauthorized", "Token missing");
462
463        assert!(api_result.is_err());
464        let err = api_result.unwrap_err();
465        assert_eq!(err.status, StatusCode::UNAUTHORIZED);
466        assert_eq!(err.title, "Unauthorized");
467        assert_eq!(err.detail, "Token missing");
468    }
469
470    #[test]
471    fn test_option_ext_context_forbidden() {
472        let option: Option<String> = None;
473        let api_result = option.context_forbidden("Forbidden", "No access");
474
475        assert!(api_result.is_err());
476        let err = api_result.unwrap_err();
477        assert_eq!(err.status, StatusCode::FORBIDDEN);
478        assert_eq!(err.title, "Forbidden");
479        assert_eq!(err.detail, "No access");
480    }
481
482    #[test]
483    fn test_option_ext_context_not_found() {
484        let option: Option<String> = None;
485        let api_result = option.context_not_found("Not Found", "Resource missing");
486
487        assert!(api_result.is_err());
488        let err = api_result.unwrap_err();
489        assert_eq!(err.status, StatusCode::NOT_FOUND);
490        assert_eq!(err.title, "Not Found");
491        assert_eq!(err.detail, "Resource missing");
492    }
493
494    #[test]
495    fn test_option_ext_context_internal() {
496        let option: Option<String> = None;
497        let api_result = option.context_internal("Internal Error", "Config missing");
498
499        assert!(api_result.is_err());
500        let err = api_result.unwrap_err();
501        assert_eq!(err.status, StatusCode::INTERNAL_SERVER_ERROR);
502        assert_eq!(err.title, "Internal Error");
503        assert_eq!(err.detail, "Config missing");
504    }
505
506    #[test]
507    fn test_option_ext_context_status() {
508        let option: Option<String> = None;
509        let api_result = option.context_status(StatusCode::CONFLICT, "Conflict", "Already exists");
510
511        assert!(api_result.is_err());
512        let err = api_result.unwrap_err();
513        assert_eq!(err.status, StatusCode::CONFLICT);
514        assert_eq!(err.title, "Conflict");
515        assert_eq!(err.detail, "Already exists");
516    }
517
518    #[test]
519    fn test_into_api_error_context_bad_request() {
520        let anyhow_err = anyhow!("Invalid input");
521        let api_err = anyhow_err.context_bad_request("Bad Request", "Field validation failed");
522
523        assert_eq!(api_err.status, StatusCode::BAD_REQUEST);
524        assert_eq!(api_err.title, "Bad Request");
525        assert_eq!(api_err.detail, "Field validation failed");
526    }
527
528    #[test]
529    fn test_chaining_result_operations() {
530        fn get_value() -> Result<i32> {
531            Err(anyhow!("Failed to get value"))
532        }
533
534        let result = get_value().context_bad_request("Bad Request", "Could not retrieve value");
535
536        assert!(result.is_err());
537        assert_eq!(result.unwrap_err().status, StatusCode::BAD_REQUEST);
538    }
539
540    #[test]
541    fn test_chaining_option_operations() {
542        fn get_value() -> Option<i32> {
543            None
544        }
545
546        let result = get_value().context_not_found("Not Found", "Value does not exist");
547
548        assert!(result.is_err());
549        assert_eq!(result.unwrap_err().status, StatusCode::NOT_FOUND);
550    }
551
552    #[test]
553    fn test_question_mark_operator_with_result() {
554        fn helper() -> ApiResult<i32> {
555            let value: Result<i32> = Err(anyhow!("error"));
556            value.context_bad_request("Bad Request", "Invalid")?;
557            Ok(42)
558        }
559
560        let result = helper();
561        assert!(result.is_err());
562        assert_eq!(result.unwrap_err().status, StatusCode::BAD_REQUEST);
563    }
564
565    #[test]
566    fn test_question_mark_operator_with_option() {
567        fn helper() -> ApiResult<i32> {
568            let value: Option<i32> = None;
569            value.context_not_found("Not Found", "Missing")?;
570            Ok(42)
571        }
572
573        let result = helper();
574        assert!(result.is_err());
575        assert_eq!(result.unwrap_err().status, StatusCode::NOT_FOUND);
576    }
577}