use async_recursion::async_recursion;
use async_trait::async_trait;
use std::sync::Arc;
use crate::{Error, Request, Response, Result};
#[async_trait]
pub trait Handler: Send + Sync + 'static {
async fn handle(&self, req: &mut Request) -> Result<Response>;
}
#[async_trait]
impl<F> Handler for F
where
F: Send + Sync + 'static + Fn(&mut Request) -> Result<Response>,
{
async fn handle(&self, req: &mut Request) -> Result<Response> {
(*self)(req)
}
}
#[async_trait]
impl Handler for Box<dyn Handler> {
async fn handle(&self, req: &mut Request) -> Result<Response> {
(**self).handle(req).await
}
}
#[async_trait]
pub trait BeforeMiddleware: Send + Sync + 'static {
async fn before(&self, _: &mut Request) -> Result<()> {
Ok(())
}
async fn catch(&self, _: &mut Request, err: Error) -> Result<()> {
Err(err)
}
}
#[async_trait]
pub trait AfterMiddleware: Send + Sync + 'static {
async fn after(&self, _: &mut Request, res: Response) -> Result<Response> {
Ok(res)
}
async fn catch(&self, _: &mut Request, err: Error) -> Result<Response> {
Err(err)
}
}
#[async_trait(?Send)]
pub trait AroundMiddleware {
async fn around(self, handler: Box<dyn Handler>) -> Box<dyn Handler>;
}
pub struct Middlewares {
befores: Vec<Box<dyn BeforeMiddleware>>,
afters: Vec<Box<dyn AfterMiddleware>>,
handler: Option<Box<dyn Handler>>,
}
impl Middlewares {
pub fn new<H>(handler: H) -> Self
where
H: Handler,
{
Self {
befores: vec![],
afters: vec![],
handler: Some(Box::new(handler) as Box<dyn Handler>),
}
}
pub fn link<B, A>(&mut self, link: (B, A)) -> &mut Middlewares
where
A: AfterMiddleware,
B: BeforeMiddleware,
{
let (before, after) = link;
self.befores
.push(Box::new(before) as Box<dyn BeforeMiddleware>);
self.afters
.push(Box::new(after) as Box<dyn AfterMiddleware>);
self
}
pub fn link_before<B>(&mut self, before: B) -> &mut Middlewares
where
B: BeforeMiddleware,
{
self.befores
.push(Box::new(before) as Box<dyn BeforeMiddleware>);
self
}
pub fn link_after<A>(&mut self, after: A) -> &mut Middlewares
where
A: AfterMiddleware,
{
self.afters
.push(Box::new(after) as Box<dyn AfterMiddleware>);
self
}
pub async fn link_around<A>(&mut self, around: A) -> &mut Middlewares
where
A: AroundMiddleware,
{
let mut handler = self.handler.take().unwrap();
handler = around.around(handler).await;
self.handler = Some(handler);
self
}
}
#[async_trait]
impl Handler for Middlewares {
async fn handle(&self, req: &mut Request) -> Result<Response> {
self.continue_from_before(req, 0).await
}
}
impl Middlewares {
#[async_recursion]
async fn fail_from_before(
&self,
req: &mut Request,
index: usize,
mut err: Error,
) -> Result<Response> {
if index >= self.befores.len() {
return self.fail_from_handler(req, err).await;
}
for (i, before) in self.befores[index..].iter().enumerate() {
err = match before.catch(req, err).await {
Err(err) => err,
Ok(()) => return self.continue_from_before(req, index + i + 1).await,
};
}
self.fail_from_handler(req, err).await
}
async fn continue_from_before(&self, req: &mut Request, index: usize) -> Result<Response> {
if index >= self.befores.len() {
return self.continue_from_handler(req).await;
}
for (i, before) in self.befores[index..].iter().enumerate() {
match before.before(req).await {
Ok(()) => {}
Err(err) => return self.fail_from_before(req, index + i + 1, err).await,
}
}
self.continue_from_handler(req).await
}
async fn fail_from_handler(&self, req: &mut Request, err: Error) -> Result<Response> {
self.fail_from_after(req, 0, err).await
}
async fn fail_from_after(
&self,
req: &mut Request,
index: usize,
mut err: Error,
) -> Result<Response> {
if index == self.afters.len() {
return Err(err);
}
for (i, after) in self.afters[index..].iter().enumerate() {
err = match after.catch(req, err).await {
Err(err) => err,
Ok(res) => return self.continue_from_after(req, index + i + 1, res).await,
}
}
Err(err)
}
async fn continue_from_handler(&self, req: &mut Request) -> Result<Response> {
match self.handler.as_ref().unwrap().handle(req).await {
Ok(res) => self.continue_from_after(req, 0, res).await,
Err(err) => self.fail_from_handler(req, err).await,
}
}
#[async_recursion]
async fn continue_from_after(
&self,
req: &mut Request,
index: usize,
mut res: Response,
) -> Result<Response> {
if index >= self.afters.len() {
return Ok(res);
}
for (i, after) in self.afters[index..].iter().enumerate() {
res = match after.after(req, res).await {
Ok(res) => res,
Err(err) => return self.fail_from_after(req, index + i + 1, err).await,
}
}
Ok(res)
}
}
#[async_trait]
impl<F> BeforeMiddleware for F
where
F: Send + Sync + 'static + Fn(&mut Request) -> Result<()>,
{
async fn before(&self, req: &mut Request) -> Result<()> {
(*self)(req)
}
}
#[async_trait]
impl BeforeMiddleware for Box<dyn BeforeMiddleware> {
async fn before(&self, req: &mut Request) -> Result<()> {
(**self).before(req).await
}
async fn catch(&self, req: &mut Request, err: Error) -> Result<()> {
(**self).catch(req, err).await
}
}
#[async_trait]
impl<T> BeforeMiddleware for Arc<T>
where
T: BeforeMiddleware,
{
async fn before(&self, req: &mut Request) -> Result<()> {
(**self).before(req).await
}
async fn catch(&self, req: &mut Request, err: Error) -> Result<()> {
(**self).catch(req, err).await
}
}
#[async_trait]
impl<F> AfterMiddleware for F
where
F: Send + Sync + 'static + Fn(&mut Request, Response) -> Result<Response>,
{
async fn after(&self, req: &mut Request, res: Response) -> Result<Response> {
(*self)(req, res)
}
}
#[async_trait]
impl AfterMiddleware for Box<dyn AfterMiddleware> {
async fn after(&self, req: &mut Request, res: Response) -> Result<Response> {
(**self).after(req, res).await
}
async fn catch(&self, req: &mut Request, err: Error) -> Result<Response> {
(**self).catch(req, err).await
}
}
#[async_trait]
impl<T> AfterMiddleware for Arc<T>
where
T: AfterMiddleware,
{
async fn after(&self, req: &mut Request, res: Response) -> Result<Response> {
(**self).after(req, res).await
}
async fn catch(&self, req: &mut Request, err: Error) -> Result<Response> {
(**self).catch(req, err).await
}
}
#[async_trait(?Send)]
impl<F> AroundMiddleware for F
where
F: FnOnce(Box<dyn Handler>) -> Box<dyn Handler>,
{
async fn around(self, handler: Box<dyn Handler>) -> Box<dyn Handler> {
self(handler)
}
}