use crate::{
context::Context,
error::Error,
resolver::UriPath,
};
use futures::{future::BoxFuture, Future, FutureExt};
pub trait Middleware {
fn next(
&'static self,
ctx: Context,
chain: &'static dyn Next,
) -> BoxFuture<'static, Result<Context, Error>>;
}
impl<Fut, F> Middleware for F
where
Fut: 'static + Future<Output = Result<Context, Error>> + Send,
F: Fn(Context, &'static dyn Next) -> Fut,
{
#[inline]
fn next(
&'static self,
ctx: Context,
next: &'static dyn Next,
) -> BoxFuture<'static, Result<Context, Error>> {
(*self)(ctx, next).boxed()
}
}
#[doc(hidden)]
pub trait Next: Sync + Send {
fn next(&self, ctx: Context) -> BoxFuture<'static, Result<Context, Error>>;
}
pub struct MiddlewareBuilder<N: Next> {
next: N,
}
impl Default for MiddlewareBuilder<ImplMiddleware> {
#[inline]
fn default() -> Self {
Self {
next: ImplMiddleware,
}
}
}
impl<T: Next + 'static> MiddlewareBuilder<T> {
#[inline]
pub fn apply<'a, M, E>(
self,
middleware: M,
include_path: Vec<&str>,
exclude_path: E,
) -> MiddlewareBuilder<MiddlewareChain<M, T>>
where
M: 'static + Middleware + Sync + Send,
E: Into<Option<Vec<&'a str>>>,
{
let rule = Rule::new(include_path, exclude_path.into());
MiddlewareBuilder {
next: MiddlewareChain {
rule,
middleware,
next: self.next,
},
}
}
#[inline]
pub(crate) fn build(self) -> Box<dyn Next> {
Box::new(self.next)
}
}
pub(crate) struct Rule {
included_path: Vec<UriPath>,
excluded_path: Option<Vec<UriPath>>,
}
impl Rule {
#[doc(hidden)]
#[inline]
pub fn new(include_path: Vec<&str>, exclude_path: Option<Vec<&str>>) -> Self {
Rule {
included_path: include_path
.iter()
.filter_map(|p| {
UriPath::new(p)
.map_err(|e| {
println!(
"{}Unable to construct included middleware route: {}",
"\x1b[31m", e
)
})
.ok()
})
.collect(),
excluded_path: exclude_path.map(|ex| {
ex.iter()
.filter_map(|p| {
UriPath::new(p)
.map_err(|e| {
println!(
"{}Unable to construct excluded middleware route: {}",
"\x1b[31m", e
)
})
.ok()
})
.collect()
}),
}
}
#[doc(hidden)]
#[inline]
pub fn validate_path(&self, path: &str) -> bool {
if self
.included_path
.iter()
.any(|m_p| m_p.match_non_exhaustive(path))
{
if let Some(ref excluded_path) = self.excluded_path {
return !excluded_path
.iter()
.any(|m_e_p| m_e_p.match_non_exhaustive(path));
} else {
return true;
}
}
false
}
}
#[doc(hidden)]
pub struct ImplMiddleware;
impl Next for ImplMiddleware {
#[doc(hidden)]
#[inline]
fn next(&self, mut ctx: Context) -> BoxFuture<'static, Result<Context, Error>> {
async {
let router = ctx
.router
.take()
.ok_or_else(||Error::Other(String::from("Router not found")))?;
router.dispatch(ctx).await
}
.boxed()
}
}
#[doc(hidden)]
pub struct MiddlewareChain<M: Middleware, N: Next> {
rule: Rule,
middleware: M,
next: N,
}
#[doc(hidden)]
impl<M, N> Next for MiddlewareChain<M, N>
where
M: Middleware + Sync + Send + 'static,
N: Next,
{
#[doc(hidden)]
#[inline]
fn next(&self, ctx: Context) -> BoxFuture<'static, Result<Context, Error>> {
let (middleware, next) = unsafe {
(
std::mem::transmute::<&'_ M, &'static M>(&self.middleware),
std::mem::transmute::<&'_ dyn Next, &'static dyn Next>(&self.next),
)
};
if ctx
.state
.request()
.filter(|req| self.rule.validate_path(req.uri().path()))
.is_some()
{
middleware.next(ctx, next)
} else {
next.next(ctx)
}
}
}