1use std::sync::Arc;
4
5use axum::body::Body;
6use axum::http::{Method, Request};
7use axum::routing::{any, delete, get, patch, post, put, MethodRouter};
8use axum::Router as AxumRouter;
9
10use crate::container::Container;
11use crate::middleware::MiddlewareRegistry;
12
13pub struct Router {
14 inner: AxumRouter<Container>,
15 registry: MiddlewareRegistry,
16 middleware_stack: Vec<String>,
17 prefix: String,
18 routes: Vec<RouteInfo>,
19}
20
21#[derive(Debug, Clone)]
25pub struct RouteInfo {
26 pub method: Method,
27 pub path: String,
28 pub middleware: Vec<String>,
29}
30
31impl Router {
32 pub fn new(registry: MiddlewareRegistry) -> Self {
33 Self {
34 inner: AxumRouter::new(),
35 registry,
36 middleware_stack: Vec::new(),
37 prefix: String::new(),
38 routes: Vec::new(),
39 }
40 }
41
42 pub fn route_infos(&self) -> &[RouteInfo] {
44 &self.routes
45 }
46
47 pub fn with_state(self) -> AxumRouter<Container> {
48 self.inner
49 }
50
51 pub fn finish(self) -> (AxumRouter<Container>, Vec<RouteInfo>) {
54 (self.inner, self.routes)
55 }
56
57 fn record(&mut self, method: Method, path: &str) {
58 check_path_syntax(method.as_str(), path);
59 self.routes.push(RouteInfo {
60 method,
61 path: self.full_path(path),
62 middleware: self.middleware_stack.clone(),
63 });
64 }
65
66 fn full_path(&self, path: &str) -> String {
67 if self.prefix.is_empty() {
68 path.to_string()
69 } else {
70 format!("{}{}", self.prefix.trim_end_matches('/'), path)
71 }
72 }
73
74 fn wrap_method_router(&self, mr: MethodRouter<Container>) -> MethodRouter<Container> {
75 let mut mr = mr;
76 for name in self.middleware_stack.iter().rev() {
77 if let Some(mw) = self.registry.get(name) {
78 let mw = mw.clone();
79 let layer = axum::middleware::from_fn(
80 move |req: Request<Body>, next: axum::middleware::Next| {
81 let mw = mw.clone();
82 async move { crate::middleware::invoke(mw, req, next).await }
83 },
84 );
85 mr = mr.layer(layer);
86 } else {
87 tracing::warn!(name, "unknown middleware referenced in route; ignoring");
88 }
89 }
90 mr
91 }
92
93 pub fn get<H, T>(mut self, path: &str, handler: H) -> Self
94 where
95 H: axum::handler::Handler<T, Container>,
96 T: 'static,
97 {
98 self.record(Method::GET, path);
99 let mr = self.wrap_method_router(get(handler));
100 let full = self.full_path(path);
101 self.inner = self.inner.route(&full, mr);
102 self
103 }
104
105 pub fn post<H, T>(mut self, path: &str, handler: H) -> Self
106 where
107 H: axum::handler::Handler<T, Container>,
108 T: 'static,
109 {
110 self.record(Method::POST, path);
111 let mr = self.wrap_method_router(post(handler));
112 let full = self.full_path(path);
113 self.inner = self.inner.route(&full, mr);
114 self
115 }
116
117 pub fn put<H, T>(mut self, path: &str, handler: H) -> Self
118 where
119 H: axum::handler::Handler<T, Container>,
120 T: 'static,
121 {
122 self.record(Method::PUT, path);
123 let mr = self.wrap_method_router(put(handler));
124 let full = self.full_path(path);
125 self.inner = self.inner.route(&full, mr);
126 self
127 }
128
129 pub fn patch<H, T>(mut self, path: &str, handler: H) -> Self
130 where
131 H: axum::handler::Handler<T, Container>,
132 T: 'static,
133 {
134 self.record(Method::PATCH, path);
135 let mr = self.wrap_method_router(patch(handler));
136 let full = self.full_path(path);
137 self.inner = self.inner.route(&full, mr);
138 self
139 }
140
141 pub fn delete<H, T>(mut self, path: &str, handler: H) -> Self
142 where
143 H: axum::handler::Handler<T, Container>,
144 T: 'static,
145 {
146 self.record(Method::DELETE, path);
147 let mr = self.wrap_method_router(delete(handler));
148 let full = self.full_path(path);
149 self.inner = self.inner.route(&full, mr);
150 self
151 }
152
153 pub fn any<H, T>(mut self, path: &str, handler: H) -> Self
154 where
155 H: axum::handler::Handler<T, Container>,
156 T: 'static,
157 {
158 self.record(Method::OPTIONS, path); let mr = self.wrap_method_router(any(handler));
160 let full = self.full_path(path);
161 self.inner = self.inner.route(&full, mr);
162 self
163 }
164
165 pub fn prefix(mut self, prefix: impl Into<String>) -> Self {
166 self.prefix = prefix.into();
167 self
168 }
169
170 pub fn middleware<I, S>(mut self, names: I) -> Self
171 where
172 I: IntoIterator<Item = S>,
173 S: Into<String>,
174 {
175 for name in names {
176 self.middleware_stack.push(name.into());
177 }
178 self
179 }
180
181 pub fn group<F>(mut self, build: F) -> Self
182 where
183 F: FnOnce(Router) -> Router,
184 {
185 let inner_router = Router {
186 inner: AxumRouter::new(),
187 registry: self.registry.clone(),
188 middleware_stack: self.middleware_stack.clone(),
189 prefix: self.prefix.clone(),
190 routes: Vec::new(),
191 };
192 let built = build(inner_router);
193 self.routes.extend(built.routes);
194 self.inner = self.inner.merge(built.inner);
195 self
196 }
197
198 pub fn merge(mut self, other: Router) -> Self {
199 self.routes.extend(other.routes);
200 self.inner = self.inner.merge(other.inner);
201 self
202 }
203
204 pub fn nest(mut self, prefix: &str, other: Router) -> Self {
205 for mut r in other.routes {
206 r.path = format!("{}{}", prefix.trim_end_matches('/'), r.path);
207 self.routes.push(r);
208 }
209 self.inner = self.inner.nest(prefix, other.inner);
210 self
211 }
212
213 pub fn with_route_infos(mut self, infos: Vec<RouteInfo>) -> Self {
216 self.routes.extend(infos);
217 self
218 }
219
220 pub fn layer<L>(mut self, layer: L) -> Self
224 where
225 L: tower::Layer<axum::routing::Route> + Clone + Send + Sync + 'static,
226 L::Service: tower::Service<
227 axum::http::Request<axum::body::Body>,
228 Response = axum::http::Response<axum::body::Body>,
229 Error = std::convert::Infallible,
230 > + Clone
231 + Send
232 + 'static,
233 <L::Service as tower::Service<axum::http::Request<axum::body::Body>>>::Future:
234 Send + 'static,
235 {
236 self.inner = self.inner.layer(layer);
237 self
238 }
239
240 pub fn adopt(self, other: AxumRouter<Container>) -> Self {
243 Router {
244 inner: self.inner.merge(other),
245 registry: self.registry,
246 middleware_stack: self.middleware_stack,
247 prefix: self.prefix,
248 routes: self.routes,
249 }
250 }
251}
252
253#[derive(Debug, Clone)]
255pub struct Route {
256 pub name: Option<String>,
257 pub method: Method,
258 pub path: String,
259}
260
261#[derive(Default, Clone)]
263pub struct NamedRoutes {
264 routes: Arc<parking_lot::RwLock<indexmap::IndexMap<String, Route>>>,
265}
266
267impl NamedRoutes {
268 pub fn new() -> Self {
269 Self::default()
270 }
271
272 pub fn add(&self, route: Route) {
273 if let Some(name) = route.name.clone() {
274 self.routes.write().insert(name, route);
275 }
276 }
277
278 pub fn url(&self, name: &str, params: &[&str]) -> Option<String> {
279 let routes = self.routes.read();
280 let route = routes.get(name)?;
281 let mut path = route.path.clone();
282 for p in params {
283 if let Some(start) = path.find('{') {
284 if let Some(end) = path[start..].find('}') {
285 path.replace_range(start..=start + end, p);
286 }
287 }
288 }
289 Some(path)
290 }
291}
292
293fn check_path_syntax(method: &str, path: &str) {
301 for (i, b) in path.bytes().enumerate() {
302 if b != b'{' {
303 continue;
304 }
305 let rest = &path[i + 1..];
306 let Some(end_offset) = rest.find('}') else {
307 continue;
308 };
309 let name = &rest[..end_offset];
310 if name.is_empty() || name.contains('/') {
313 continue;
314 }
315 if name.starts_with('{') || name.starts_with('#') {
318 continue;
319 }
320 let suggested = path.replacen(&format!("{{{name}}}"), &format!(":{name}"), 1);
321 panic!(
322 "{method} route `{path}` uses axum-0.8 syntax `{{{name}}}` but Anvilforge \
323 runs on axum 0.7. Use `:{name}` instead:\n\n \"{suggested}\"\n\n\
324 Without this rewrite, requests to that path segment 404 silently."
325 );
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332
333 #[test]
334 fn accepts_axum_07_colon_syntax() {
335 check_path_syntax("GET", "/handles/:handle");
336 check_path_syntax("GET", "/users/:id/posts/:post_id");
337 check_path_syntax("GET", "/.well-known/sidevers/resolve/:handle");
338 }
339
340 #[test]
341 fn accepts_literal_paths() {
342 check_path_syntax("GET", "/health");
343 check_path_syntax("GET", "/");
344 check_path_syntax("POST", "/handles/claim");
345 }
346
347 #[test]
348 #[should_panic(expected = "uses axum-0.8 syntax `{handle}`")]
349 fn rejects_axum_08_brace_syntax() {
350 check_path_syntax("GET", "/handles/{handle}");
351 }
352
353 #[test]
354 #[should_panic(expected = "Use `:id` instead")]
355 fn suggests_concrete_rewrite() {
356 check_path_syntax("GET", "/users/{id}/posts");
357 }
358
359 #[test]
360 fn ignores_empty_braces() {
361 check_path_syntax("GET", "/foo{}/bar");
363 }
364}