athene 2.0.4

A simple and lightweight rust web framework based on hyper
Documentation
use crate::{context::Context, error::Error, resolver::UriPath};
use async_trait::async_trait;
use std::{future::Future,sync::Arc};

#[doc(hidden)]
#[async_trait]
pub trait Middleware {
    async fn next(&'static self, ctx: Context, next: &'static dyn Next) -> Result<Context, Error>;
}

#[doc(hidden)]
#[async_trait]
impl<Fut, F> Middleware for F
where
    Fut: 'static + Future<Output = Result<Context, Error>> + Send,
    F: Fn(Context, &'static dyn Next) -> Fut + Send + Sync,
{
    #[inline]
    async fn next(&'static self, ctx: Context, next: &'static dyn Next) -> Result<Context, Error> {
        self(ctx, next).await
    }
}

#[doc(hidden)]
#[async_trait]
pub trait Next: Sync + Send {
    async fn next(&self, ctx: Context) -> Result<Context, Error>;
}

/// Builder to apply middleware onto the http stack
pub struct MiddlewareBuilder<N: Next> {
    next: N,
}

impl Default for MiddlewareBuilder<ImplNext> {
    #[inline]
    fn default() -> Self {
        Self { next: ImplNext }
    }
}

impl<T: Next + 'static> MiddlewareBuilder<T> {
    #[inline]
    pub fn apply<'a, M, E>(
        self,
        middleware: M,
        include_path: Vec<&str>,
        exclude_path: E,
    ) -> MiddlewareBuilder<MiddlewareChain<M, T>>
    where
        M: 'static + Middleware + Sync + Send,
        E: Into<Option<Vec<&'a str>>>,
    {
        let rule = Rule::new(include_path, exclude_path.into());
        MiddlewareBuilder {
            next: MiddlewareChain {
                rule,
                middleware,
                next: self.next,
            },
        }
    }
    #[inline]
    pub(crate) fn build(self) -> Arc<dyn Next> {
        Arc::new(self.next)
    }
}

pub(crate) struct Rule {
    included_path: Vec<UriPath>,
    excluded_path: Option<Vec<UriPath>>,
}

impl Rule {
    #[doc(hidden)]
    #[inline]
    pub fn new(include_path: Vec<&str>, exclude_path: Option<Vec<&str>>) -> Self {
        Rule {
            included_path: include_path
                .iter()
                .filter_map(|p| {
                    UriPath::new(p)
                        .map_err(|e| {
                            println!(
                                "{}Unable to construct included middleware route: {}",
                                "\x1b[31m", e
                            )
                        })
                        .ok()
                })
                .collect(),
            excluded_path: exclude_path.map(|ex| {
                ex.iter()
                    .filter_map(|p| {
                        UriPath::new(p)
                            .map_err(|e| {
                                println!(
                                    "{}Unable to construct excluded middleware route: {}",
                                    "\x1b[31m", e
                                )
                            })
                            .ok()
                    })
                    .collect()
            }),
        }
    }

    #[doc(hidden)]
    #[inline]
    pub fn validate_path(&self, path: &str) -> bool {
        if self
            .included_path
            .iter()
            .any(|m_p| m_p.match_non_exhaustive(path))
        {
            if let Some(ref excluded_path) = self.excluded_path {
                return !excluded_path
                    .iter()
                    .any(|m_e_p| m_e_p.match_non_exhaustive(path));
            } else {
                return true;
            }
        }
        false
    }
}

#[doc(hidden)]
pub struct ImplNext;

#[async_trait]
impl Next for ImplNext {
    #[doc(hidden)]
    #[inline]
    async fn next(&self, mut ctx: Context) -> Result<Context, Error> {
        let router = ctx
            .router
            .take()
            .ok_or_else(|| Error::Other(String::from("Router not found")))?;
        router.dispatch(ctx).await
    }
}

#[doc(hidden)]
pub struct MiddlewareChain<M: Middleware, N: Next> {
    rule: Rule,
    middleware: M,
    next: N,
}

#[doc(hidden)]
#[async_trait]
impl<M, N> Next for MiddlewareChain<M, N>
where
    M: Middleware + Sync + Send + 'static,
    N: Next,
{
    #[doc(hidden)]
    #[inline]
    async fn next(&self, ctx: Context) -> Result<Context, Error> {
        let (middleware, next) = unsafe {
            (
                std::mem::transmute::<&'_ M, &'static M>(&self.middleware),
                std::mem::transmute::<&'_ dyn Next, &'static dyn Next>(&self.next),
            )
        };

        let req = ctx.state.request()?;
        if self.rule.validate_path(req.uri().path())
        {
            middleware.next(ctx, next).await
        } else {
            next.next(ctx).await
        }
    }
}