Skip to main content

nestforge_http/
middleware.rs

1use std::sync::Arc;
2
3use axum::{body::Body, extract::Request, http::Method};
4use nestforge_core::{framework_log_event, NextFn, NextFuture};
5
6pub trait NestMiddleware: Send + Sync + 'static {
7    fn handle(&self, req: Request<Body>, next: NextFn) -> NextFuture;
8}
9
10#[derive(Clone)]
11pub struct MiddlewareBinding {
12    middleware: Arc<dyn NestMiddleware>,
13    matcher: RouteMatcher,
14}
15
16impl MiddlewareBinding {
17    fn matches(&self, method: &Method, path: &str) -> bool {
18        self.matcher.matches(method, path)
19    }
20}
21
22#[derive(Clone, Default)]
23struct RouteMatcher {
24    include: Vec<MiddlewareRoute>,
25    exclude: Vec<MiddlewareRoute>,
26}
27
28impl RouteMatcher {
29    fn matches(&self, method: &Method, path: &str) -> bool {
30        if self.exclude.iter().any(|route| route.matches(method, path)) {
31            return false;
32        }
33
34        if self.include.is_empty() {
35            return true;
36        }
37
38        self.include.iter().any(|route| route.matches(method, path))
39    }
40}
41
42#[derive(Clone, Debug, PartialEq, Eq)]
43pub struct MiddlewareRoute {
44    path: String,
45    methods: Option<Vec<Method>>,
46}
47
48impl MiddlewareRoute {
49    pub fn path(path: impl Into<String>) -> Self {
50        Self {
51            path: normalize_path(path.into()),
52            methods: None,
53        }
54    }
55
56    pub fn methods<I>(path: impl Into<String>, methods: I) -> Self
57    where
58        I: IntoIterator<Item = Method>,
59    {
60        Self {
61            path: normalize_path(path.into()),
62            methods: Some(methods.into_iter().collect()),
63        }
64    }
65
66    pub fn get(path: impl Into<String>) -> Self {
67        Self::methods(path, [Method::GET])
68    }
69
70    pub fn post(path: impl Into<String>) -> Self {
71        Self::methods(path, [Method::POST])
72    }
73
74    pub fn put(path: impl Into<String>) -> Self {
75        Self::methods(path, [Method::PUT])
76    }
77
78    pub fn delete(path: impl Into<String>) -> Self {
79        Self::methods(path, [Method::DELETE])
80    }
81
82    fn matches(&self, method: &Method, path: &str) -> bool {
83        if !path_matches_prefix(path, &self.path) {
84            return false;
85        }
86
87        match &self.methods {
88            Some(methods) => methods.iter().any(|candidate| candidate == method),
89            None => true,
90        }
91    }
92}
93
94impl From<&str> for MiddlewareRoute {
95    fn from(value: &str) -> Self {
96        Self::path(value)
97    }
98}
99
100impl From<String> for MiddlewareRoute {
101    fn from(value: String) -> Self {
102        Self::path(value)
103    }
104}
105
106#[derive(Default)]
107pub struct MiddlewareConsumer {
108    bindings: Vec<MiddlewareBinding>,
109}
110
111impl MiddlewareConsumer {
112    pub fn new() -> Self {
113        Self::default()
114    }
115
116    pub fn apply<T>(&mut self) -> MiddlewareBindingBuilder<'_>
117    where
118        T: NestMiddleware + Default,
119    {
120        MiddlewareBindingBuilder::new(self, Arc::new(T::default()))
121    }
122
123    pub fn apply_instance<T>(&mut self, middleware: T) -> MiddlewareBindingBuilder<'_>
124    where
125        T: NestMiddleware,
126    {
127        MiddlewareBindingBuilder::new(self, Arc::new(middleware))
128    }
129
130    pub fn into_bindings(self) -> Vec<MiddlewareBinding> {
131        self.bindings
132    }
133}
134
135pub struct MiddlewareBindingBuilder<'a> {
136    consumer: &'a mut MiddlewareConsumer,
137    middleware: Arc<dyn NestMiddleware>,
138    exclude: Vec<MiddlewareRoute>,
139}
140
141impl<'a> MiddlewareBindingBuilder<'a> {
142    fn new(consumer: &'a mut MiddlewareConsumer, middleware: Arc<dyn NestMiddleware>) -> Self {
143        Self {
144            consumer,
145            middleware,
146            exclude: Vec::new(),
147        }
148    }
149
150    pub fn exclude<I, S>(mut self, routes: I) -> Self
151    where
152        I: IntoIterator<Item = S>,
153        S: Into<MiddlewareRoute>,
154    {
155        self.exclude = routes.into_iter().map(Into::into).collect();
156        self
157    }
158
159    pub fn for_all_routes(self) -> &'a mut MiddlewareConsumer {
160        self.register(Vec::new())
161    }
162
163    pub fn for_routes<I, S>(self, routes: I) -> &'a mut MiddlewareConsumer
164    where
165        I: IntoIterator<Item = S>,
166        S: Into<MiddlewareRoute>,
167    {
168        let include = routes.into_iter().map(Into::into).collect();
169        self.register(include)
170    }
171
172    fn register(self, include: Vec<MiddlewareRoute>) -> &'a mut MiddlewareConsumer {
173        framework_log_event(
174            "middleware_register",
175            &[
176                ("include", format!("{include:?}")),
177                ("exclude", format!("{:?}", self.exclude)),
178            ],
179        );
180        self.consumer.bindings.push(MiddlewareBinding {
181            middleware: self.middleware,
182            matcher: RouteMatcher {
183                include,
184                exclude: self.exclude,
185            },
186        });
187        self.consumer
188    }
189}
190
191pub fn run_middleware_chain(
192    middlewares: Arc<Vec<MiddlewareBinding>>,
193    index: usize,
194    req: Request<Body>,
195    terminal: NextFn,
196) -> NextFuture {
197    if index >= middlewares.len() {
198        return terminal(req);
199    }
200
201    let binding = middlewares[index].clone();
202    if !binding.matches(req.method(), req.uri().path()) {
203        return run_middleware_chain(middlewares, index + 1, req, terminal);
204    }
205
206    let middlewares_for_next = Arc::clone(&middlewares);
207    let terminal_for_next = Arc::clone(&terminal);
208    let next_fn: NextFn = Arc::new(move |next_req| {
209        run_middleware_chain(
210            Arc::clone(&middlewares_for_next),
211            index + 1,
212            next_req,
213            Arc::clone(&terminal_for_next),
214        )
215    });
216
217    binding.middleware.handle(req, next_fn)
218}
219
220fn normalize_path(path: String) -> String {
221    let trimmed = path.trim();
222    if trimmed.is_empty() || trimmed == "/" {
223        return "/".to_string();
224    }
225
226    if trimmed.starts_with('/') {
227        trimmed.trim_end_matches('/').to_string()
228    } else {
229        format!("/{}", trimmed.trim_end_matches('/'))
230    }
231}
232
233fn path_matches_prefix(path: &str, prefix: &str) -> bool {
234    if prefix == "/" {
235        return true;
236    }
237
238    path == prefix || path.starts_with(&format!("{prefix}/"))
239}
240
241#[cfg(test)]
242mod tests {
243    use axum::http::Method;
244
245    use super::{normalize_path, path_matches_prefix, MiddlewareRoute, RouteMatcher};
246
247    #[test]
248    fn matcher_supports_prefix_matching_and_excludes() {
249        let matcher = RouteMatcher {
250            include: vec![MiddlewareRoute::path("/api")],
251            exclude: vec![MiddlewareRoute::path("/api/health")],
252        };
253
254        assert!(matcher.matches(&Method::GET, "/api/users"));
255        assert!(!matcher.matches(&Method::GET, "/api/health"));
256        assert!(!matcher.matches(&Method::GET, "/admin"));
257    }
258
259    #[test]
260    fn normalize_path_handles_empty_and_trailing_slashes() {
261        assert_eq!(normalize_path("".to_string()), "/");
262        assert_eq!(normalize_path("/users/".to_string()), "/users");
263        assert_eq!(normalize_path("users".to_string()), "/users");
264    }
265
266    #[test]
267    fn prefix_matching_requires_boundary() {
268        assert!(path_matches_prefix("/users/1", "/users"));
269        assert!(path_matches_prefix("/users", "/users"));
270        assert!(!path_matches_prefix("/users-list", "/users"));
271    }
272
273    #[test]
274    fn matcher_can_target_specific_http_methods() {
275        let matcher = RouteMatcher {
276            include: vec![MiddlewareRoute::get("/admin")],
277            exclude: Vec::new(),
278        };
279
280        assert!(matcher.matches(&Method::GET, "/admin/users"));
281        assert!(!matcher.matches(&Method::POST, "/admin/users"));
282    }
283}