use crate::{
endpoint::{BoxedEnpoint, Endpoint, Handler, HandlerOutput},
http,
request::Request,
response::Response,
server::state::State,
};
use async_std::sync::{Arc, RwLock};
use async_trait::async_trait;
use std::{future::Future, pin::Pin};
pub(crate) type MiddlewareLock<S, Args, O> = RwLock<Vec<Arc<dyn Middleware<S, Args, O>>>>;
pub(crate) type MiddlewareArc<S, Args, O> = Arc<MiddlewareLock<S, Args, O>>;
#[async_trait]
pub trait Middleware<S, Args, O>: Send + Sync + 'static
where
S: State,
{
async fn handle(&self, args: Args, next: Next<'_, S, Args, O>) -> O;
fn name(&self) -> &str {
std::any::type_name::<Self>()
}
}
#[async_trait]
impl<S, FN> Middleware<S, Request<S>, HandlerOutput> for FN
where
S: State,
FN: for<'a> Fn(
Request<S>,
Next<'a, S, Request<S>, HandlerOutput>,
) -> Pin<Box<dyn Future<Output = HandlerOutput> + 'a + Send>>
+ Send
+ Sync
+ 'static,
{
async fn handle(
&self,
args: Request<S>,
next: Next<'_, S, Request<S>, HandlerOutput>,
) -> HandlerOutput {
log::debug!("inside middleware call named: {}", self.name());
(self)(args, next).await
}
}
pub struct Next<'a, S, Args, O>
where
S: State,
{
pub(crate) handler: &'a Endpoint<S>,
pub(crate) next_middleware: &'a [Arc<dyn Middleware<S, Args, O>>],
}
impl<S: State> std::fmt::Debug for Next<'_, S, Request<S>, HandlerOutput> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Next").finish()
}
}
impl<S: State> Next<'_, S, Request<S>, HandlerOutput> {
pub async fn run(mut self, req: Request<S>) -> HandlerOutput {
if let Some((current, next)) = self.next_middleware.split_first() {
self.next_middleware = next;
current.handle(req, self).await
} else {
self.handler.call(req).await
}
}
}
#[derive(Clone)]
pub(crate) struct MiddlewareHandler<H, S: State> {
handler: H,
middleware: MiddlewareArc<S, Request<S>, HandlerOutput>,
}
impl<E, S: State> std::fmt::Debug for MiddlewareHandler<E, S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "MiddlewareHandler: {:?}", self.middleware.read())
}
}
impl<H, S> MiddlewareHandler<H, S>
where
H: Handler<Request<S>, HandlerOutput>,
S: State,
{
pub(crate) async fn wrap_with_middleware(
ep: H,
middleware: MiddlewareArc<S, Request<S>, HandlerOutput>,
) -> BoxedEnpoint<S> {
if middleware.read().await.is_empty() {
Box::new(ep)
} else {
Box::new(Self {
handler: ep,
middleware,
})
}
}
}
#[async_trait]
impl<H, S> Handler<Request<S>, HandlerOutput> for MiddlewareHandler<H, S>
where
S: State,
H: Handler<Request<S>, HandlerOutput>,
{
async fn call(&self, args: Request<S>) -> HandlerOutput {
let next_middleware = self.middleware.read().await;
let next = Next {
handler: &self.handler,
next_middleware: next_middleware.as_ref(),
};
next.run(args).await
}
}
#[derive(Debug)]
pub struct Before<F>(pub F);
#[async_trait]
impl<S, FN, FT> Middleware<S, Request<S>, HandlerOutput> for Before<FN>
where
S: State,
FN: Fn(Request<S>) -> FT + Send + Sync + 'static,
FT: Future<Output = Result<Request<S>, http::Error>> + Send + 'static,
{
async fn handle(
&self,
args: Request<S>,
next: Next<'_, S, Request<S>, HandlerOutput>,
) -> HandlerOutput {
match (self.0)(args).await {
Ok(req) => next.run(req).await,
Err(e) => Ok(Response::from(e).into()),
}
}
}
#[derive(Debug)]
pub struct After<F>(pub F);
#[async_trait]
impl<S, FN, FT> Middleware<S, Request<S>, HandlerOutput> for After<FN>
where
S: State,
FN: Fn(Response) -> FT + Send + Sync + 'static,
FT: Future<Output = HandlerOutput> + Send + 'static,
{
async fn handle(
&self,
args: Request<S>,
next: Next<'_, S, Request<S>, HandlerOutput>,
) -> HandlerOutput {
match next.run(args).await {
Ok(res) => (self.0)(res.into()).await,
Err(e) => Ok(Response::from(e).into()),
}
}
}