1use std::sync::Arc;
4
5use axum::body::Body;
6use axum::http::{Method, Request};
7use axum::routing::{any, delete, get, patch, post, put, MethodRouter};
8use axum::Router as AxumRouter;
9
10use crate::container::Container;
11use crate::middleware::MiddlewareRegistry;
12
13pub struct Router {
14 inner: AxumRouter<Container>,
15 registry: MiddlewareRegistry,
16 middleware_stack: Vec<String>,
17 prefix: String,
18}
19
20impl Router {
21 pub fn new(registry: MiddlewareRegistry) -> Self {
22 Self {
23 inner: AxumRouter::new(),
24 registry,
25 middleware_stack: Vec::new(),
26 prefix: String::new(),
27 }
28 }
29
30 pub fn with_state(self) -> AxumRouter<Container> {
31 self.inner
32 }
33
34 fn full_path(&self, path: &str) -> String {
35 if self.prefix.is_empty() {
36 path.to_string()
37 } else {
38 format!("{}{}", self.prefix.trim_end_matches('/'), path)
39 }
40 }
41
42 fn wrap_method_router(&self, mr: MethodRouter<Container>) -> MethodRouter<Container> {
43 let mut mr = mr;
44 for name in self.middleware_stack.iter().rev() {
45 if let Some(mw) = self.registry.get(name) {
46 let mw = mw.clone();
47 let layer = axum::middleware::from_fn(move |req: Request<Body>, next: axum::middleware::Next| {
48 let mw = mw.clone();
49 async move {
50 crate::middleware::invoke(mw, req, next).await
51 }
52 });
53 mr = mr.layer(layer);
54 } else {
55 tracing::warn!(name, "unknown middleware referenced in route; ignoring");
56 }
57 }
58 mr
59 }
60
61 pub fn get<H, T>(mut self, path: &str, handler: H) -> Self
62 where
63 H: axum::handler::Handler<T, Container>,
64 T: 'static,
65 {
66 let mr = self.wrap_method_router(get(handler));
67 let full = self.full_path(path);
68 self.inner = self.inner.route(&full, mr);
69 self
70 }
71
72 pub fn post<H, T>(mut self, path: &str, handler: H) -> Self
73 where
74 H: axum::handler::Handler<T, Container>,
75 T: 'static,
76 {
77 let mr = self.wrap_method_router(post(handler));
78 let full = self.full_path(path);
79 self.inner = self.inner.route(&full, mr);
80 self
81 }
82
83 pub fn put<H, T>(mut self, path: &str, handler: H) -> Self
84 where
85 H: axum::handler::Handler<T, Container>,
86 T: 'static,
87 {
88 let mr = self.wrap_method_router(put(handler));
89 let full = self.full_path(path);
90 self.inner = self.inner.route(&full, mr);
91 self
92 }
93
94 pub fn patch<H, T>(mut self, path: &str, handler: H) -> Self
95 where
96 H: axum::handler::Handler<T, Container>,
97 T: 'static,
98 {
99 let mr = self.wrap_method_router(patch(handler));
100 let full = self.full_path(path);
101 self.inner = self.inner.route(&full, mr);
102 self
103 }
104
105 pub fn delete<H, T>(mut self, path: &str, handler: H) -> Self
106 where
107 H: axum::handler::Handler<T, Container>,
108 T: 'static,
109 {
110 let mr = self.wrap_method_router(delete(handler));
111 let full = self.full_path(path);
112 self.inner = self.inner.route(&full, mr);
113 self
114 }
115
116 pub fn any<H, T>(mut self, path: &str, handler: H) -> Self
117 where
118 H: axum::handler::Handler<T, Container>,
119 T: 'static,
120 {
121 let mr = self.wrap_method_router(any(handler));
122 let full = self.full_path(path);
123 self.inner = self.inner.route(&full, mr);
124 self
125 }
126
127 pub fn prefix(mut self, prefix: impl Into<String>) -> Self {
128 self.prefix = prefix.into();
129 self
130 }
131
132 pub fn middleware<I, S>(mut self, names: I) -> Self
133 where
134 I: IntoIterator<Item = S>,
135 S: Into<String>,
136 {
137 for name in names {
138 self.middleware_stack.push(name.into());
139 }
140 self
141 }
142
143 pub fn group<F>(mut self, build: F) -> Self
144 where
145 F: FnOnce(Router) -> Router,
146 {
147 let inner_router = Router {
148 inner: AxumRouter::new(),
149 registry: self.registry.clone(),
150 middleware_stack: self.middleware_stack.clone(),
151 prefix: self.prefix.clone(),
152 };
153 let built = build(inner_router);
154 self.inner = self.inner.merge(built.inner);
155 self
156 }
157
158 pub fn merge(mut self, other: Router) -> Self {
159 self.inner = self.inner.merge(other.inner);
160 self
161 }
162
163 pub fn nest(mut self, prefix: &str, other: Router) -> Self {
164 self.inner = self.inner.nest(prefix, other.inner);
165 self
166 }
167}
168
169#[derive(Debug, Clone)]
171pub struct Route {
172 pub name: Option<String>,
173 pub method: Method,
174 pub path: String,
175}
176
177#[derive(Default, Clone)]
179pub struct NamedRoutes {
180 routes: Arc<parking_lot::RwLock<indexmap::IndexMap<String, Route>>>,
181}
182
183impl NamedRoutes {
184 pub fn new() -> Self {
185 Self::default()
186 }
187
188 pub fn add(&self, route: Route) {
189 if let Some(name) = route.name.clone() {
190 self.routes.write().insert(name, route);
191 }
192 }
193
194 pub fn url(&self, name: &str, params: &[&str]) -> Option<String> {
195 let routes = self.routes.read();
196 let route = routes.get(name)?;
197 let mut path = route.path.clone();
198 for p in params {
199 if let Some(start) = path.find('{') {
200 if let Some(end) = path[start..].find('}') {
201 path.replace_range(start..=start + end, p);
202 }
203 }
204 }
205 Some(path)
206 }
207}