1use std::sync::Arc;
4
5use axum::body::Body;
6use axum::http::{Method, Request};
7use axum::routing::{any, delete, get, patch, post, put, MethodRouter};
8use axum::Router as AxumRouter;
9
10use crate::container::Container;
11use crate::middleware::MiddlewareRegistry;
12
13pub struct Router {
14 inner: AxumRouter<Container>,
15 registry: MiddlewareRegistry,
16 middleware_stack: Vec<String>,
17 prefix: String,
18 routes: Vec<RouteInfo>,
19}
20
21#[derive(Debug, Clone)]
25pub struct RouteInfo {
26 pub method: Method,
27 pub path: String,
28 pub middleware: Vec<String>,
29}
30
31impl Router {
32 pub fn new(registry: MiddlewareRegistry) -> Self {
33 Self {
34 inner: AxumRouter::new(),
35 registry,
36 middleware_stack: Vec::new(),
37 prefix: String::new(),
38 routes: Vec::new(),
39 }
40 }
41
42 pub fn route_infos(&self) -> &[RouteInfo] {
44 &self.routes
45 }
46
47 pub fn with_state(self) -> AxumRouter<Container> {
48 self.inner
49 }
50
51 pub fn finish(self) -> (AxumRouter<Container>, Vec<RouteInfo>) {
54 (self.inner, self.routes)
55 }
56
57 fn record(&mut self, method: Method, path: &str) {
58 self.routes.push(RouteInfo {
59 method,
60 path: self.full_path(path),
61 middleware: self.middleware_stack.clone(),
62 });
63 }
64
65 fn full_path(&self, path: &str) -> String {
66 if self.prefix.is_empty() {
67 path.to_string()
68 } else {
69 format!("{}{}", self.prefix.trim_end_matches('/'), path)
70 }
71 }
72
73 fn wrap_method_router(&self, mr: MethodRouter<Container>) -> MethodRouter<Container> {
74 let mut mr = mr;
75 for name in self.middleware_stack.iter().rev() {
76 if let Some(mw) = self.registry.get(name) {
77 let mw = mw.clone();
78 let layer = axum::middleware::from_fn(
79 move |req: Request<Body>, next: axum::middleware::Next| {
80 let mw = mw.clone();
81 async move { crate::middleware::invoke(mw, req, next).await }
82 },
83 );
84 mr = mr.layer(layer);
85 } else {
86 tracing::warn!(name, "unknown middleware referenced in route; ignoring");
87 }
88 }
89 mr
90 }
91
92 pub fn get<H, T>(mut self, path: &str, handler: H) -> Self
93 where
94 H: axum::handler::Handler<T, Container>,
95 T: 'static,
96 {
97 self.record(Method::GET, path);
98 let mr = self.wrap_method_router(get(handler));
99 let full = self.full_path(path);
100 self.inner = self.inner.route(&full, mr);
101 self
102 }
103
104 pub fn post<H, T>(mut self, path: &str, handler: H) -> Self
105 where
106 H: axum::handler::Handler<T, Container>,
107 T: 'static,
108 {
109 self.record(Method::POST, path);
110 let mr = self.wrap_method_router(post(handler));
111 let full = self.full_path(path);
112 self.inner = self.inner.route(&full, mr);
113 self
114 }
115
116 pub fn put<H, T>(mut self, path: &str, handler: H) -> Self
117 where
118 H: axum::handler::Handler<T, Container>,
119 T: 'static,
120 {
121 self.record(Method::PUT, path);
122 let mr = self.wrap_method_router(put(handler));
123 let full = self.full_path(path);
124 self.inner = self.inner.route(&full, mr);
125 self
126 }
127
128 pub fn patch<H, T>(mut self, path: &str, handler: H) -> Self
129 where
130 H: axum::handler::Handler<T, Container>,
131 T: 'static,
132 {
133 self.record(Method::PATCH, path);
134 let mr = self.wrap_method_router(patch(handler));
135 let full = self.full_path(path);
136 self.inner = self.inner.route(&full, mr);
137 self
138 }
139
140 pub fn delete<H, T>(mut self, path: &str, handler: H) -> Self
141 where
142 H: axum::handler::Handler<T, Container>,
143 T: 'static,
144 {
145 self.record(Method::DELETE, path);
146 let mr = self.wrap_method_router(delete(handler));
147 let full = self.full_path(path);
148 self.inner = self.inner.route(&full, mr);
149 self
150 }
151
152 pub fn any<H, T>(mut self, path: &str, handler: H) -> Self
153 where
154 H: axum::handler::Handler<T, Container>,
155 T: 'static,
156 {
157 self.record(Method::OPTIONS, path); let mr = self.wrap_method_router(any(handler));
159 let full = self.full_path(path);
160 self.inner = self.inner.route(&full, mr);
161 self
162 }
163
164 pub fn prefix(mut self, prefix: impl Into<String>) -> Self {
165 self.prefix = prefix.into();
166 self
167 }
168
169 pub fn middleware<I, S>(mut self, names: I) -> Self
170 where
171 I: IntoIterator<Item = S>,
172 S: Into<String>,
173 {
174 for name in names {
175 self.middleware_stack.push(name.into());
176 }
177 self
178 }
179
180 pub fn group<F>(mut self, build: F) -> Self
181 where
182 F: FnOnce(Router) -> Router,
183 {
184 let inner_router = Router {
185 inner: AxumRouter::new(),
186 registry: self.registry.clone(),
187 middleware_stack: self.middleware_stack.clone(),
188 prefix: self.prefix.clone(),
189 routes: Vec::new(),
190 };
191 let built = build(inner_router);
192 self.routes.extend(built.routes);
193 self.inner = self.inner.merge(built.inner);
194 self
195 }
196
197 pub fn merge(mut self, other: Router) -> Self {
198 self.routes.extend(other.routes);
199 self.inner = self.inner.merge(other.inner);
200 self
201 }
202
203 pub fn nest(mut self, prefix: &str, other: Router) -> Self {
204 for mut r in other.routes {
205 r.path = format!("{}{}", prefix.trim_end_matches('/'), r.path);
206 self.routes.push(r);
207 }
208 self.inner = self.inner.nest(prefix, other.inner);
209 self
210 }
211
212 pub fn with_route_infos(mut self, infos: Vec<RouteInfo>) -> Self {
215 self.routes.extend(infos);
216 self
217 }
218
219 pub fn layer<L>(mut self, layer: L) -> Self
223 where
224 L: tower::Layer<axum::routing::Route> + Clone + Send + Sync + 'static,
225 L::Service: tower::Service<
226 axum::http::Request<axum::body::Body>,
227 Response = axum::http::Response<axum::body::Body>,
228 Error = std::convert::Infallible,
229 > + Clone
230 + Send
231 + 'static,
232 <L::Service as tower::Service<axum::http::Request<axum::body::Body>>>::Future:
233 Send + 'static,
234 {
235 self.inner = self.inner.layer(layer);
236 self
237 }
238
239 pub fn adopt(self, other: AxumRouter<Container>) -> Self {
242 Router {
243 inner: self.inner.merge(other),
244 registry: self.registry,
245 middleware_stack: self.middleware_stack,
246 prefix: self.prefix,
247 routes: self.routes,
248 }
249 }
250}
251
252#[derive(Debug, Clone)]
254pub struct Route {
255 pub name: Option<String>,
256 pub method: Method,
257 pub path: String,
258}
259
260#[derive(Default, Clone)]
262pub struct NamedRoutes {
263 routes: Arc<parking_lot::RwLock<indexmap::IndexMap<String, Route>>>,
264}
265
266impl NamedRoutes {
267 pub fn new() -> Self {
268 Self::default()
269 }
270
271 pub fn add(&self, route: Route) {
272 if let Some(name) = route.name.clone() {
273 self.routes.write().insert(name, route);
274 }
275 }
276
277 pub fn url(&self, name: &str, params: &[&str]) -> Option<String> {
278 let routes = self.routes.read();
279 let route = routes.get(name)?;
280 let mut path = route.path.clone();
281 for p in params {
282 if let Some(start) = path.find('{') {
283 if let Some(end) = path[start..].find('}') {
284 path.replace_range(start..=start + end, p);
285 }
286 }
287 }
288 Some(path)
289 }
290}