Skip to main content

nestforge_http/
middleware.rs

1use std::sync::Arc;
2
3use axum::{body::Body, extract::Request, http::Method};
4use nestforge_core::{framework_log_event, NextFn, NextFuture};
5
6/// Trait for defining custom middleware.
7///
8/// Middleware sits between the incoming request and the NestForge pipeline (guards/interceptors/handlers).
9/// It can modify the request, response, or short-circuit the execution.
10///
11/// # Example
12/// ```rust
13/// struct LoggerMiddleware;
14/// impl NestMiddleware for LoggerMiddleware {
15///     fn handle(&self, req: Request<Body>, next: NextFn) -> NextFuture {
16///         Box::pin(async move {
17///             println!("Request: {:?}", req.uri());
18///             next(req).await
19///         })
20///     }
21/// }
22/// ```
23pub trait NestMiddleware: Send + Sync + 'static {
24    fn handle(&self, req: Request<Body>, next: NextFn) -> NextFuture;
25}
26
27#[derive(Clone)]
28pub struct MiddlewareBinding {
29    middleware: Arc<dyn NestMiddleware>,
30    matcher: RouteMatcher,
31}
32
33impl MiddlewareBinding {
34    fn matches(&self, method: &Method, path: &str) -> bool {
35        self.matcher.matches(method, path)
36    }
37}
38
39#[derive(Clone, Default)]
40struct RouteMatcher {
41    include: Vec<MiddlewareRoute>,
42    exclude: Vec<MiddlewareRoute>,
43}
44
45impl RouteMatcher {
46    fn matches(&self, method: &Method, path: &str) -> bool {
47        if self.exclude.iter().any(|route| route.matches(method, path)) {
48            return false;
49        }
50
51        if self.include.is_empty() {
52            return true;
53        }
54
55        self.include.iter().any(|route| route.matches(method, path))
56    }
57}
58
59/// A definition of a route (path + method) for middleware targeting.
60///
61/// Supports exact path matching and prefix matching (if path ends with `/`).
62#[derive(Clone, Debug, PartialEq, Eq)]
63pub struct MiddlewareRoute {
64    path: String,
65    methods: Option<Vec<Method>>,
66}
67
68impl MiddlewareRoute {
69    /// Matches any HTTP method on the given path.
70    pub fn path(path: impl Into<String>) -> Self {
71        Self {
72            path: normalize_path(path.into()),
73            methods: None,
74        }
75    }
76
77    /// Matches specific HTTP methods on the given path.
78    pub fn methods<I>(path: impl Into<String>, methods: I) -> Self
79    where
80        I: IntoIterator<Item = Method>,
81    {
82        Self {
83            path: normalize_path(path.into()),
84            methods: Some(methods.into_iter().collect()),
85        }
86    }
87
88    /// Matches GET requests.
89    pub fn get(path: impl Into<String>) -> Self {
90        Self::methods(path, [Method::GET])
91    }
92
93    /// Matches POST requests.
94    pub fn post(path: impl Into<String>) -> Self {
95        Self::methods(path, [Method::POST])
96    }
97
98    /// Matches PUT requests.
99    pub fn put(path: impl Into<String>) -> Self {
100        Self::methods(path, [Method::PUT])
101    }
102
103    /// Matches DELETE requests.
104    pub fn delete(path: impl Into<String>) -> Self {
105        Self::methods(path, [Method::DELETE])
106    }
107
108    fn matches(&self, method: &Method, path: &str) -> bool {
109        if !path_matches_prefix(path, &self.path) {
110            return false;
111        }
112
113        match &self.methods {
114            Some(methods) => methods.iter().any(|candidate| candidate == method),
115            None => true,
116        }
117    }
118}
119
120impl From<&str> for MiddlewareRoute {
121    fn from(value: &str) -> Self {
122        Self::path(value)
123    }
124}
125
126impl From<String> for MiddlewareRoute {
127    fn from(value: String) -> Self {
128        Self::path(value)
129    }
130}
131
132/// A builder for configuring middleware application.
133///
134/// Allows you to apply middleware to specific routes or exclude routes.
135#[derive(Default)]
136pub struct MiddlewareConsumer {
137    bindings: Vec<MiddlewareBinding>,
138}
139
140impl MiddlewareConsumer {
141    pub fn new() -> Self {
142        Self::default()
143    }
144
145    /// Apply a middleware type (default constructed).
146    pub fn apply<T>(&mut self) -> MiddlewareBindingBuilder<'_>
147    where
148        T: NestMiddleware + Default,
149    {
150        MiddlewareBindingBuilder::new(self, Arc::new(T::default()))
151    }
152
153    /// Apply a specific middleware instance.
154    pub fn apply_instance<T>(&mut self, middleware: T) -> MiddlewareBindingBuilder<'_>
155    where
156        T: NestMiddleware,
157    {
158        MiddlewareBindingBuilder::new(self, Arc::new(middleware))
159    }
160
161    pub fn into_bindings(self) -> Vec<MiddlewareBinding> {
162        self.bindings
163    }
164}
165
166/// Helper struct for chaining middleware configuration methods.
167pub struct MiddlewareBindingBuilder<'a> {
168    consumer: &'a mut MiddlewareConsumer,
169    middleware: Arc<dyn NestMiddleware>,
170    exclude: Vec<MiddlewareRoute>,
171}
172
173impl<'a> MiddlewareBindingBuilder<'a> {
174    fn new(consumer: &'a mut MiddlewareConsumer, middleware: Arc<dyn NestMiddleware>) -> Self {
175        Self {
176            consumer,
177            middleware,
178            exclude: Vec::new(),
179        }
180    }
181
182    /// Exclude specific routes from this middleware.
183    pub fn exclude<I, S>(mut self, routes: I) -> Self
184    where
185        I: IntoIterator<Item = S>,
186        S: Into<MiddlewareRoute>,
187    {
188        self.exclude = routes.into_iter().map(Into::into).collect();
189        self
190    }
191
192    /// Apply the middleware to all routes.
193    pub fn for_all_routes(self) -> &'a mut MiddlewareConsumer {
194        self.register(Vec::new())
195    }
196
197    /// Apply the middleware to specific routes.
198    pub fn for_routes<I, S>(self, routes: I) -> &'a mut MiddlewareConsumer
199    where
200        I: IntoIterator<Item = S>,
201        S: Into<MiddlewareRoute>,
202    {
203        let include = routes.into_iter().map(Into::into).collect();
204        self.register(include)
205    }
206
207    fn register(self, include: Vec<MiddlewareRoute>) -> &'a mut MiddlewareConsumer {
208        framework_log_event(
209            "middleware_register",
210            &[
211                ("include", format!("{include:?}")),
212                ("exclude", format!("{:?}", self.exclude)),
213            ],
214        );
215        self.consumer.bindings.push(MiddlewareBinding {
216            middleware: self.middleware,
217            matcher: RouteMatcher {
218                include,
219                exclude: self.exclude,
220            },
221        });
222        self.consumer
223    }
224}
225
226pub fn run_middleware_chain(
227    middlewares: Arc<Vec<MiddlewareBinding>>,
228    index: usize,
229    req: Request<Body>,
230    terminal: NextFn,
231) -> NextFuture {
232    if index >= middlewares.len() {
233        return terminal(req);
234    }
235
236    let binding = middlewares[index].clone();
237    if !binding.matches(req.method(), req.uri().path()) {
238        return run_middleware_chain(middlewares, index + 1, req, terminal);
239    }
240
241    let middlewares_for_next = Arc::clone(&middlewares);
242    let terminal_for_next = Arc::clone(&terminal);
243    let next_fn: NextFn = Arc::new(move |next_req| {
244        run_middleware_chain(
245            Arc::clone(&middlewares_for_next),
246            index + 1,
247            next_req,
248            Arc::clone(&terminal_for_next),
249        )
250    });
251
252    binding.middleware.handle(req, next_fn)
253}
254
255fn normalize_path(path: String) -> String {
256    let trimmed = path.trim();
257    if trimmed.is_empty() || trimmed == "/" {
258        return "/".to_string();
259    }
260
261    if trimmed.starts_with('/') {
262        trimmed.trim_end_matches('/').to_string()
263    } else {
264        format!("/{}", trimmed.trim_end_matches('/'))
265    }
266}
267
268fn path_matches_prefix(path: &str, prefix: &str) -> bool {
269    if prefix == "/" {
270        return true;
271    }
272
273    path == prefix || path.starts_with(&format!("{prefix}/"))
274}
275
276#[cfg(test)]
277mod tests {
278    use axum::http::Method;
279
280    use super::{normalize_path, path_matches_prefix, MiddlewareRoute, RouteMatcher};
281
282    #[test]
283    fn matcher_supports_prefix_matching_and_excludes() {
284        let matcher = RouteMatcher {
285            include: vec![MiddlewareRoute::path("/api")],
286            exclude: vec![MiddlewareRoute::path("/api/health")],
287        };
288
289        assert!(matcher.matches(&Method::GET, "/api/users"));
290        assert!(!matcher.matches(&Method::GET, "/api/health"));
291        assert!(!matcher.matches(&Method::GET, "/admin"));
292    }
293
294    #[test]
295    fn normalize_path_handles_empty_and_trailing_slashes() {
296        assert_eq!(normalize_path("".to_string()), "/");
297        assert_eq!(normalize_path("/users/".to_string()), "/users");
298        assert_eq!(normalize_path("users".to_string()), "/users");
299    }
300
301    #[test]
302    fn prefix_matching_requires_boundary() {
303        assert!(path_matches_prefix("/users/1", "/users"));
304        assert!(path_matches_prefix("/users", "/users"));
305        assert!(!path_matches_prefix("/users-list", "/users"));
306    }
307
308    #[test]
309    fn matcher_can_target_specific_http_methods() {
310        let matcher = RouteMatcher {
311            include: vec![MiddlewareRoute::get("/admin")],
312            exclude: Vec::new(),
313        };
314
315        assert!(matcher.matches(&Method::GET, "/admin/users"));
316        assert!(!matcher.matches(&Method::POST, "/admin/users"));
317    }
318}