Skip to main content

az_axum_route/
lib.rs

1use std::collections::BTreeMap;
2
3use axum::{routing::MethodRouter, Router};
4pub use inventory;
5use serde::Serialize;
6
7pub struct RouteDef<S> {
8    pub path: &'static str,
9    pub method_router: MethodRouter<S>,
10}
11
12impl<S> RouteDef<S> {
13    pub fn new(path: &'static str, method_router: MethodRouter<S>) -> Self {
14        Self {
15            path,
16            method_router,
17        }
18    }
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize)]
22pub struct RouteDescriptor {
23    pub scope: &'static str,
24    pub method: &'static str,
25    pub path: &'static str,
26    pub handler: &'static str,
27    pub guard: Option<RouteGuard>,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize)]
31#[serde(rename_all = "snake_case")]
32pub enum RouteGuard {
33    DeveloperMode,
34}
35
36pub fn mount_routes<S, I>(router: Router<S>, routes: I) -> Router<S>
37where
38    S: Clone + Send + Sync + 'static,
39    I: IntoIterator<Item = RouteDef<S>>,
40{
41    mount_routes_with_prefix(router, "", routes)
42}
43
44pub fn mount_routes_with_prefix<S, I>(mut router: Router<S>, prefix: &str, routes: I) -> Router<S>
45where
46    S: Clone + Send + Sync + 'static,
47    I: IntoIterator<Item = RouteDef<S>>,
48{
49    let mut grouped = BTreeMap::<String, MethodRouter<S>>::new();
50    for route in routes {
51        let path = join_route_path(prefix, route.path);
52        let merged = match grouped.remove(&path) {
53            Some(existing) => existing.merge(route.method_router),
54            None => route.method_router,
55        };
56        grouped.insert(path, merged);
57    }
58
59    for (path, method_router) in grouped {
60        router = router.route(&path, method_router);
61    }
62
63    router
64}
65
66fn join_route_path(prefix: &str, path: &str) -> String {
67    let normalized_prefix = normalize_path_segment(prefix);
68    let normalized_path = normalize_path_segment(path);
69    match (normalized_prefix.as_str(), normalized_path.as_str()) {
70        ("", "") => "/".to_owned(),
71        ("", path) => path.to_owned(),
72        (prefix, "") => prefix.to_owned(),
73        (prefix, path) => format!("{prefix}{path}"),
74    }
75}
76
77fn normalize_path_segment(path: &str) -> String {
78    if path.is_empty() || path == "/" {
79        return String::new();
80    }
81
82    let trimmed = path.trim_matches('/');
83    if trimmed.is_empty() {
84        String::new()
85    } else {
86        format!("/{trimmed}")
87    }
88}
89
90#[macro_export]
91macro_rules! declare_route_registry {
92    ($state:ty) => {
93        pub(crate) struct RouteRegistration {
94            pub method: &'static str,
95            pub path: &'static str,
96            pub scope: &'static str,
97            pub handler: &'static str,
98            pub guard: ::std::option::Option<$crate::RouteGuard>,
99            pub route: fn() -> $crate::RouteDef<$state>,
100        }
101
102        $crate::inventory::collect!(RouteRegistration);
103
104        pub(crate) fn registered_route_descriptors() -> ::std::vec::Vec<$crate::RouteDescriptor> {
105            let mut routes = $crate::inventory::iter::<RouteRegistration>
106                .into_iter()
107                .map(|registration| $crate::RouteDescriptor {
108                    scope: registration.scope,
109                    method: registration.method,
110                    path: registration.path,
111                    handler: registration.handler,
112                    guard: registration.guard,
113                })
114                .collect::<::std::vec::Vec<_>>();
115            routes.sort();
116            routes
117        }
118
119        pub(crate) fn assert_no_conflicting_registered_routes() {
120            let mut seen = ::std::collections::BTreeMap::<
121                (&'static str, &'static str, &'static str),
122                &'static str,
123            >::new();
124
125            for route in registered_route_descriptors() {
126                let key = (route.scope, route.method, route.path);
127                if let Some(previous_handler) = seen.insert(key, route.handler) {
128                    panic!(
129                        "duplicate registered route detected for scope `{}`, method `{}`, path `{}`: `{}` and `{}`",
130                        route.scope, route.method, route.path, previous_handler, route.handler
131                    );
132                }
133            }
134        }
135
136        pub(crate) fn mount_registered_routes(
137            router: ::axum::Router<$state>,
138            scope: &str,
139        ) -> ::axum::Router<$state> {
140            $crate::mount_routes(
141                router,
142                $crate::inventory::iter::<RouteRegistration>
143                    .into_iter()
144                    .filter(|registration| registration.scope == scope)
145                    .map(|registration| (registration.route)()),
146            )
147        }
148
149        pub(crate) fn mount_registered_routes_with_prefix(
150            router: ::axum::Router<$state>,
151            scope: &str,
152            prefix: &str,
153        ) -> ::axum::Router<$state> {
154            $crate::mount_routes_with_prefix(
155                router,
156                prefix,
157                $crate::inventory::iter::<RouteRegistration>
158                    .into_iter()
159                    .filter(|registration| registration.scope == scope)
160                    .map(|registration| (registration.route)()),
161            )
162        }
163    };
164}
165
166#[cfg(test)]
167mod tests {
168    use super::{join_route_path, normalize_path_segment, RouteDescriptor, RouteGuard};
169
170    #[test]
171    fn normalize_path_segment_handles_root_and_slashes() {
172        assert_eq!(normalize_path_segment(""), "");
173        assert_eq!(normalize_path_segment("/"), "");
174        assert_eq!(normalize_path_segment("api/v1"), "/api/v1");
175        assert_eq!(normalize_path_segment("/api/v1/"), "/api/v1");
176    }
177
178    #[test]
179    fn join_route_path_combines_prefix_and_route_path() {
180        assert_eq!(join_route_path("", "/system/health"), "/system/health");
181        assert_eq!(
182            join_route_path("/api/v1", "/system/health"),
183            "/api/v1/system/health"
184        );
185        assert_eq!(
186            join_route_path("/api/v1/", "system/health"),
187            "/api/v1/system/health"
188        );
189        assert_eq!(join_route_path("/", "/system/health"), "/system/health");
190        assert_eq!(join_route_path("/api/v1", "/"), "/api/v1");
191        assert_eq!(join_route_path("/", "/"), "/");
192    }
193
194    #[test]
195    fn route_descriptor_sort_is_stable_for_route_listing() {
196        let mut routes = vec![
197            RouteDescriptor {
198                scope: "system",
199                method: "POST",
200                path: "/b",
201                handler: "post_b",
202                guard: None,
203            },
204            RouteDescriptor {
205                scope: "asset",
206                method: "GET",
207                path: "/a",
208                handler: "get_a",
209                guard: Some(RouteGuard::DeveloperMode),
210            },
211        ];
212        routes.sort();
213
214        assert_eq!(routes[0].scope, "asset");
215        assert_eq!(routes[1].scope, "system");
216    }
217}