1use actus_controller::{Controller, Params, RouteDef};
5use actus_reply::{Reply, WebError};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tracing::{debug, warn};
9
10#[derive(Default)]
12struct RouteNode {
13 children: HashMap<String, RouteNode>,
15 controller: Option<Arc<dyn Controller>>,
21}
22
23#[derive(Clone)]
28pub struct RouteMatch {
29 pub controller: Arc<dyn Controller>,
31 pub action: String,
34}
35
36#[derive(Clone, Debug, PartialEq, Eq)]
41pub struct RateLimitClass {
42 pub mount: String,
44 pub class: &'static str,
46}
47
48pub struct Router {
51 root: RouteNode,
52}
53
54impl Router {
55 pub fn match_controller(&self, path_parts: &[String]) -> Option<RouteMatch> {
64 let mut current_node = &self.root;
65 let mut longest_match: Option<(Arc<dyn Controller>, usize)> = None;
68
69 if let Some(ref controller) = current_node.controller {
72 longest_match = Some((controller.clone(), 0));
73 }
74
75 for (i, segment) in path_parts.iter().enumerate() {
76 match current_node.children.get(segment) {
77 Some(child_node) => {
78 current_node = child_node;
79 if let Some(ref controller) = current_node.controller {
80 longest_match = Some((controller.clone(), i + 1));
81 }
82 }
83 None => break,
86 }
87 }
88
89 longest_match.map(|(controller, prefix_len)| {
90 let action: String = if prefix_len >= path_parts.len() {
96 String::new()
97 } else {
98 path_parts[prefix_len..].join("/")
99 };
100 RouteMatch { controller, action }
101 })
102 }
103
104 pub async fn route(&self, path_parts: &[String], params: Params) -> Reply {
111 match self.match_controller(path_parts) {
112 Some(rm) => {
113 debug!(action = %rm.action, "routing to controller");
114 rm.controller.actus_dispatch(&rm.action, params).await
115 }
116 None => {
117 debug!(?path_parts, "no route matched");
118 Err(WebError::NotFound)
119 }
120 }
121 }
122
123 pub fn routes(&self) -> Vec<(String, RouteDef)> {
137 let mut out = Vec::new();
138 let mut prefix: Vec<String> = Vec::new();
139 Self::walk(&self.root, &mut prefix, &mut out);
140 out
141 }
142
143 fn walk(node: &RouteNode, prefix: &mut Vec<String>, out: &mut Vec<(String, RouteDef)>) {
144 if let Some(controller) = &node.controller {
145 let mount = prefix.join("/");
146 for rd in controller.actus_describe_routes() {
147 out.push((mount.clone(), rd));
148 }
149 }
150 let mut children: Vec<(&String, &RouteNode)> = node.children.iter().collect();
154 children.sort_by(|a, b| a.0.cmp(b.0));
155 for (seg, child) in children {
156 prefix.push(seg.clone());
157 Self::walk(child, prefix, out);
158 prefix.pop();
159 }
160 }
161
162 pub fn rate_limit_classes(&self) -> Vec<RateLimitClass> {
176 let mut out = Vec::new();
177 let mut prefix: Vec<String> = Vec::new();
178 Self::walk_classes(&self.root, &mut prefix, &mut out);
179 out
180 }
181
182 fn walk_classes(node: &RouteNode, prefix: &mut Vec<String>, out: &mut Vec<RateLimitClass>) {
183 if let Some(controller) = &node.controller
184 && let Some(class) = controller.actus_rate_limit()
185 {
186 out.push(RateLimitClass {
187 mount: prefix.join("/"),
188 class,
189 });
190 }
191 let mut children: Vec<(&String, &RouteNode)> = node.children.iter().collect();
192 children.sort_by(|a, b| a.0.cmp(b.0));
193 for (seg, child) in children {
194 prefix.push(seg.clone());
195 Self::walk_classes(child, prefix, out);
196 prefix.pop();
197 }
198 }
199}
200
201#[derive(Default)]
203pub struct RouterBuilder {
204 root: RouteNode,
205}
206
207impl RouterBuilder {
208 pub fn new() -> Self {
210 Self::default()
211 }
212
213 pub fn add_route(mut self, path: &str, controller: Arc<dyn Controller>) -> Self {
224 let path = path.trim_matches('/');
225 let parts: Vec<&str> = if path.is_empty() {
226 Vec::new()
227 } else {
228 path.split('/').collect()
229 };
230
231 let segments: &[&str] = match parts.split_last() {
233 Some((last, head)) if *last == "*" => head,
234 _ => parts.as_slice(),
235 };
236
237 if segments.contains(&"*") {
238 warn!(
239 route = path,
240 "'*' is only meaningful as the last segment of a route path; ignoring route"
241 );
242 return self;
243 }
244
245 let mut current_node = &mut self.root;
246 for part in segments {
247 current_node = current_node
248 .children
249 .entry((*part).to_string())
250 .or_default();
251 }
252 if let Some(prev) = ¤t_node.controller {
258 warn!(
259 route = path,
260 previous = prev.__name(),
261 new = controller.__name(),
262 "duplicate route mount: the later controller overwrites the earlier one",
263 );
264 }
265 current_node.controller = Some(controller);
266 self
267 }
268
269 pub fn build(self) -> Router {
271 Router { root: self.root }
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278 use actus_controller::Verb;
279 use bytes::Bytes;
280 use std::sync::Mutex;
281
282 type SpyLog = Arc<Mutex<Vec<(String, String)>>>;
284
285 struct Spy {
289 path: String,
290 log: SpyLog,
291 }
292
293 #[actus_controller::async_trait]
294 impl Controller for Spy {
295 async fn actus_dispatch(&self, action: &str, _params: Params) -> Reply {
296 self.log
297 .lock()
298 .unwrap()
299 .push((self.path.clone(), action.to_string()));
300 Err(WebError::NotFound) }
302 fn __name(&self) -> &'static str {
303 "spy"
304 }
305 }
306
307 fn build(routes: &[&str]) -> (Router, SpyLog) {
308 let log = Arc::new(Mutex::new(Vec::new()));
309 let mut b = RouterBuilder::new();
310 for r in routes {
311 b = b.add_route(
312 r,
313 Arc::new(Spy {
314 path: r.to_string(),
315 log: log.clone(),
316 }),
317 );
318 }
319 (b.build(), log)
320 }
321
322 fn empty_params() -> Params {
323 Params::new(
324 Verb::GET,
325 HashMap::new(),
326 None,
327 Bytes::new(),
328 HashMap::new(),
329 )
330 }
331
332 async fn hit(router: &Router, log: &SpyLog, path: &str) -> Option<(String, String)> {
335 log.lock().unwrap().clear();
336 let pp: Vec<String> = path
337 .trim_matches('/')
338 .split('/')
339 .filter(|s| !s.is_empty())
340 .map(String::from)
341 .collect();
342 let _ = router.route(&pp, empty_params()).await;
343 log.lock().unwrap().last().cloned()
344 }
345
346 fn at(path: &str, action: &str) -> Option<(String, String)> {
347 Some((path.to_string(), action.to_string()))
348 }
349
350 #[tokio::test]
351 async fn longest_prefix_wins_and_passes_the_remainder() {
352 let (r, log) = build(&["api/users", "api/users/admin"]);
353 assert_eq!(hit(&r, &log, "/api/users").await, at("api/users", ""));
354 assert_eq!(hit(&r, &log, "/api/users/42").await, at("api/users", "42"));
355 assert_eq!(
356 hit(&r, &log, "/api/users/42/posts").await,
357 at("api/users", "42/posts")
358 );
359 assert_eq!(
361 hit(&r, &log, "/api/users/admin").await,
362 at("api/users/admin", "")
363 );
364 assert_eq!(
365 hit(&r, &log, "/api/users/admin/x").await,
366 at("api/users/admin", "x")
367 );
368 assert_eq!(hit(&r, &log, "/api").await, None);
370 assert_eq!(hit(&r, &log, "/nope").await, None);
371 assert_eq!(hit(&r, &log, "/").await, None);
372 }
373
374 #[tokio::test]
375 async fn trailing_star_is_sugar_for_mounting_at_the_prefix() {
376 let (star, log_s) = build(&["api/folder/*"]);
377 let (plain, log_p) = build(&["api/folder"]);
378 for path in ["/api/folder", "/api/folder/x", "/api/folder/x/y/z"] {
379 assert_eq!(
380 hit(&star, &log_s, path).await.map(|(_, a)| a),
381 hit(&plain, &log_p, path).await.map(|(_, a)| a),
382 "`api/folder/*` and `api/folder` must route `{path}` identically"
383 );
384 }
385 assert_eq!(hit(&star, &log_s, "/api/folder").await.unwrap().1, "");
387 assert_eq!(
388 hit(&star, &log_s, "/api/folder/deep/path").await.unwrap().1,
389 "deep/path"
390 );
391 }
392
393 #[tokio::test]
394 async fn root_star_is_the_global_fallback() {
395 let (r, log) = build(&["*", "api/users"]);
396 assert_eq!(hit(&r, &log, "/api/users").await, at("api/users", ""));
398 assert_eq!(hit(&r, &log, "/api/users/9").await, at("api/users", "9"));
399 assert_eq!(hit(&r, &log, "/").await, at("*", ""));
401 assert_eq!(
402 hit(&r, &log, "/anything/here").await,
403 at("*", "anything/here")
404 );
405 assert_eq!(hit(&r, &log, "/api/missing").await, at("*", "api/missing"));
407 let (r2, log2) = build(&["", "api/users"]);
409 assert_eq!(hit(&r2, &log2, "/whatever").await, at("", "whatever"));
410 assert_eq!(hit(&r2, &log2, "/api/users").await, at("api/users", ""));
411 }
412
413 #[test]
414 fn match_controller_returns_some_for_matched_path_and_none_for_404() {
415 let log = Arc::new(Mutex::new(Vec::new()));
419 let router = RouterBuilder::new()
420 .add_route(
421 "api/users",
422 Arc::new(Spy {
423 path: "users".to_string(),
424 log: log.clone(),
425 }),
426 )
427 .add_route(
428 "api/users/admin",
429 Arc::new(Spy {
430 path: "admin".to_string(),
431 log: log.clone(),
432 }),
433 )
434 .build();
435
436 let pp: Vec<String> = vec!["api".into(), "users".into(), "admin".into()];
438 let rm = router.match_controller(&pp).expect("matches");
439 assert_eq!(rm.action, "");
440 assert_eq!(rm.controller.__name(), "spy");
441
442 let pp: Vec<String> = vec!["api".into(), "users".into(), "42".into()];
444 let rm = router.match_controller(&pp).expect("matches");
445 assert_eq!(rm.action, "42");
446
447 let pp: Vec<String> = vec!["api".into(), "missing".into()];
449 assert!(router.match_controller(&pp).is_none());
450
451 let pp: Vec<String> = vec![];
453 assert!(router.match_controller(&pp).is_none());
454 }
455
456 #[tokio::test]
457 async fn duplicate_mount_overwrites_earlier_one() {
458 let log = Arc::new(Mutex::new(Vec::new()));
462 let first = Arc::new(Spy {
463 path: "first".to_string(),
464 log: log.clone(),
465 });
466 let second = Arc::new(Spy {
467 path: "second".to_string(),
468 log: log.clone(),
469 });
470 let router = RouterBuilder::new()
471 .add_route("api/users", first)
472 .add_route("api/users", second)
473 .build();
474 let pp: Vec<String> = vec!["api".into(), "users".into()];
475 let _ = router.route(&pp, empty_params()).await;
476 assert_eq!(
477 log.lock().unwrap().last().map(|(p, _)| p.as_str()),
478 Some("second"),
479 );
480 }
481
482 struct Described {
486 routes: &'static [RouteDef],
487 }
488
489 #[actus_controller::async_trait]
490 impl Controller for Described {
491 async fn actus_dispatch(&self, _action: &str, _params: Params) -> Reply {
492 Err(WebError::NotFound)
493 }
494 fn __name(&self) -> &'static str {
495 "described"
496 }
497 fn actus_describe_routes(&self) -> Vec<RouteDef> {
498 self.routes.to_vec()
499 }
500 }
501
502 #[tokio::test]
503 async fn routes_introspection_returns_mount_paths_and_routedefs() {
504 static USERS_ROUTES: &[RouteDef] = &[
505 RouteDef {
506 pattern: "",
507 handler_id: "handler_0",
508 handler: "list",
509 verb: &[Verb::GET],
510 params: &[],
511 doc: None,
512 },
513 RouteDef {
514 pattern: "{id}",
515 handler_id: "handler_1",
516 handler: "get",
517 verb: &[Verb::GET],
518 params: &[],
519 doc: None,
520 },
521 ];
522 static HEALTH_ROUTES: &[RouteDef] = &[RouteDef {
523 pattern: "",
524 handler_id: "handler_0",
525 handler: "ping",
526 verb: actus_controller::DEFAULT_VERBS,
527 params: &[],
528 doc: None,
529 }];
530
531 let router = RouterBuilder::new()
532 .add_route(
533 "api/users",
534 Arc::new(Described {
535 routes: USERS_ROUTES,
536 }),
537 )
538 .add_route(
539 "health",
540 Arc::new(Described {
541 routes: HEALTH_ROUTES,
542 }),
543 )
544 .build();
545
546 let pairs = router.routes();
547 let mounts: Vec<&str> = pairs.iter().map(|(m, _)| m.as_str()).collect();
549 let handlers: Vec<&str> = pairs.iter().map(|(_, r)| r.handler).collect();
550 assert_eq!(
551 mounts,
552 vec!["api/users", "api/users", "health"],
553 "DFS sorts child segments alphabetically — 'api' before 'health'",
554 );
555 assert_eq!(handlers, vec!["list", "get", "ping"]);
556 }
557
558 struct Classed(&'static str);
562
563 #[actus_controller::async_trait]
564 impl Controller for Classed {
565 async fn actus_dispatch(&self, _action: &str, _params: Params) -> Reply {
566 Err(WebError::NotFound)
567 }
568 fn __name(&self) -> &'static str {
569 "classed"
570 }
571 fn actus_rate_limit(&self) -> Option<&'static str> {
572 Some(self.0)
573 }
574 }
575
576 #[test]
577 fn rate_limit_classes_lists_only_classed_controllers_with_mounts() {
578 let log = Arc::new(Mutex::new(Vec::new()));
582 let router = RouterBuilder::new()
583 .add_route("api/auth", Arc::new(Classed("auth")))
584 .add_route(
585 "api/health",
586 Arc::new(Spy {
587 path: "health".into(),
588 log: log.clone(),
589 }),
590 )
591 .add_route("api/tasks", Arc::new(Classed("tasks")))
592 .build();
593
594 assert_eq!(
595 router.rate_limit_classes(),
596 vec![
597 RateLimitClass {
598 mount: "api/auth".to_string(),
599 class: "auth",
600 },
601 RateLimitClass {
602 mount: "api/tasks".to_string(),
603 class: "tasks",
604 },
605 ],
606 );
607 assert!(
609 router
610 .rate_limit_classes()
611 .iter()
612 .all(|rlc| rlc.mount != "api/health"),
613 );
614 }
615
616 #[tokio::test]
617 async fn star_only_in_last_position() {
618 let (r, log) = build(&["a/*/b"]);
622 assert_eq!(hit(&r, &log, "/a/x/b").await, None);
623 assert_eq!(hit(&r, &log, "/a/*/b").await, None);
624 }
625}