Skip to main content

nestforge_http/
middleware.rs

1use std::sync::Arc;
2
3use axum::{body::Body, extract::Request, http::Method};
4use nestforge_core::{framework_log_event, NextFn, NextFuture};
5
6pub trait NestMiddleware: Send + Sync + 'static {
7    fn handle(&self, req: Request<Body>, next: NextFn) -> NextFuture;
8}
9
10#[derive(Clone)]
11pub struct MiddlewareBinding {
12    middleware: Arc<dyn NestMiddleware>,
13    matcher: RouteMatcher,
14}
15
16impl MiddlewareBinding {
17    fn matches(&self, method: &Method, path: &str) -> bool {
18        self.matcher.matches(method, path)
19    }
20}
21
22#[derive(Clone, Default)]
23struct RouteMatcher {
24    include: Vec<MiddlewareRoute>,
25    exclude: Vec<MiddlewareRoute>,
26}
27
28impl RouteMatcher {
29    fn matches(&self, method: &Method, path: &str) -> bool {
30        if self
31            .exclude
32            .iter()
33            .any(|route| route.matches(method, path))
34        {
35            return false;
36        }
37
38        if self.include.is_empty() {
39            return true;
40        }
41
42        self.include
43            .iter()
44            .any(|route| route.matches(method, path))
45    }
46}
47
48#[derive(Clone, Debug, PartialEq, Eq)]
49pub struct MiddlewareRoute {
50    path: String,
51    methods: Option<Vec<Method>>,
52}
53
54impl MiddlewareRoute {
55    pub fn path(path: impl Into<String>) -> Self {
56        Self {
57            path: normalize_path(path.into()),
58            methods: None,
59        }
60    }
61
62    pub fn methods<I>(path: impl Into<String>, methods: I) -> Self
63    where
64        I: IntoIterator<Item = Method>,
65    {
66        Self {
67            path: normalize_path(path.into()),
68            methods: Some(methods.into_iter().collect()),
69        }
70    }
71
72    pub fn get(path: impl Into<String>) -> Self {
73        Self::methods(path, [Method::GET])
74    }
75
76    pub fn post(path: impl Into<String>) -> Self {
77        Self::methods(path, [Method::POST])
78    }
79
80    pub fn put(path: impl Into<String>) -> Self {
81        Self::methods(path, [Method::PUT])
82    }
83
84    pub fn delete(path: impl Into<String>) -> Self {
85        Self::methods(path, [Method::DELETE])
86    }
87
88    fn matches(&self, method: &Method, path: &str) -> bool {
89        if !path_matches_prefix(path, &self.path) {
90            return false;
91        }
92
93        match &self.methods {
94            Some(methods) => methods.iter().any(|candidate| candidate == method),
95            None => true,
96        }
97    }
98}
99
100impl From<&str> for MiddlewareRoute {
101    fn from(value: &str) -> Self {
102        Self::path(value)
103    }
104}
105
106impl From<String> for MiddlewareRoute {
107    fn from(value: String) -> Self {
108        Self::path(value)
109    }
110}
111
112#[derive(Default)]
113pub struct MiddlewareConsumer {
114    bindings: Vec<MiddlewareBinding>,
115}
116
117impl MiddlewareConsumer {
118    pub fn new() -> Self {
119        Self::default()
120    }
121
122    pub fn apply<T>(&mut self) -> MiddlewareBindingBuilder<'_>
123    where
124        T: NestMiddleware + Default,
125    {
126        MiddlewareBindingBuilder::new(self, Arc::new(T::default()))
127    }
128
129    pub fn apply_instance<T>(&mut self, middleware: T) -> MiddlewareBindingBuilder<'_>
130    where
131        T: NestMiddleware,
132    {
133        MiddlewareBindingBuilder::new(self, Arc::new(middleware))
134    }
135
136    pub fn into_bindings(self) -> Vec<MiddlewareBinding> {
137        self.bindings
138    }
139}
140
141pub struct MiddlewareBindingBuilder<'a> {
142    consumer: &'a mut MiddlewareConsumer,
143    middleware: Arc<dyn NestMiddleware>,
144    exclude: Vec<MiddlewareRoute>,
145}
146
147impl<'a> MiddlewareBindingBuilder<'a> {
148    fn new(consumer: &'a mut MiddlewareConsumer, middleware: Arc<dyn NestMiddleware>) -> Self {
149        Self {
150            consumer,
151            middleware,
152            exclude: Vec::new(),
153        }
154    }
155
156    pub fn exclude<I, S>(mut self, routes: I) -> Self
157    where
158        I: IntoIterator<Item = S>,
159        S: Into<MiddlewareRoute>,
160    {
161        self.exclude = routes.into_iter().map(Into::into).collect();
162        self
163    }
164
165    pub fn for_all_routes(self) -> &'a mut MiddlewareConsumer {
166        self.register(Vec::new())
167    }
168
169    pub fn for_routes<I, S>(self, routes: I) -> &'a mut MiddlewareConsumer
170    where
171        I: IntoIterator<Item = S>,
172        S: Into<MiddlewareRoute>,
173    {
174        let include = routes.into_iter().map(Into::into).collect();
175        self.register(include)
176    }
177
178    fn register(self, include: Vec<MiddlewareRoute>) -> &'a mut MiddlewareConsumer {
179        framework_log_event(
180            "middleware_register",
181            &[("include", format!("{include:?}")), ("exclude", format!("{:?}", self.exclude))],
182        );
183        self.consumer.bindings.push(MiddlewareBinding {
184            middleware: self.middleware,
185            matcher: RouteMatcher {
186                include,
187                exclude: self.exclude,
188            },
189        });
190        self.consumer
191    }
192}
193
194pub fn run_middleware_chain(
195    middlewares: Arc<Vec<MiddlewareBinding>>,
196    index: usize,
197    req: Request<Body>,
198    terminal: NextFn,
199) -> NextFuture {
200    if index >= middlewares.len() {
201        return terminal(req);
202    }
203
204    let binding = middlewares[index].clone();
205    if !binding.matches(req.method(), req.uri().path()) {
206        return run_middleware_chain(middlewares, index + 1, req, terminal);
207    }
208
209    let middlewares_for_next = Arc::clone(&middlewares);
210    let terminal_for_next = Arc::clone(&terminal);
211    let next_fn: NextFn = Arc::new(move |next_req| {
212        run_middleware_chain(
213            Arc::clone(&middlewares_for_next),
214            index + 1,
215            next_req,
216            Arc::clone(&terminal_for_next),
217        )
218    });
219
220    binding.middleware.handle(req, next_fn)
221}
222
223fn normalize_path(path: String) -> String {
224    let trimmed = path.trim();
225    if trimmed.is_empty() || trimmed == "/" {
226        return "/".to_string();
227    }
228
229    if trimmed.starts_with('/') {
230        trimmed.trim_end_matches('/').to_string()
231    } else {
232        format!("/{}", trimmed.trim_end_matches('/'))
233    }
234}
235
236fn path_matches_prefix(path: &str, prefix: &str) -> bool {
237    if prefix == "/" {
238        return true;
239    }
240
241    path == prefix || path.starts_with(&format!("{prefix}/"))
242}
243
244#[cfg(test)]
245mod tests {
246    use axum::http::Method;
247
248    use super::{normalize_path, path_matches_prefix, MiddlewareRoute, RouteMatcher};
249
250    #[test]
251    fn matcher_supports_prefix_matching_and_excludes() {
252        let matcher = RouteMatcher {
253            include: vec![MiddlewareRoute::path("/api")],
254            exclude: vec![MiddlewareRoute::path("/api/health")],
255        };
256
257        assert!(matcher.matches(&Method::GET, "/api/users"));
258        assert!(!matcher.matches(&Method::GET, "/api/health"));
259        assert!(!matcher.matches(&Method::GET, "/admin"));
260    }
261
262    #[test]
263    fn normalize_path_handles_empty_and_trailing_slashes() {
264        assert_eq!(normalize_path("".to_string()), "/");
265        assert_eq!(normalize_path("/users/".to_string()), "/users");
266        assert_eq!(normalize_path("users".to_string()), "/users");
267    }
268
269    #[test]
270    fn prefix_matching_requires_boundary() {
271        assert!(path_matches_prefix("/users/1", "/users"));
272        assert!(path_matches_prefix("/users", "/users"));
273        assert!(!path_matches_prefix("/users-list", "/users"));
274    }
275
276    #[test]
277    fn matcher_can_target_specific_http_methods() {
278        let matcher = RouteMatcher {
279            include: vec![MiddlewareRoute::get("/admin")],
280            exclude: Vec::new(),
281        };
282
283        assert!(matcher.matches(&Method::GET, "/admin/users"));
284        assert!(!matcher.matches(&Method::POST, "/admin/users"));
285    }
286}