use std::fmt::Display;
use std::marker::PhantomData;
use std::ops::Deref;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use derive_more::with_trait::Debug;
use crate::Error;
use crate::handler::handle_all_parameters;
use crate::request::extractors::FromRequestHead;
use crate::request::{Request, RequestHead};
use crate::response::Response;
#[diagnostic::on_unimplemented(
message = "`{Self}` is not a valid error page handler",
label = "not a valid error page handler",
note = "make sure the function is marked `async`",
note = "make sure all parameters implement `FromRequestHead`",
note = "make sure the function takes no more than 10 parameters",
note = "make sure the function returns a type that implements `IntoResponse`"
)]
pub trait ErrorPageHandler<T = ()> {
fn handle(&self, head: &RequestHead) -> impl Future<Output = crate::Result<Response>> + Send;
}
pub(crate) trait BoxErrorPageHandler: Send + Sync {
fn handle<'a>(
&'a self,
head: &'a RequestHead,
) -> Pin<Box<dyn Future<Output = crate::Result<Response>> + Send + 'a>>;
}
#[derive(Debug, Clone)]
pub struct DynErrorPageHandler {
#[debug("..")]
handler: Arc<dyn BoxErrorPageHandler>,
}
impl DynErrorPageHandler {
pub fn new<HandlerParams, H>(handler: H) -> Self
where
HandlerParams: 'static,
H: ErrorPageHandler<HandlerParams> + Send + Sync + 'static,
{
struct Inner<T, H>(H, PhantomData<fn() -> T>);
impl<T, H: ErrorPageHandler<T> + Send + Sync> BoxErrorPageHandler for Inner<T, H> {
fn handle<'a>(
&'a self,
head: &'a RequestHead,
) -> Pin<Box<dyn Future<Output = cot::Result<Response>> + Send + 'a>> {
Box::pin(self.0.handle(head))
}
}
Self {
handler: Arc::new(Inner(handler, PhantomData)),
}
}
}
impl tower::Service<Request> for DynErrorPageHandler {
type Response = Response;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = cot::Result<Self::Response>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request) -> Self::Future {
let handler = self.handler.clone();
let (head, _) = req.into_parts();
Box::pin(async move { handler.handle(&head).await })
}
}
macro_rules! impl_request_handler {
($($ty:ident),*) => {
impl<Func, $($ty,)* Fut, R> ErrorPageHandler<($($ty,)*)> for Func
where
Func: FnOnce($($ty,)*) -> Fut + Clone + Send + Sync + 'static,
$($ty: FromRequestHead + Send,)*
Fut: Future<Output = R> + Send,
R: crate::response::IntoResponse,
{
#[allow(
clippy::allow_attributes,
non_snake_case,
unused_variables,
reason = "for the case where there are no params"
)]
async fn handle(&self, head: &RequestHead) -> crate::Result<Response> {
$(
let $ty = <$ty as FromRequestHead>::from_request_head(&head).await?;
)*
self.clone()($($ty,)*).await.into_response()
}
}
};
}
handle_all_parameters!(impl_request_handler);
#[derive(Debug, Clone)]
pub struct RequestOuterError(Arc<Error>);
impl RequestOuterError {
#[must_use]
pub(crate) fn new(error: Error) -> Self {
Self(Arc::new(error))
}
}
impl Deref for RequestOuterError {
type Target = Error;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Display for RequestOuterError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Display::fmt(&self.0.inner(), f)
}
}
impl FromRequestHead for RequestOuterError {
async fn from_request_head(head: &RequestHead) -> crate::Result<Self> {
let error = head.extensions.get::<RequestOuterError>();
error
.ok_or_else(|| {
Error::internal("No error found in request head. Make sure you use this extractor in an error handler.")
}).cloned()
}
}
#[derive(Debug, Clone)]
pub struct RequestError(Arc<Error>);
impl Deref for RequestError {
type Target = Error;
fn deref(&self) -> &Self::Target {
self.0.inner()
}
}
impl Display for RequestError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Display::fmt(&self.0.inner(), f)
}
}
impl FromRequestHead for RequestError {
async fn from_request_head(head: &RequestHead) -> crate::Result<Self> {
let error = head.extensions.get::<RequestOuterError>();
error
.ok_or_else(|| {
Error::internal(
"No error found in request head. \
Make sure you use this extractor in an error handler.",
)
})
.map(|request_error| Self(request_error.0.clone()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn request_outer_error_display() {
let error = Error::internal("Test error");
let request_error = RequestOuterError::new(error);
assert_eq!(format!("{request_error}"), "Test error");
}
#[test]
fn request_error_display() {
let error = Error::internal("Test error");
let request_error = RequestError(Arc::new(error));
assert_eq!(format!("{request_error}"), "Test error");
}
#[cot::test]
async fn request_outer_error_from_request_head() {
let request = Request::default();
let (mut head, _) = request.into_parts();
head.extensions
.insert(RequestOuterError::new(Error::internal("Test error")));
let extracted_error = RequestOuterError::from_request_head(&head).await.unwrap();
assert_eq!(format!("{extracted_error}"), "Test error");
}
#[cot::test]
async fn request_error_from_request_head() {
let request = Request::default();
let (mut head, _) = request.into_parts();
head.extensions
.insert(RequestOuterError::new(Error::internal("Test error")));
let extracted_error = RequestError::from_request_head(&head).await.unwrap();
assert_eq!(format!("{extracted_error}"), "Test error");
}
}