1use std::sync::Arc;
2
3use axum::{body::Body, extract::Request, http::Method};
4use nestforge_core::{framework_log_event, NextFn, NextFuture};
5
6pub trait NestMiddleware: Send + Sync + 'static {
7 fn handle(&self, req: Request<Body>, next: NextFn) -> NextFuture;
8}
9
10#[derive(Clone)]
11pub struct MiddlewareBinding {
12 middleware: Arc<dyn NestMiddleware>,
13 matcher: RouteMatcher,
14}
15
16impl MiddlewareBinding {
17 fn matches(&self, method: &Method, path: &str) -> bool {
18 self.matcher.matches(method, path)
19 }
20}
21
22#[derive(Clone, Default)]
23struct RouteMatcher {
24 include: Vec<MiddlewareRoute>,
25 exclude: Vec<MiddlewareRoute>,
26}
27
28impl RouteMatcher {
29 fn matches(&self, method: &Method, path: &str) -> bool {
30 if self
31 .exclude
32 .iter()
33 .any(|route| route.matches(method, path))
34 {
35 return false;
36 }
37
38 if self.include.is_empty() {
39 return true;
40 }
41
42 self.include
43 .iter()
44 .any(|route| route.matches(method, path))
45 }
46}
47
48#[derive(Clone, Debug, PartialEq, Eq)]
49pub struct MiddlewareRoute {
50 path: String,
51 methods: Option<Vec<Method>>,
52}
53
54impl MiddlewareRoute {
55 pub fn path(path: impl Into<String>) -> Self {
56 Self {
57 path: normalize_path(path.into()),
58 methods: None,
59 }
60 }
61
62 pub fn methods<I>(path: impl Into<String>, methods: I) -> Self
63 where
64 I: IntoIterator<Item = Method>,
65 {
66 Self {
67 path: normalize_path(path.into()),
68 methods: Some(methods.into_iter().collect()),
69 }
70 }
71
72 pub fn get(path: impl Into<String>) -> Self {
73 Self::methods(path, [Method::GET])
74 }
75
76 pub fn post(path: impl Into<String>) -> Self {
77 Self::methods(path, [Method::POST])
78 }
79
80 pub fn put(path: impl Into<String>) -> Self {
81 Self::methods(path, [Method::PUT])
82 }
83
84 pub fn delete(path: impl Into<String>) -> Self {
85 Self::methods(path, [Method::DELETE])
86 }
87
88 fn matches(&self, method: &Method, path: &str) -> bool {
89 if !path_matches_prefix(path, &self.path) {
90 return false;
91 }
92
93 match &self.methods {
94 Some(methods) => methods.iter().any(|candidate| candidate == method),
95 None => true,
96 }
97 }
98}
99
100impl From<&str> for MiddlewareRoute {
101 fn from(value: &str) -> Self {
102 Self::path(value)
103 }
104}
105
106impl From<String> for MiddlewareRoute {
107 fn from(value: String) -> Self {
108 Self::path(value)
109 }
110}
111
112#[derive(Default)]
113pub struct MiddlewareConsumer {
114 bindings: Vec<MiddlewareBinding>,
115}
116
117impl MiddlewareConsumer {
118 pub fn new() -> Self {
119 Self::default()
120 }
121
122 pub fn apply<T>(&mut self) -> MiddlewareBindingBuilder<'_>
123 where
124 T: NestMiddleware + Default,
125 {
126 MiddlewareBindingBuilder::new(self, Arc::new(T::default()))
127 }
128
129 pub fn apply_instance<T>(&mut self, middleware: T) -> MiddlewareBindingBuilder<'_>
130 where
131 T: NestMiddleware,
132 {
133 MiddlewareBindingBuilder::new(self, Arc::new(middleware))
134 }
135
136 pub fn into_bindings(self) -> Vec<MiddlewareBinding> {
137 self.bindings
138 }
139}
140
141pub struct MiddlewareBindingBuilder<'a> {
142 consumer: &'a mut MiddlewareConsumer,
143 middleware: Arc<dyn NestMiddleware>,
144 exclude: Vec<MiddlewareRoute>,
145}
146
147impl<'a> MiddlewareBindingBuilder<'a> {
148 fn new(consumer: &'a mut MiddlewareConsumer, middleware: Arc<dyn NestMiddleware>) -> Self {
149 Self {
150 consumer,
151 middleware,
152 exclude: Vec::new(),
153 }
154 }
155
156 pub fn exclude<I, S>(mut self, routes: I) -> Self
157 where
158 I: IntoIterator<Item = S>,
159 S: Into<MiddlewareRoute>,
160 {
161 self.exclude = routes.into_iter().map(Into::into).collect();
162 self
163 }
164
165 pub fn for_all_routes(self) -> &'a mut MiddlewareConsumer {
166 self.register(Vec::new())
167 }
168
169 pub fn for_routes<I, S>(self, routes: I) -> &'a mut MiddlewareConsumer
170 where
171 I: IntoIterator<Item = S>,
172 S: Into<MiddlewareRoute>,
173 {
174 let include = routes.into_iter().map(Into::into).collect();
175 self.register(include)
176 }
177
178 fn register(self, include: Vec<MiddlewareRoute>) -> &'a mut MiddlewareConsumer {
179 framework_log_event(
180 "middleware_register",
181 &[("include", format!("{include:?}")), ("exclude", format!("{:?}", self.exclude))],
182 );
183 self.consumer.bindings.push(MiddlewareBinding {
184 middleware: self.middleware,
185 matcher: RouteMatcher {
186 include,
187 exclude: self.exclude,
188 },
189 });
190 self.consumer
191 }
192}
193
194pub fn run_middleware_chain(
195 middlewares: Arc<Vec<MiddlewareBinding>>,
196 index: usize,
197 req: Request<Body>,
198 terminal: NextFn,
199) -> NextFuture {
200 if index >= middlewares.len() {
201 return terminal(req);
202 }
203
204 let binding = middlewares[index].clone();
205 if !binding.matches(req.method(), req.uri().path()) {
206 return run_middleware_chain(middlewares, index + 1, req, terminal);
207 }
208
209 let middlewares_for_next = Arc::clone(&middlewares);
210 let terminal_for_next = Arc::clone(&terminal);
211 let next_fn: NextFn = Arc::new(move |next_req| {
212 run_middleware_chain(
213 Arc::clone(&middlewares_for_next),
214 index + 1,
215 next_req,
216 Arc::clone(&terminal_for_next),
217 )
218 });
219
220 binding.middleware.handle(req, next_fn)
221}
222
223fn normalize_path(path: String) -> String {
224 let trimmed = path.trim();
225 if trimmed.is_empty() || trimmed == "/" {
226 return "/".to_string();
227 }
228
229 if trimmed.starts_with('/') {
230 trimmed.trim_end_matches('/').to_string()
231 } else {
232 format!("/{}", trimmed.trim_end_matches('/'))
233 }
234}
235
236fn path_matches_prefix(path: &str, prefix: &str) -> bool {
237 if prefix == "/" {
238 return true;
239 }
240
241 path == prefix || path.starts_with(&format!("{prefix}/"))
242}
243
244#[cfg(test)]
245mod tests {
246 use axum::http::Method;
247
248 use super::{normalize_path, path_matches_prefix, MiddlewareRoute, RouteMatcher};
249
250 #[test]
251 fn matcher_supports_prefix_matching_and_excludes() {
252 let matcher = RouteMatcher {
253 include: vec![MiddlewareRoute::path("/api")],
254 exclude: vec![MiddlewareRoute::path("/api/health")],
255 };
256
257 assert!(matcher.matches(&Method::GET, "/api/users"));
258 assert!(!matcher.matches(&Method::GET, "/api/health"));
259 assert!(!matcher.matches(&Method::GET, "/admin"));
260 }
261
262 #[test]
263 fn normalize_path_handles_empty_and_trailing_slashes() {
264 assert_eq!(normalize_path("".to_string()), "/");
265 assert_eq!(normalize_path("/users/".to_string()), "/users");
266 assert_eq!(normalize_path("users".to_string()), "/users");
267 }
268
269 #[test]
270 fn prefix_matching_requires_boundary() {
271 assert!(path_matches_prefix("/users/1", "/users"));
272 assert!(path_matches_prefix("/users", "/users"));
273 assert!(!path_matches_prefix("/users-list", "/users"));
274 }
275
276 #[test]
277 fn matcher_can_target_specific_http_methods() {
278 let matcher = RouteMatcher {
279 include: vec![MiddlewareRoute::get("/admin")],
280 exclude: Vec::new(),
281 };
282
283 assert!(matcher.matches(&Method::GET, "/admin/users"));
284 assert!(!matcher.matches(&Method::POST, "/admin/users"));
285 }
286}