1use std::sync::Arc;
2
3use axum::{body::Body, extract::Request, http::Method};
4use nestforge_core::{framework_log_event, NextFn, NextFuture};
5
6pub trait NestMiddleware: Send + Sync + 'static {
7 fn handle(&self, req: Request<Body>, next: NextFn) -> NextFuture;
8}
9
10#[derive(Clone)]
11pub struct MiddlewareBinding {
12 middleware: Arc<dyn NestMiddleware>,
13 matcher: RouteMatcher,
14}
15
16impl MiddlewareBinding {
17 fn matches(&self, method: &Method, path: &str) -> bool {
18 self.matcher.matches(method, path)
19 }
20}
21
22#[derive(Clone, Default)]
23struct RouteMatcher {
24 include: Vec<MiddlewareRoute>,
25 exclude: Vec<MiddlewareRoute>,
26}
27
28impl RouteMatcher {
29 fn matches(&self, method: &Method, path: &str) -> bool {
30 if self.exclude.iter().any(|route| route.matches(method, path)) {
31 return false;
32 }
33
34 if self.include.is_empty() {
35 return true;
36 }
37
38 self.include.iter().any(|route| route.matches(method, path))
39 }
40}
41
42#[derive(Clone, Debug, PartialEq, Eq)]
43pub struct MiddlewareRoute {
44 path: String,
45 methods: Option<Vec<Method>>,
46}
47
48impl MiddlewareRoute {
49 pub fn path(path: impl Into<String>) -> Self {
50 Self {
51 path: normalize_path(path.into()),
52 methods: None,
53 }
54 }
55
56 pub fn methods<I>(path: impl Into<String>, methods: I) -> Self
57 where
58 I: IntoIterator<Item = Method>,
59 {
60 Self {
61 path: normalize_path(path.into()),
62 methods: Some(methods.into_iter().collect()),
63 }
64 }
65
66 pub fn get(path: impl Into<String>) -> Self {
67 Self::methods(path, [Method::GET])
68 }
69
70 pub fn post(path: impl Into<String>) -> Self {
71 Self::methods(path, [Method::POST])
72 }
73
74 pub fn put(path: impl Into<String>) -> Self {
75 Self::methods(path, [Method::PUT])
76 }
77
78 pub fn delete(path: impl Into<String>) -> Self {
79 Self::methods(path, [Method::DELETE])
80 }
81
82 fn matches(&self, method: &Method, path: &str) -> bool {
83 if !path_matches_prefix(path, &self.path) {
84 return false;
85 }
86
87 match &self.methods {
88 Some(methods) => methods.iter().any(|candidate| candidate == method),
89 None => true,
90 }
91 }
92}
93
94impl From<&str> for MiddlewareRoute {
95 fn from(value: &str) -> Self {
96 Self::path(value)
97 }
98}
99
100impl From<String> for MiddlewareRoute {
101 fn from(value: String) -> Self {
102 Self::path(value)
103 }
104}
105
106#[derive(Default)]
107pub struct MiddlewareConsumer {
108 bindings: Vec<MiddlewareBinding>,
109}
110
111impl MiddlewareConsumer {
112 pub fn new() -> Self {
113 Self::default()
114 }
115
116 pub fn apply<T>(&mut self) -> MiddlewareBindingBuilder<'_>
117 where
118 T: NestMiddleware + Default,
119 {
120 MiddlewareBindingBuilder::new(self, Arc::new(T::default()))
121 }
122
123 pub fn apply_instance<T>(&mut self, middleware: T) -> MiddlewareBindingBuilder<'_>
124 where
125 T: NestMiddleware,
126 {
127 MiddlewareBindingBuilder::new(self, Arc::new(middleware))
128 }
129
130 pub fn into_bindings(self) -> Vec<MiddlewareBinding> {
131 self.bindings
132 }
133}
134
135pub struct MiddlewareBindingBuilder<'a> {
136 consumer: &'a mut MiddlewareConsumer,
137 middleware: Arc<dyn NestMiddleware>,
138 exclude: Vec<MiddlewareRoute>,
139}
140
141impl<'a> MiddlewareBindingBuilder<'a> {
142 fn new(consumer: &'a mut MiddlewareConsumer, middleware: Arc<dyn NestMiddleware>) -> Self {
143 Self {
144 consumer,
145 middleware,
146 exclude: Vec::new(),
147 }
148 }
149
150 pub fn exclude<I, S>(mut self, routes: I) -> Self
151 where
152 I: IntoIterator<Item = S>,
153 S: Into<MiddlewareRoute>,
154 {
155 self.exclude = routes.into_iter().map(Into::into).collect();
156 self
157 }
158
159 pub fn for_all_routes(self) -> &'a mut MiddlewareConsumer {
160 self.register(Vec::new())
161 }
162
163 pub fn for_routes<I, S>(self, routes: I) -> &'a mut MiddlewareConsumer
164 where
165 I: IntoIterator<Item = S>,
166 S: Into<MiddlewareRoute>,
167 {
168 let include = routes.into_iter().map(Into::into).collect();
169 self.register(include)
170 }
171
172 fn register(self, include: Vec<MiddlewareRoute>) -> &'a mut MiddlewareConsumer {
173 framework_log_event(
174 "middleware_register",
175 &[
176 ("include", format!("{include:?}")),
177 ("exclude", format!("{:?}", self.exclude)),
178 ],
179 );
180 self.consumer.bindings.push(MiddlewareBinding {
181 middleware: self.middleware,
182 matcher: RouteMatcher {
183 include,
184 exclude: self.exclude,
185 },
186 });
187 self.consumer
188 }
189}
190
191pub fn run_middleware_chain(
192 middlewares: Arc<Vec<MiddlewareBinding>>,
193 index: usize,
194 req: Request<Body>,
195 terminal: NextFn,
196) -> NextFuture {
197 if index >= middlewares.len() {
198 return terminal(req);
199 }
200
201 let binding = middlewares[index].clone();
202 if !binding.matches(req.method(), req.uri().path()) {
203 return run_middleware_chain(middlewares, index + 1, req, terminal);
204 }
205
206 let middlewares_for_next = Arc::clone(&middlewares);
207 let terminal_for_next = Arc::clone(&terminal);
208 let next_fn: NextFn = Arc::new(move |next_req| {
209 run_middleware_chain(
210 Arc::clone(&middlewares_for_next),
211 index + 1,
212 next_req,
213 Arc::clone(&terminal_for_next),
214 )
215 });
216
217 binding.middleware.handle(req, next_fn)
218}
219
220fn normalize_path(path: String) -> String {
221 let trimmed = path.trim();
222 if trimmed.is_empty() || trimmed == "/" {
223 return "/".to_string();
224 }
225
226 if trimmed.starts_with('/') {
227 trimmed.trim_end_matches('/').to_string()
228 } else {
229 format!("/{}", trimmed.trim_end_matches('/'))
230 }
231}
232
233fn path_matches_prefix(path: &str, prefix: &str) -> bool {
234 if prefix == "/" {
235 return true;
236 }
237
238 path == prefix || path.starts_with(&format!("{prefix}/"))
239}
240
241#[cfg(test)]
242mod tests {
243 use axum::http::Method;
244
245 use super::{normalize_path, path_matches_prefix, MiddlewareRoute, RouteMatcher};
246
247 #[test]
248 fn matcher_supports_prefix_matching_and_excludes() {
249 let matcher = RouteMatcher {
250 include: vec![MiddlewareRoute::path("/api")],
251 exclude: vec![MiddlewareRoute::path("/api/health")],
252 };
253
254 assert!(matcher.matches(&Method::GET, "/api/users"));
255 assert!(!matcher.matches(&Method::GET, "/api/health"));
256 assert!(!matcher.matches(&Method::GET, "/admin"));
257 }
258
259 #[test]
260 fn normalize_path_handles_empty_and_trailing_slashes() {
261 assert_eq!(normalize_path("".to_string()), "/");
262 assert_eq!(normalize_path("/users/".to_string()), "/users");
263 assert_eq!(normalize_path("users".to_string()), "/users");
264 }
265
266 #[test]
267 fn prefix_matching_requires_boundary() {
268 assert!(path_matches_prefix("/users/1", "/users"));
269 assert!(path_matches_prefix("/users", "/users"));
270 assert!(!path_matches_prefix("/users-list", "/users"));
271 }
272
273 #[test]
274 fn matcher_can_target_specific_http_methods() {
275 let matcher = RouteMatcher {
276 include: vec![MiddlewareRoute::get("/admin")],
277 exclude: Vec::new(),
278 };
279
280 assert!(matcher.matches(&Method::GET, "/admin/users"));
281 assert!(!matcher.matches(&Method::POST, "/admin/users"));
282 }
283}