1use std::error::Error;
7
8#[cfg(feature = "diesel")]
9use crate::sql::NoRowsFound;
10use crate::{
11 http::{self, NotFound},
12 Result,
13};
14
15mod sealed {
16 pub trait Sealed {
17 type Value;
18 }
19
20 impl<T, E> Sealed for Result<T, E> {
21 type Value = T;
22 }
23
24 impl<T> Sealed for Option<T> {
25 type Value = T;
26 }
27}
28
29pub trait ResultExt: sealed::Sealed + Sized {
33 fn internal(self) -> Result<Self::Value>;
45}
46
47impl<T, E> ResultExt for Result<T, E>
48where
49 E: Error + Send + Sync + 'static,
50{
51 #[track_caller]
52 fn internal(self) -> Result<Self::Value> {
53 self.map_err(http::internal_error)
54 }
55}
56
57pub trait ProblemResultExt: ResultExt {
59 fn catch_err<E>(self) -> Result<Result<Self::Value, E>>
75 where
76 E: Error + Send + Sync + 'static;
77
78 fn optional(self) -> Result<Option<Self::Value>>;
88}
89
90impl<T> ProblemResultExt for Result<T> {
91 fn catch_err<E>(self) -> Result<Result<Self::Value, E>>
92 where
93 E: Error + Send + Sync + 'static,
94 {
95 Ok(match self {
96 Ok(ok) => Ok(ok),
97 Err(err) => Err(err.downcast::<E>()?),
98 })
99 }
100
101 fn optional(self) -> Result<Option<Self::Value>> {
102 match self {
103 Ok(ok) => Ok(Some(ok)),
104 Err(err) => {
105 #[allow(clippy::question_mark)]
106 if let Err(err) = err.downcast::<NotFound>() {
107 #[cfg(feature = "diesel")]
108 err.downcast::<NoRowsFound>()?;
109 #[cfg(not(feature = "diesel"))]
110 return Err(err);
111 }
112
113 Ok(None)
114 }
115 }
116 }
117}
118
119pub trait OptionExt: sealed::Sealed + Sized {
121 fn or_not_found<I>(self, entity: &'static str, identifier: I) -> Result<Self::Value>
128 where
129 I: std::fmt::Display;
130
131 fn or_not_found_unknown(self, entity: &'static str) -> Result<Self::Value>;
141}
142
143impl<T> OptionExt for Option<T> {
144 #[track_caller]
145 fn or_not_found<I>(self, entity: &'static str, identifier: I) -> Result<Self::Value>
146 where
147 I: std::fmt::Display,
148 {
149 if let Some(value) = self {
151 Ok(value)
152 } else {
153 Err(http::not_found(entity, identifier))
154 }
155 }
156
157 #[track_caller]
158 fn or_not_found_unknown(self, entity: &'static str) -> Result<Self::Value> {
159 self.or_not_found(entity, "<unknown>")
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166 use crate::http;
167
168 #[test]
169 fn test_internal() {
170 let res =
171 Err(std::io::Error::new(std::io::ErrorKind::Other, "oh no")) as std::io::Result<()>;
172
173 let res = res.internal().unwrap_err();
174
175 assert!(res.is::<http::InternalError>());
176 }
177
178 #[test]
179 fn test_catch_err() {
180 let res =
181 Err(std::io::Error::new(std::io::ErrorKind::Other, "oh no")) as std::io::Result<()>;
182
183 let res = res.internal();
184
185 let not_found = res.catch_err::<http::NotFound>().unwrap_err();
186 let res = Err(not_found) as crate::Result<()>;
187
188 let res = res.catch_err::<http::InternalError>().unwrap();
189
190 assert!(res.is_err());
191
192 let ok = Ok(()) as crate::Result<()>;
193
194 assert!(ok.catch_err::<http::InternalError>().unwrap().is_ok());
195 }
196
197 #[test]
198 fn test_optional() {
199 let res = Err(http::not_found("user", "bla")) as crate::Result<()>;
200 assert!(res.optional().unwrap().is_none());
201
202 let res = Err(http::failed_precondition()) as crate::Result<()>;
203 assert!(res.optional().is_err());
204
205 let res = Ok(()) as crate::Result<()>;
206 assert!(res.optional().unwrap().is_some());
207 }
208
209 #[test]
210 fn test_or_not_found() {
211 let res = None.or_not_found_unknown("bla") as crate::Result<()>;
212 let err = res.unwrap_err();
213
214 assert!(err.is::<http::NotFound>());
215
216 let res = Some(()).or_not_found_unknown("bla");
217
218 assert!(res.is_ok());
219 }
220}