1use super::{HttpMethod, RouteInfo, RouteRegistry, params::{ParamExtractor, ParamType}};
4use axum::{
5 Router as AxumRouter,
6 routing::{get, post, put, delete, patch},
7 handler::Handler,
8 response::IntoResponse,
9};
10use std::collections::HashMap;
11use std::sync::{Arc, Mutex};
12
13#[derive(Debug)]
15pub struct Router<S = ()>
16where
17 S: Clone + Send + Sync + 'static,
18{
19 axum_router: AxumRouter<S>,
20 registry: Arc<Mutex<RouteRegistry>>,
21 route_counter: Arc<Mutex<usize>>,
22}
23
24impl<S> Router<S>
25where
26 S: Clone + Send + Sync + 'static,
27{
28 pub fn new() -> Self {
30 Self {
31 axum_router: AxumRouter::new(),
32 registry: Arc::new(Mutex::new(RouteRegistry::new())),
33 route_counter: Arc::new(Mutex::new(0)),
34 }
35 }
36
37 pub fn with_state(state: S) -> Self {
39 Self {
40 axum_router: AxumRouter::new().with_state(state),
41 registry: Arc::new(Mutex::new(RouteRegistry::new())),
42 route_counter: Arc::new(Mutex::new(0)),
43 }
44 }
45
46 fn next_route_id(&self) -> String {
48 let mut counter = self.route_counter.lock().unwrap();
49 *counter += 1;
50 format!("route_{}", counter)
51 }
52
53 fn register_route(&self, method: HttpMethod, path: &str, name: Option<String>) -> String {
55 let route_id = self.next_route_id();
56 let params = self.extract_param_names(path);
57
58 let route_info = RouteInfo {
59 name: name.clone(),
60 path: path.to_string(),
61 method,
62 params,
63 group: None, };
65
66 self.registry.lock().unwrap().register(route_id.clone(), route_info);
67 route_id
68 }
69
70 fn extract_param_names(&self, path: &str) -> Vec<String> {
72 path.split('/')
73 .filter_map(|segment| {
74 if segment.starts_with('{') && segment.ends_with('}') {
75 Some(segment[1..segment.len()-1].to_string())
76 } else {
77 None
78 }
79 })
80 .collect()
81 }
82
83 pub fn get<H, T>(mut self, path: &str, handler: H) -> Self
85 where
86 H: Handler<T, S>,
87 T: 'static,
88 {
89 self.register_route(HttpMethod::GET, path, None);
90 self.axum_router = self.axum_router.route(path, get(handler));
91 self
92 }
93
94 pub fn post<H, T>(mut self, path: &str, handler: H) -> Self
96 where
97 H: Handler<T, S>,
98 T: 'static,
99 {
100 self.register_route(HttpMethod::POST, path, None);
101 self.axum_router = self.axum_router.route(path, post(handler));
102 self
103 }
104
105 pub fn put<H, T>(mut self, path: &str, handler: H) -> Self
107 where
108 H: Handler<T, S>,
109 T: 'static,
110 {
111 self.register_route(HttpMethod::PUT, path, None);
112 self.axum_router = self.axum_router.route(path, put(handler));
113 self
114 }
115
116 pub fn delete<H, T>(mut self, path: &str, handler: H) -> Self
118 where
119 H: Handler<T, S>,
120 T: 'static,
121 {
122 self.register_route(HttpMethod::DELETE, path, None);
123 self.axum_router = self.axum_router.route(path, delete(handler));
124 self
125 }
126
127 pub fn patch<H, T>(mut self, path: &str, handler: H) -> Self
129 where
130 H: Handler<T, S>,
131 T: 'static,
132 {
133 self.register_route(HttpMethod::PATCH, path, None);
134 self.axum_router = self.axum_router.route(path, patch(handler));
135 self
136 }
137
138 pub fn merge(mut self, other: Router<S>) -> Self {
140 if let (Ok(mut self_registry), Ok(other_registry)) =
142 (self.registry.lock(), other.registry.lock()) {
143 for (id, route_info) in other_registry.all_routes() {
144 self_registry.register(id.clone(), route_info.clone());
145 }
146 }
147
148 self.axum_router = self.axum_router.merge(other.axum_router);
150 self
151 }
152
153 pub(crate) fn merge_axum(mut self, other: AxumRouter<S>) -> Self {
155 self.axum_router = self.axum_router.merge(other);
156 self
157 }
158
159 pub fn nest(mut self, path: &str, router: Router<S>) -> Self {
161 if let (Ok(mut self_registry), Ok(router_registry)) =
164 (self.registry.lock(), router.registry.lock()) {
165 for (id, route_info) in router_registry.all_routes() {
166 self_registry.register(id.clone(), route_info.clone());
167 }
168 }
169
170 self.axum_router = self.axum_router.nest(path, router.axum_router);
171 self
172 }
173
174 pub(crate) fn nest_axum(mut self, path: &str, router: AxumRouter<S>) -> Self {
176 self.axum_router = self.axum_router.nest(path, router);
177 self
178 }
179
180 pub fn into_axum_router(self) -> AxumRouter<S> {
182 self.axum_router
183 }
184
185 pub fn registry(&self) -> Arc<Mutex<RouteRegistry>> {
187 Arc::clone(&self.registry)
188 }
189
190 pub fn url_for(&self, name: &str, params: &HashMap<String, String>) -> Option<String> {
192 let registry = self.registry.lock().unwrap();
193 if let Some(route) = registry.get_by_name(name) {
194 let mut url = route.path.clone();
195 for (key, value) in params {
196 url = url.replace(&format!("{{{}}}", key), value);
197 }
198 Some(url)
199 } else {
200 None
201 }
202 }
203}
204
205impl<S> Default for Router<S>
206where
207 S: Clone + Send + Sync + 'static,
208{
209 fn default() -> Self {
210 Self::new()
211 }
212}
213
214#[derive(Debug, Default)]
216pub struct RouteBuilder {
217 name: Option<String>,
218 param_types: HashMap<String, ParamType>,
219 middleware: Vec<String>, }
221
222impl RouteBuilder {
223 pub fn new() -> Self {
224 Self::default()
225 }
226
227 pub fn name(mut self, name: &str) -> Self {
229 self.name = Some(name.to_string());
230 self
231 }
232
233 pub fn param(mut self, name: &str, param_type: ParamType) -> Self {
235 self.param_types.insert(name.to_string(), param_type);
236 self
237 }
238
239 pub fn build(self) -> Route {
241 Route {
242 name: self.name,
243 param_types: self.param_types,
244 middleware: self.middleware,
245 }
246 }
247}
248
249#[derive(Debug)]
251pub struct Route {
252 pub name: Option<String>,
253 pub param_types: HashMap<String, ParamType>,
254 pub middleware: Vec<String>,
255}
256
257impl Route {
258 pub fn builder() -> RouteBuilder {
259 RouteBuilder::new()
260 }
261
262 pub fn param_extractor(&self) -> ParamExtractor {
264 let mut extractor = ParamExtractor::new();
265 for (name, param_type) in &self.param_types {
266 extractor = extractor.param(name, param_type.clone());
267 }
268 extractor
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use axum::response::Html;
276
277 async fn handler() -> Html<&'static str> {
278 Html("<h1>Hello, World!</h1>")
279 }
280
281 #[test]
282 fn test_router_creation() {
283 let router = Router::<()>::new()
284 .get("/", handler)
285 .post("/users", handler)
286 .get("/users/{id}", handler);
287
288 let registry = router.registry();
289 let reg = registry.lock().unwrap();
290 assert_eq!(reg.all_routes().len(), 3);
291 }
292
293 #[test]
294 fn test_param_extraction() {
295 let router = Router::<()>::new();
296 let params = router.extract_param_names("/users/{id}/posts/{slug}");
297 assert_eq!(params, vec!["id", "slug"]);
298 }
299
300 #[test]
301 fn test_url_generation() {
302 let mut router = Router::<()>::new().get("/users/{id}/posts/{slug}", handler);
303
304 {
306 let mut registry = router.registry.lock().unwrap();
307 let route_info = RouteInfo {
308 name: Some("user.posts.show".to_string()),
309 path: "/users/{id}/posts/{slug}".to_string(),
310 method: HttpMethod::GET,
311 params: vec!["id".to_string(), "slug".to_string()],
312 group: None,
313 };
314 registry.register("test_route".to_string(), route_info);
315 }
316
317 let mut params = HashMap::new();
318 params.insert("id".to_string(), "123".to_string());
319 params.insert("slug".to_string(), "hello-world".to_string());
320
321 let url = router.url_for("user.posts.show", ¶ms);
322 assert_eq!(url, Some("/users/123/posts/hello-world".to_string()));
323 }
324}