http_problem/
ext.rs

1//! # Standard Error Handling Types Extensions
2//!
3//! This module provides traits that extends the common error
4//! handling types `Result` and `Option` with methods that
5//! integrate them with the types defined in this crate.
6use std::error::Error;
7
8#[cfg(feature = "diesel")]
9use crate::sql::NoRowsFound;
10use crate::{
11    http::{self, NotFound},
12    Result,
13};
14
15mod sealed {
16    pub trait Sealed {
17        type Value;
18    }
19
20    impl<T, E> Sealed for Result<T, E> {
21        type Value = T;
22    }
23
24    impl<T> Sealed for Option<T> {
25        type Value = T;
26    }
27}
28
29/// Extension methods on [`Result`].
30///
31/// [`Result`]: std::result::Result
32pub trait ResultExt: sealed::Sealed + Sized {
33    /// Converts this result to an internal error.
34    ///
35    /// Used when an unrecoverable error happens.
36    ///
37    /// See [`internal_error`] for more information.
38    ///
39    /// # Errors
40    ///
41    /// Returns `Err` if `self` is `Err`.
42    ///
43    /// [`internal_error`]: crate::http::internal_error
44    fn internal(self) -> Result<Self::Value>;
45}
46
47impl<T, E> ResultExt for Result<T, E>
48where
49    E: Error + Send + Sync + 'static,
50{
51    #[track_caller]
52    fn internal(self) -> Result<Self::Value> {
53        self.map_err(http::internal_error)
54    }
55}
56
57/// Extension methods on `Result<T, Problem>`.
58pub trait ProblemResultExt: ResultExt {
59    /// Catch a specific error type `E`.
60    ///
61    /// Returns `Ok(Ok(_))` when the source `Result<T>` is a `T`, returns
62    /// `Ok(Err(_))` when its is a `Problem` that can downcast to an `E`,
63    /// and returns `Err(_)` otherwise.
64    ///
65    /// Useful when there is a need to handle a specific error differently,
66    /// e.g. a [`NotFound`] error.
67    ///
68    /// # Errors
69    ///
70    /// Returns `Err` if `self` contains a [`Problem`] which is not an `E`.
71    ///
72    /// [`Problem`]: crate::Problem
73    /// [`NotFound`]: crate::http::NotFound
74    fn catch_err<E>(self) -> Result<Result<Self::Value, E>>
75    where
76        E: Error + Send + Sync + 'static;
77
78    /// Catch a [`NotFound`] and convert it to `None`.
79    ///
80    /// # Errors
81    ///
82    /// Returns `Err` if `self` contains a [`Problem`] which is not a
83    /// [`NotFound`].
84    ///
85    /// [`Problem`]: crate::Problem
86    /// [`NotFound`]: crate::http::NotFound
87    fn optional(self) -> Result<Option<Self::Value>>;
88}
89
90impl<T> ProblemResultExt for Result<T> {
91    fn catch_err<E>(self) -> Result<Result<Self::Value, E>>
92    where
93        E: Error + Send + Sync + 'static,
94    {
95        Ok(match self {
96            Ok(ok) => Ok(ok),
97            Err(err) => Err(err.downcast::<E>()?),
98        })
99    }
100
101    fn optional(self) -> Result<Option<Self::Value>> {
102        match self {
103            Ok(ok) => Ok(Some(ok)),
104            Err(err) => {
105                #[allow(clippy::question_mark)]
106                if let Err(err) = err.downcast::<NotFound>() {
107                    #[cfg(feature = "diesel")]
108                    err.downcast::<NoRowsFound>()?;
109                    #[cfg(not(feature = "diesel"))]
110                    return Err(err);
111                }
112
113                Ok(None)
114            }
115        }
116    }
117}
118
119/// Extension methods on `Option<T>`.
120pub trait OptionExt: sealed::Sealed + Sized {
121    /// Returns `Ok(_)` if the source is `Some(_)`, otherwise, returns a
122    /// `Problem` that can downcast to `NotFound`.
123    ///
124    /// # Errors
125    ///
126    /// Returns `Err` when `self` is `None`.
127    fn or_not_found<I>(self, entity: &'static str, identifier: I) -> Result<Self::Value>
128    where
129        I: std::fmt::Display;
130
131    /// It's a wrapper to `or_not_found` to be used when
132    /// there is no a specific identifier to entity message.
133    ///
134    /// Returns `Ok(_)` if the source is `Some(_)`, otherwise, returns a
135    /// `Problem` that can downcast to `NotFound`.
136    ///
137    /// # Errors
138    ///
139    /// Returns `Err` when `self` is `None`.
140    fn or_not_found_unknown(self, entity: &'static str) -> Result<Self::Value>;
141}
142
143impl<T> OptionExt for Option<T> {
144    #[track_caller]
145    fn or_not_found<I>(self, entity: &'static str, identifier: I) -> Result<Self::Value>
146    where
147        I: std::fmt::Display,
148    {
149        // Cannot use Option::ok_or_else as it isn't annotated with track_caller.
150        if let Some(value) = self {
151            Ok(value)
152        } else {
153            Err(http::not_found(entity, identifier))
154        }
155    }
156
157    #[track_caller]
158    fn or_not_found_unknown(self, entity: &'static str) -> Result<Self::Value> {
159        self.or_not_found(entity, "<unknown>")
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use crate::http;
167
168    #[test]
169    fn test_internal() {
170        let res =
171            Err(std::io::Error::new(std::io::ErrorKind::Other, "oh no")) as std::io::Result<()>;
172
173        let res = res.internal().unwrap_err();
174
175        assert!(res.is::<http::InternalError>());
176    }
177
178    #[test]
179    fn test_catch_err() {
180        let res =
181            Err(std::io::Error::new(std::io::ErrorKind::Other, "oh no")) as std::io::Result<()>;
182
183        let res = res.internal();
184
185        let not_found = res.catch_err::<http::NotFound>().unwrap_err();
186        let res = Err(not_found) as crate::Result<()>;
187
188        let res = res.catch_err::<http::InternalError>().unwrap();
189
190        assert!(res.is_err());
191
192        let ok = Ok(()) as crate::Result<()>;
193
194        assert!(ok.catch_err::<http::InternalError>().unwrap().is_ok());
195    }
196
197    #[test]
198    fn test_optional() {
199        let res = Err(http::not_found("user", "bla")) as crate::Result<()>;
200        assert!(res.optional().unwrap().is_none());
201
202        let res = Err(http::failed_precondition()) as crate::Result<()>;
203        assert!(res.optional().is_err());
204
205        let res = Ok(()) as crate::Result<()>;
206        assert!(res.optional().unwrap().is_some());
207    }
208
209    #[test]
210    fn test_or_not_found() {
211        let res = None.or_not_found_unknown("bla") as crate::Result<()>;
212        let err = res.unwrap_err();
213
214        assert!(err.is::<http::NotFound>());
215
216        let res = Some(()).or_not_found_unknown("bla");
217
218        assert!(res.is_ok());
219    }
220}