use crate::{context::Context, error::Error, resolver::UriPath};
use async_trait::async_trait;
use std::{future::Future,sync::Arc};
#[doc(hidden)]
#[async_trait]
pub trait Middleware {
async fn next(&'static self, ctx: Context, next: &'static dyn Next) -> Result<Context, Error>;
}
#[doc(hidden)]
#[async_trait]
impl<Fut, F> Middleware for F
where
Fut: 'static + Future<Output = Result<Context, Error>> + Send,
F: Fn(Context, &'static dyn Next) -> Fut + Send + Sync,
{
#[inline]
async fn next(&'static self, ctx: Context, next: &'static dyn Next) -> Result<Context, Error> {
self(ctx, next).await
}
}
#[doc(hidden)]
#[async_trait]
pub trait Next: Sync + Send {
async fn next(&self, ctx: Context) -> Result<Context, Error>;
}
pub struct MiddlewareBuilder<N: Next> {
next: N,
}
impl Default for MiddlewareBuilder<ImplNext> {
#[inline]
fn default() -> Self {
Self { next: ImplNext }
}
}
impl<T: Next + 'static> MiddlewareBuilder<T> {
#[inline]
pub fn apply<'a, M, E>(
self,
middleware: M,
include_path: Vec<&str>,
exclude_path: E,
) -> MiddlewareBuilder<MiddlewareChain<M, T>>
where
M: 'static + Middleware + Sync + Send,
E: Into<Option<Vec<&'a str>>>,
{
let rule = Rule::new(include_path, exclude_path.into());
MiddlewareBuilder {
next: MiddlewareChain {
rule,
middleware,
next: self.next,
},
}
}
#[inline]
pub(crate) fn build(self) -> Arc<dyn Next> {
Arc::new(self.next)
}
}
pub(crate) struct Rule {
included_path: Vec<UriPath>,
excluded_path: Option<Vec<UriPath>>,
}
impl Rule {
#[doc(hidden)]
#[inline]
pub fn new(include_path: Vec<&str>, exclude_path: Option<Vec<&str>>) -> Self {
Rule {
included_path: include_path
.iter()
.filter_map(|p| {
UriPath::new(p)
.map_err(|e| {
println!(
"{}Unable to construct included middleware route: {}",
"\x1b[31m", e
)
})
.ok()
})
.collect(),
excluded_path: exclude_path.map(|ex| {
ex.iter()
.filter_map(|p| {
UriPath::new(p)
.map_err(|e| {
println!(
"{}Unable to construct excluded middleware route: {}",
"\x1b[31m", e
)
})
.ok()
})
.collect()
}),
}
}
#[doc(hidden)]
#[inline]
pub fn validate_path(&self, path: &str) -> bool {
if self
.included_path
.iter()
.any(|m_p| m_p.match_non_exhaustive(path))
{
if let Some(ref excluded_path) = self.excluded_path {
return !excluded_path
.iter()
.any(|m_e_p| m_e_p.match_non_exhaustive(path));
} else {
return true;
}
}
false
}
}
#[doc(hidden)]
pub struct ImplNext;
#[async_trait]
impl Next for ImplNext {
#[doc(hidden)]
#[inline]
async fn next(&self, mut ctx: Context) -> Result<Context, Error> {
let router = ctx
.router
.take()
.ok_or_else(|| Error::Other(String::from("Router not found")))?;
router.dispatch(ctx).await
}
}
#[doc(hidden)]
pub struct MiddlewareChain<M: Middleware, N: Next> {
rule: Rule,
middleware: M,
next: N,
}
#[doc(hidden)]
#[async_trait]
impl<M, N> Next for MiddlewareChain<M, N>
where
M: Middleware + Sync + Send + 'static,
N: Next,
{
#[doc(hidden)]
#[inline]
async fn next(&self, ctx: Context) -> Result<Context, Error> {
let (middleware, next) = unsafe {
(
std::mem::transmute::<&'_ M, &'static M>(&self.middleware),
std::mem::transmute::<&'_ dyn Next, &'static dyn Next>(&self.next),
)
};
let req = ctx.state.request()?;
if self.rule.validate_path(req.uri().path())
{
middleware.next(ctx, next).await
} else {
next.next(ctx).await
}
}
}