saphir/
middleware.rs

1//! A middleware is an object being called before the request is processed by
2//! the router, allowing to continue or stop the processing of a given request
3//! by calling / omitting next.
4//!
5//! ```text
6//!        chain.next(_)       chain.next(_)
7//!              |               |
8//!              |               |
9//! +---------+  |  +---------+  |  +---------+
10//! |         +--+->+         +--+->+         |
11//! | Middle  |     | Middle  |     |  Router |
12//! | ware1   |     | ware2   |     |         |
13//! |         +<----+         +<----+         |
14//! +---------+     +---------+     +---------+
15//! ```
16//!
17//! Once the request is fully processed by the stack or whenever a middleware
18//! returns an error, the request is terminated and the response is generated,
19//! the response then becomes available to the middleware
20//!
21//! A middleware is defined as the following:
22//!
23//! ```rust
24//! # use saphir::prelude::*;
25//! # struct CustomData;
26//! #
27//! async fn example_middleware(data: &CustomData, ctx: HttpContext, chain: &dyn MiddlewareChain) -> Result<HttpContext, SaphirError> {
28//!     // Do work before the request is handled by the router
29//!
30//!     let ctx = chain.next(ctx).await?;
31//!
32//!     // Do work with the response
33//!
34//!     Ok(ctx)
35//! }
36//! ```
37//!
38//! *SAFETY NOTICE*
39//!
40//! Inside the middleware chain we need a little bit of unsafe code. This code
41//! allow us to consider the futures generated by the middlewares as 'static.
42//! This is considered safe since all middleware data lives within the server
43//! stack which has a static lifetime over your application. We plan to remove
44//! this unsafe code as soon as we find another solution to it.
45
46use crate::{
47    error::{InternalError, SaphirError},
48    http_context::HttpContext,
49    utils::UriPathMatcher,
50};
51use futures::{future::BoxFuture, FutureExt};
52use futures_util::future::Future;
53
54pub trait Middleware {
55    fn next(&'static self, ctx: HttpContext, chain: &'static dyn MiddlewareChain) -> BoxFuture<'static, Result<HttpContext, SaphirError>>;
56}
57
58impl<Fun, Fut> Middleware for Fun
59where
60    Fun: Fn(HttpContext, &'static dyn MiddlewareChain) -> Fut,
61    Fut: 'static + Future<Output = Result<HttpContext, SaphirError>> + Send,
62{
63    #[inline]
64    fn next(&'static self, ctx: HttpContext, chain: &'static dyn MiddlewareChain) -> BoxFuture<'static, Result<HttpContext, SaphirError>> {
65        (*self)(ctx, chain).boxed()
66    }
67}
68
69/// Builder to apply middleware onto the http stack
70pub struct Builder<Chain: MiddlewareChain> {
71    chain: Chain,
72}
73
74impl Default for Builder<MiddleChainEnd> {
75    fn default() -> Self {
76        Self { chain: MiddleChainEnd }
77    }
78}
79
80impl<Chain: MiddlewareChain + 'static> Builder<Chain> {
81    /// Method to apply a new middleware onto the stack where the `include_path`
82    /// vec are all path affected by the middleware, and `exclude_path` are
83    /// exclusion amongst the included paths.
84    ///
85    /// ```rust
86    /// use saphir::middleware::Builder as MBuilder;
87    /// # use saphir::prelude::*;
88    ///
89    /// # async fn log_middleware(
90    /// #     ctx: HttpContext,
91    /// #     chain: &dyn MiddlewareChain,
92    /// # ) -> Result<HttpContext, SaphirError> {
93    /// #     println!("new request on path: {}", ctx.state.request_unchecked().uri().path());
94    /// #     let ctx = chain.next(ctx).await?;
95    /// #     println!("new response with status: {}", ctx.state.response_unchecked().status());
96    /// #     Ok(ctx)
97    /// # }
98    /// #
99    /// let builder = MBuilder::default().apply(log_middleware, vec!["/"], None);
100    /// ```
101    pub fn apply<'a, Mid, E>(self, mid: Mid, include_path: Vec<&str>, exclude_path: E) -> Builder<MiddlewareChainLink<Mid, Chain>>
102    where
103        Mid: 'static + Middleware + Sync + Send,
104        E: Into<Option<Vec<&'a str>>>,
105    {
106        let rule = Rule::new(include_path, exclude_path.into());
107        Builder {
108            chain: MiddlewareChainLink { rule, mid, rest: self.chain },
109        }
110    }
111
112    pub(crate) fn build(self) -> Box<dyn MiddlewareChain> {
113        Box::new(self.chain)
114    }
115}
116
117pub(crate) struct Rule {
118    included_path: Vec<UriPathMatcher>,
119    excluded_path: Option<Vec<UriPathMatcher>>,
120}
121
122impl Rule {
123    #[doc(hidden)]
124    pub fn new(include_path: Vec<&str>, exclude_path: Option<Vec<&str>>) -> Self {
125        Rule {
126            included_path: include_path
127                .iter()
128                .filter_map(|p| {
129                    UriPathMatcher::new(p)
130                        .map_err(|e| error!("Unable to construct included middleware route: {}", e))
131                        .ok()
132                })
133                .collect(),
134            excluded_path: exclude_path.map(|ex| {
135                ex.iter()
136                    .filter_map(|p| {
137                        UriPathMatcher::new(p)
138                            .map_err(|e| error!("Unable to construct excluded middleware route: {}", e))
139                            .ok()
140                    })
141                    .collect()
142            }),
143        }
144    }
145
146    #[doc(hidden)]
147    pub fn validate_path(&self, path: &str) -> bool {
148        if self.included_path.iter().any(|m_p| m_p.match_non_exhaustive(path)) {
149            if let Some(ref excluded_path) = self.excluded_path {
150                return !excluded_path.iter().any(|m_e_p| m_e_p.match_non_exhaustive(path));
151            } else {
152                return true;
153            }
154        }
155
156        false
157    }
158}
159
160#[doc(hidden)]
161pub trait MiddlewareChain: Sync + Send {
162    fn next(&self, ctx: HttpContext) -> BoxFuture<'static, Result<HttpContext, SaphirError>>;
163}
164
165#[doc(hidden)]
166pub struct MiddleChainEnd;
167
168impl MiddlewareChain for MiddleChainEnd {
169    #[doc(hidden)]
170    #[allow(unused_mut)]
171    #[inline]
172    fn next(&self, mut ctx: HttpContext) -> BoxFuture<'static, Result<HttpContext, SaphirError>> {
173        async {
174            let router = ctx.router.take().ok_or(SaphirError::Internal(InternalError::Stack))?;
175            router.dispatch(ctx).await
176        }
177        .boxed()
178    }
179}
180
181#[doc(hidden)]
182pub struct MiddlewareChainLink<Mid: Middleware, Rest: MiddlewareChain> {
183    rule: Rule,
184    mid: Mid,
185    rest: Rest,
186}
187
188#[doc(hidden)]
189impl<Mid, Rest> MiddlewareChain for MiddlewareChainLink<Mid, Rest>
190where
191    Mid: Middleware + Sync + Send + 'static,
192    Rest: MiddlewareChain,
193{
194    #[doc(hidden)]
195    #[allow(clippy::transmute_ptr_to_ptr)]
196    #[inline]
197    fn next(&self, ctx: HttpContext) -> BoxFuture<'static, Result<HttpContext, SaphirError>> {
198        // # SAFETY #
199        // The middleware chain and data are initialized in static memory when calling
200        // run on Server.
201        let (mid, rest) = unsafe {
202            (
203                std::mem::transmute::<&'_ Mid, &'static Mid>(&self.mid),
204                std::mem::transmute::<&'_ dyn MiddlewareChain, &'static dyn MiddlewareChain>(&self.rest),
205            )
206        };
207
208        if ctx.state.request().filter(|req| self.rule.validate_path(req.uri().path())).is_some() {
209            mid.next(ctx, rest)
210        } else {
211            rest.next(ctx)
212        }
213    }
214}