1use std::sync::Arc;
2
3use axum::{body::Body, extract::Request, http::Method};
4use nestforge_core::{framework_log_event, NextFn, NextFuture};
5
6pub trait NestMiddleware: Send + Sync + 'static {
24 fn handle(&self, req: Request<Body>, next: NextFn) -> NextFuture;
25}
26
27#[derive(Clone)]
28pub struct MiddlewareBinding {
29 middleware: Arc<dyn NestMiddleware>,
30 matcher: RouteMatcher,
31}
32
33impl MiddlewareBinding {
34 fn matches(&self, method: &Method, path: &str) -> bool {
35 self.matcher.matches(method, path)
36 }
37}
38
39#[derive(Clone, Default)]
40struct RouteMatcher {
41 include: Vec<MiddlewareRoute>,
42 exclude: Vec<MiddlewareRoute>,
43}
44
45impl RouteMatcher {
46 fn matches(&self, method: &Method, path: &str) -> bool {
47 if self.exclude.iter().any(|route| route.matches(method, path)) {
48 return false;
49 }
50
51 if self.include.is_empty() {
52 return true;
53 }
54
55 self.include.iter().any(|route| route.matches(method, path))
56 }
57}
58
59#[derive(Clone, Debug, PartialEq, Eq)]
63pub struct MiddlewareRoute {
64 path: String,
65 methods: Option<Vec<Method>>,
66}
67
68impl MiddlewareRoute {
69 pub fn path(path: impl Into<String>) -> Self {
71 Self {
72 path: normalize_path(path.into()),
73 methods: None,
74 }
75 }
76
77 pub fn methods<I>(path: impl Into<String>, methods: I) -> Self
79 where
80 I: IntoIterator<Item = Method>,
81 {
82 Self {
83 path: normalize_path(path.into()),
84 methods: Some(methods.into_iter().collect()),
85 }
86 }
87
88 pub fn get(path: impl Into<String>) -> Self {
90 Self::methods(path, [Method::GET])
91 }
92
93 pub fn post(path: impl Into<String>) -> Self {
95 Self::methods(path, [Method::POST])
96 }
97
98 pub fn put(path: impl Into<String>) -> Self {
100 Self::methods(path, [Method::PUT])
101 }
102
103 pub fn delete(path: impl Into<String>) -> Self {
105 Self::methods(path, [Method::DELETE])
106 }
107
108 fn matches(&self, method: &Method, path: &str) -> bool {
109 if !path_matches_prefix(path, &self.path) {
110 return false;
111 }
112
113 match &self.methods {
114 Some(methods) => methods.iter().any(|candidate| candidate == method),
115 None => true,
116 }
117 }
118}
119
120impl From<&str> for MiddlewareRoute {
121 fn from(value: &str) -> Self {
122 Self::path(value)
123 }
124}
125
126impl From<String> for MiddlewareRoute {
127 fn from(value: String) -> Self {
128 Self::path(value)
129 }
130}
131
132#[derive(Default)]
136pub struct MiddlewareConsumer {
137 bindings: Vec<MiddlewareBinding>,
138}
139
140impl MiddlewareConsumer {
141 pub fn new() -> Self {
142 Self::default()
143 }
144
145 pub fn apply<T>(&mut self) -> MiddlewareBindingBuilder<'_>
147 where
148 T: NestMiddleware + Default,
149 {
150 MiddlewareBindingBuilder::new(self, Arc::new(T::default()))
151 }
152
153 pub fn apply_instance<T>(&mut self, middleware: T) -> MiddlewareBindingBuilder<'_>
155 where
156 T: NestMiddleware,
157 {
158 MiddlewareBindingBuilder::new(self, Arc::new(middleware))
159 }
160
161 pub fn into_bindings(self) -> Vec<MiddlewareBinding> {
162 self.bindings
163 }
164}
165
166pub struct MiddlewareBindingBuilder<'a> {
168 consumer: &'a mut MiddlewareConsumer,
169 middleware: Arc<dyn NestMiddleware>,
170 exclude: Vec<MiddlewareRoute>,
171}
172
173impl<'a> MiddlewareBindingBuilder<'a> {
174 fn new(consumer: &'a mut MiddlewareConsumer, middleware: Arc<dyn NestMiddleware>) -> Self {
175 Self {
176 consumer,
177 middleware,
178 exclude: Vec::new(),
179 }
180 }
181
182 pub fn exclude<I, S>(mut self, routes: I) -> Self
184 where
185 I: IntoIterator<Item = S>,
186 S: Into<MiddlewareRoute>,
187 {
188 self.exclude = routes.into_iter().map(Into::into).collect();
189 self
190 }
191
192 pub fn for_all_routes(self) -> &'a mut MiddlewareConsumer {
194 self.register(Vec::new())
195 }
196
197 pub fn for_routes<I, S>(self, routes: I) -> &'a mut MiddlewareConsumer
199 where
200 I: IntoIterator<Item = S>,
201 S: Into<MiddlewareRoute>,
202 {
203 let include = routes.into_iter().map(Into::into).collect();
204 self.register(include)
205 }
206
207 fn register(self, include: Vec<MiddlewareRoute>) -> &'a mut MiddlewareConsumer {
208 framework_log_event(
209 "middleware_register",
210 &[
211 ("include", format!("{include:?}")),
212 ("exclude", format!("{:?}", self.exclude)),
213 ],
214 );
215 self.consumer.bindings.push(MiddlewareBinding {
216 middleware: self.middleware,
217 matcher: RouteMatcher {
218 include,
219 exclude: self.exclude,
220 },
221 });
222 self.consumer
223 }
224}
225
226pub fn run_middleware_chain(
227 middlewares: Arc<Vec<MiddlewareBinding>>,
228 index: usize,
229 req: Request<Body>,
230 terminal: NextFn,
231) -> NextFuture {
232 if index >= middlewares.len() {
233 return terminal(req);
234 }
235
236 let binding = middlewares[index].clone();
237 if !binding.matches(req.method(), req.uri().path()) {
238 return run_middleware_chain(middlewares, index + 1, req, terminal);
239 }
240
241 let middlewares_for_next = Arc::clone(&middlewares);
242 let terminal_for_next = Arc::clone(&terminal);
243 let next_fn: NextFn = Arc::new(move |next_req| {
244 run_middleware_chain(
245 Arc::clone(&middlewares_for_next),
246 index + 1,
247 next_req,
248 Arc::clone(&terminal_for_next),
249 )
250 });
251
252 binding.middleware.handle(req, next_fn)
253}
254
255fn normalize_path(path: String) -> String {
256 let trimmed = path.trim();
257 if trimmed.is_empty() || trimmed == "/" {
258 return "/".to_string();
259 }
260
261 if trimmed.starts_with('/') {
262 trimmed.trim_end_matches('/').to_string()
263 } else {
264 format!("/{}", trimmed.trim_end_matches('/'))
265 }
266}
267
268fn path_matches_prefix(path: &str, prefix: &str) -> bool {
269 if prefix == "/" {
270 return true;
271 }
272
273 path == prefix || path.starts_with(&format!("{prefix}/"))
274}
275
276#[cfg(test)]
277mod tests {
278 use axum::http::Method;
279
280 use super::{normalize_path, path_matches_prefix, MiddlewareRoute, RouteMatcher};
281
282 #[test]
283 fn matcher_supports_prefix_matching_and_excludes() {
284 let matcher = RouteMatcher {
285 include: vec![MiddlewareRoute::path("/api")],
286 exclude: vec![MiddlewareRoute::path("/api/health")],
287 };
288
289 assert!(matcher.matches(&Method::GET, "/api/users"));
290 assert!(!matcher.matches(&Method::GET, "/api/health"));
291 assert!(!matcher.matches(&Method::GET, "/admin"));
292 }
293
294 #[test]
295 fn normalize_path_handles_empty_and_trailing_slashes() {
296 assert_eq!(normalize_path("".to_string()), "/");
297 assert_eq!(normalize_path("/users/".to_string()), "/users");
298 assert_eq!(normalize_path("users".to_string()), "/users");
299 }
300
301 #[test]
302 fn prefix_matching_requires_boundary() {
303 assert!(path_matches_prefix("/users/1", "/users"));
304 assert!(path_matches_prefix("/users", "/users"));
305 assert!(!path_matches_prefix("/users-list", "/users"));
306 }
307
308 #[test]
309 fn matcher_can_target_specific_http_methods() {
310 let matcher = RouteMatcher {
311 include: vec![MiddlewareRoute::get("/admin")],
312 exclude: Vec::new(),
313 };
314
315 assert!(matcher.matches(&Method::GET, "/admin/users"));
316 assert!(!matcher.matches(&Method::POST, "/admin/users"));
317 }
318}