1use crate::{
47 error::{InternalError, SaphirError},
48 http_context::HttpContext,
49 utils::UriPathMatcher,
50};
51use futures::{future::BoxFuture, FutureExt};
52use futures_util::future::Future;
53
54pub trait Middleware {
55 fn next(&'static self, ctx: HttpContext, chain: &'static dyn MiddlewareChain) -> BoxFuture<'static, Result<HttpContext, SaphirError>>;
56}
57
58impl<Fun, Fut> Middleware for Fun
59where
60 Fun: Fn(HttpContext, &'static dyn MiddlewareChain) -> Fut,
61 Fut: 'static + Future<Output = Result<HttpContext, SaphirError>> + Send,
62{
63 #[inline]
64 fn next(&'static self, ctx: HttpContext, chain: &'static dyn MiddlewareChain) -> BoxFuture<'static, Result<HttpContext, SaphirError>> {
65 (*self)(ctx, chain).boxed()
66 }
67}
68
69pub struct Builder<Chain: MiddlewareChain> {
71 chain: Chain,
72}
73
74impl Default for Builder<MiddleChainEnd> {
75 fn default() -> Self {
76 Self { chain: MiddleChainEnd }
77 }
78}
79
80impl<Chain: MiddlewareChain + 'static> Builder<Chain> {
81 pub fn apply<'a, Mid, E>(self, mid: Mid, include_path: Vec<&str>, exclude_path: E) -> Builder<MiddlewareChainLink<Mid, Chain>>
102 where
103 Mid: 'static + Middleware + Sync + Send,
104 E: Into<Option<Vec<&'a str>>>,
105 {
106 let rule = Rule::new(include_path, exclude_path.into());
107 Builder {
108 chain: MiddlewareChainLink { rule, mid, rest: self.chain },
109 }
110 }
111
112 pub(crate) fn build(self) -> Box<dyn MiddlewareChain> {
113 Box::new(self.chain)
114 }
115}
116
117pub(crate) struct Rule {
118 included_path: Vec<UriPathMatcher>,
119 excluded_path: Option<Vec<UriPathMatcher>>,
120}
121
122impl Rule {
123 #[doc(hidden)]
124 pub fn new(include_path: Vec<&str>, exclude_path: Option<Vec<&str>>) -> Self {
125 Rule {
126 included_path: include_path
127 .iter()
128 .filter_map(|p| {
129 UriPathMatcher::new(p)
130 .map_err(|e| error!("Unable to construct included middleware route: {}", e))
131 .ok()
132 })
133 .collect(),
134 excluded_path: exclude_path.map(|ex| {
135 ex.iter()
136 .filter_map(|p| {
137 UriPathMatcher::new(p)
138 .map_err(|e| error!("Unable to construct excluded middleware route: {}", e))
139 .ok()
140 })
141 .collect()
142 }),
143 }
144 }
145
146 #[doc(hidden)]
147 pub fn validate_path(&self, path: &str) -> bool {
148 if self.included_path.iter().any(|m_p| m_p.match_non_exhaustive(path)) {
149 if let Some(ref excluded_path) = self.excluded_path {
150 return !excluded_path.iter().any(|m_e_p| m_e_p.match_non_exhaustive(path));
151 } else {
152 return true;
153 }
154 }
155
156 false
157 }
158}
159
160#[doc(hidden)]
161pub trait MiddlewareChain: Sync + Send {
162 fn next(&self, ctx: HttpContext) -> BoxFuture<'static, Result<HttpContext, SaphirError>>;
163}
164
165#[doc(hidden)]
166pub struct MiddleChainEnd;
167
168impl MiddlewareChain for MiddleChainEnd {
169 #[doc(hidden)]
170 #[allow(unused_mut)]
171 #[inline]
172 fn next(&self, mut ctx: HttpContext) -> BoxFuture<'static, Result<HttpContext, SaphirError>> {
173 async {
174 let router = ctx.router.take().ok_or(SaphirError::Internal(InternalError::Stack))?;
175 router.dispatch(ctx).await
176 }
177 .boxed()
178 }
179}
180
181#[doc(hidden)]
182pub struct MiddlewareChainLink<Mid: Middleware, Rest: MiddlewareChain> {
183 rule: Rule,
184 mid: Mid,
185 rest: Rest,
186}
187
188#[doc(hidden)]
189impl<Mid, Rest> MiddlewareChain for MiddlewareChainLink<Mid, Rest>
190where
191 Mid: Middleware + Sync + Send + 'static,
192 Rest: MiddlewareChain,
193{
194 #[doc(hidden)]
195 #[allow(clippy::transmute_ptr_to_ptr)]
196 #[inline]
197 fn next(&self, ctx: HttpContext) -> BoxFuture<'static, Result<HttpContext, SaphirError>> {
198 let (mid, rest) = unsafe {
202 (
203 std::mem::transmute::<&'_ Mid, &'static Mid>(&self.mid),
204 std::mem::transmute::<&'_ dyn MiddlewareChain, &'static dyn MiddlewareChain>(&self.rest),
205 )
206 };
207
208 if ctx.state.request().filter(|req| self.rule.validate_path(req.uri().path())).is_some() {
209 mid.next(ctx, rest)
210 } else {
211 rest.next(ctx)
212 }
213 }
214}