use crate::{endpoint::EndpointImpl, error::BoxHttpError, Endpoint, HttpError, Request, Response};
use alloc::boxed::Box;
use core::{
any::type_name,
convert::Infallible,
fmt::{Debug, Display},
future::Future,
ops::DerefMut,
pin::Pin,
};
use http::StatusCode;
pub trait Middleware: Send {
type Error: HttpError;
fn handle<E: Endpoint>(
&mut self,
request: &mut Request,
next: E,
) -> impl Future<Output = Result<Response, MiddlewareError<E::Error, Self::Error>>> + Send;
}
#[derive(Debug)]
pub enum MiddlewareError<N, E> {
Endpoint(N),
Middleware(E),
}
impl<N: HttpError, E: HttpError> Display for MiddlewareError<N, E> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
MiddlewareError::Endpoint(e) => write!(f, "Endpoint error: {}", e),
MiddlewareError::Middleware(e) => write!(f, "Middleware error: {}", e),
}
}
}
impl<N: HttpError, E: HttpError> core::error::Error for MiddlewareError<N, E> {
fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
match self {
MiddlewareError::Endpoint(e) => e.source(),
MiddlewareError::Middleware(e) => e.source(),
}
}
}
impl<N: HttpError, E: HttpError> HttpError for MiddlewareError<N, E> {
fn status(&self) -> StatusCode {
match self {
MiddlewareError::Endpoint(e) => e.status(),
MiddlewareError::Middleware(e) => e.status(),
}
}
}
pub(crate) trait MiddlewareImpl: Send {
fn handle_inner<'this, 'req, 'next, 'fut>(
&'this mut self,
request: &'req mut Request,
next: &'next mut dyn EndpointImpl,
) -> Pin<Box<dyn 'fut + Future<Output = Result<Response, BoxHttpError>> + Send>>
where
'this: 'fut,
'req: 'fut,
'next: 'fut;
fn name(&self) -> &'static str {
type_name::<Self>()
}
}
impl<'a> Endpoint for &mut (dyn EndpointImpl + 'a) {
type Error = BoxHttpError;
async fn respond(&mut self, request: &mut Request) -> Result<Response, Self::Error> {
self.respond_inner(request).await
}
}
impl<T: Middleware> MiddlewareImpl for T {
fn handle_inner<'this, 'req, 'next, 'fut>(
&'this mut self,
request: &'req mut Request,
next: &'next mut dyn EndpointImpl,
) -> Pin<Box<dyn 'fut + Future<Output = Result<Response, BoxHttpError>> + Send>>
where
'this: 'fut,
'req: 'fut,
'next: 'fut,
{
Box::pin(async move {
self.handle(request, next)
.await
.map_err(|e| Box::new(e) as BoxHttpError)
})
}
}
impl<M: Middleware> Middleware for &mut M {
type Error = M::Error;
async fn handle<E: Endpoint>(
&mut self,
request: &mut Request,
next: E,
) -> Result<Response, MiddlewareError<E::Error, Self::Error>> {
Middleware::handle(*self, request, next).await
}
}
impl<M: Middleware> Middleware for Box<M> {
type Error = M::Error;
async fn handle<E: Endpoint>(
&mut self,
request: &mut Request,
next: E,
) -> Result<Response, MiddlewareError<E::Error, Self::Error>> {
Middleware::handle(self.deref_mut(), request, next).await
}
}
#[derive(Debug)]
pub enum MiddlewareTupleError<E1: HttpError, E2: HttpError> {
First(E1),
Second(E2),
}
impl<A, B> Display for MiddlewareTupleError<A, B>
where
A: HttpError,
B: HttpError,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
MiddlewareTupleError::First(e) => write!(f, "First middleware error: {}", e),
MiddlewareTupleError::Second(e) => write!(f, "Second middleware error: {}", e),
}
}
}
impl<A, B> core::error::Error for MiddlewareTupleError<A, B>
where
A: HttpError,
B: HttpError,
{
fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
match self {
MiddlewareTupleError::First(e) => e.source(),
MiddlewareTupleError::Second(e) => e.source(),
}
}
}
impl<A, B> HttpError for MiddlewareTupleError<A, B>
where
A: HttpError,
B: HttpError,
{
fn status(&self) -> StatusCode {
match self {
MiddlewareTupleError::First(e) => e.status(),
MiddlewareTupleError::Second(e) => e.status(),
}
}
}
pub struct AnyMiddleware(Box<dyn MiddlewareImpl>);
impl Debug for AnyMiddleware {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_fmt(format_args!("AnyMiddleware[{}]", self.name()))
}
}
impl AnyMiddleware {
pub fn new(middleware: impl Middleware + 'static) -> Self {
AnyMiddleware(Box::new(middleware))
}
pub fn name(&self) -> &'static str {
self.0.name()
}
}
impl Middleware for AnyMiddleware {
type Error = BoxHttpError;
async fn handle<E: Endpoint>(
&mut self,
request: &mut Request,
mut next: E,
) -> Result<Response, MiddlewareError<E::Error, Self::Error>> {
self.0
.handle_inner(request, &mut next)
.await
.map_err(MiddlewareError::<E::Error, _>::Middleware)
}
}
impl Middleware for () {
type Error = Infallible;
async fn handle<E: Endpoint>(
&mut self,
request: &mut Request,
mut next: E,
) -> Result<Response, MiddlewareError<E::Error, Self::Error>> {
next.respond(request)
.await
.map_err(MiddlewareError::<_, Self::Error>::Endpoint)
}
}