use std::error::Error;
#[cfg(feature = "diesel")]
use crate::sql::NoRowsFound;
use crate::{
http::{self, NotFound},
Result,
};
mod sealed {
pub trait Sealed {
type Value;
}
impl<T, E> Sealed for Result<T, E> {
type Value = T;
}
impl<T> Sealed for Option<T> {
type Value = T;
}
}
pub trait ResultExt: sealed::Sealed + Sized {
fn internal(self) -> Result<Self::Value>;
}
impl<T, E> ResultExt for Result<T, E>
where
E: Error + Send + Sync + 'static,
{
#[track_caller]
fn internal(self) -> Result<Self::Value> {
self.map_err(http::internal_error)
}
}
pub trait ProblemResultExt: ResultExt {
fn catch_err<E>(self) -> Result<Result<Self::Value, E>>
where
E: Error + Send + Sync + 'static;
fn optional(self) -> Result<Option<Self::Value>>;
}
impl<T> ProblemResultExt for Result<T> {
fn catch_err<E>(self) -> Result<Result<Self::Value, E>>
where
E: Error + Send + Sync + 'static,
{
Ok(match self {
Ok(ok) => Ok(ok),
Err(err) => Err(err.downcast::<E>()?),
})
}
fn optional(self) -> Result<Option<Self::Value>> {
match self {
Ok(ok) => Ok(Some(ok)),
Err(err) => {
#[allow(clippy::question_mark)]
if let Err(err) = err.downcast::<NotFound>() {
#[cfg(feature = "diesel")]
err.downcast::<NoRowsFound>()?;
#[cfg(not(feature = "diesel"))]
return Err(err);
}
Ok(None)
}
}
}
}
pub trait OptionExt: sealed::Sealed + Sized {
fn or_not_found<I>(self, entity: &'static str, identifier: I) -> Result<Self::Value>
where
I: std::fmt::Display;
fn or_not_found_unknown(self, entity: &'static str) -> Result<Self::Value>;
}
impl<T> OptionExt for Option<T> {
#[track_caller]
fn or_not_found<I>(self, entity: &'static str, identifier: I) -> Result<Self::Value>
where
I: std::fmt::Display,
{
if let Some(value) = self {
Ok(value)
} else {
Err(http::not_found(entity, identifier))
}
}
#[track_caller]
fn or_not_found_unknown(self, entity: &'static str) -> Result<Self::Value> {
self.or_not_found(entity, "<unknown>")
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::http;
#[test]
fn test_internal() {
let res =
Err(std::io::Error::new(std::io::ErrorKind::Other, "oh no")) as std::io::Result<()>;
let res = res.internal().unwrap_err();
assert!(res.is::<http::InternalError>());
}
#[test]
fn test_catch_err() {
let res =
Err(std::io::Error::new(std::io::ErrorKind::Other, "oh no")) as std::io::Result<()>;
let res = res.internal();
let not_found = res.catch_err::<http::NotFound>().unwrap_err();
let res = Err(not_found) as crate::Result<()>;
let res = res.catch_err::<http::InternalError>().unwrap();
assert!(res.is_err());
let ok = Ok(()) as crate::Result<()>;
assert!(ok.catch_err::<http::InternalError>().unwrap().is_ok());
}
#[test]
fn test_optional() {
let res = Err(http::not_found("user", "bla")) as crate::Result<()>;
assert!(res.optional().unwrap().is_none());
let res = Err(http::failed_precondition()) as crate::Result<()>;
assert!(res.optional().is_err());
let res = Ok(()) as crate::Result<()>;
assert!(res.optional().unwrap().is_some());
}
#[test]
fn test_or_not_found() {
let res = None.or_not_found_unknown("bla") as crate::Result<()>;
let err = res.unwrap_err();
assert!(err.is::<http::NotFound>());
let res = Some(()).or_not_found_unknown("bla");
assert!(res.is_ok());
}
}