1use viz_core::{
2 BoxHandler, Handler, HandlerExt, IntoResponse, Next, Request, Response, Result, Transform,
3};
4
5use crate::{Resources, Route};
6
7macro_rules! export_verb {
8 ($name:ident $verb:ty) => {
9 #[doc = concat!(" Adds a handler with a path and HTTP `", stringify!($verb), "` verb pair.")]
10 #[must_use]
11 pub fn $name<S, H, O>(self, path: S, handler: H) -> Self
12 where
13 S: AsRef<str>,
14 H: Handler<Request, Output = Result<O>> + Clone,
15 O: IntoResponse + Send + 'static,
16 {
17 self.route(path, Route::new().$name(handler))
18 }
19 };
20}
21
22#[derive(Clone, Debug, Default)]
24pub struct Router {
25 pub(crate) routes: Option<Vec<(String, Route)>>,
26}
27
28impl Router {
29 #[must_use]
31 pub const fn new() -> Self {
32 Self { routes: None }
33 }
34
35 fn push<S>(routes: &mut Vec<(String, Route)>, path: S, route: Route)
36 where
37 S: AsRef<str>,
38 {
39 let path = path.as_ref();
40 match routes
41 .iter_mut()
42 .find_map(|(p, r)| if p == path { Some(r) } else { None })
43 {
44 Some(r) => {
45 *r = route.into_iter().fold(
46 r.clone().into_iter().collect(),
48 |or: Route, (method, handler)| or.on(method, handler),
49 );
50 }
51 None => routes.push((path.to_string(), route)),
52 }
53 }
54
55 #[must_use]
57 pub fn route<S>(mut self, path: S, route: Route) -> Self
58 where
59 S: AsRef<str>,
60 {
61 Self::push(
62 self.routes.get_or_insert_with(Vec::new),
63 path.as_ref().trim_start_matches('/'),
64 route,
65 );
66 self
67 }
68
69 #[must_use]
71 pub fn resources<S>(self, path: S, resource: Resources) -> Self
72 where
73 S: AsRef<str>,
74 {
75 let mut path = path.as_ref().to_string();
76 if !path.ends_with('/') {
77 path.push('/');
78 }
79
80 resource.into_iter().fold(self, |router, (mut sp, route)| {
81 let is_empty = sp.is_empty();
82 sp = path.clone() + &sp;
83 if is_empty {
84 sp = sp.trim_end_matches('/').to_string();
85 }
86 router.route(sp, route)
87 })
88 }
89
90 #[allow(clippy::similar_names)]
92 #[must_use]
93 pub fn nest<S>(self, path: S, router: Self) -> Self
94 where
95 S: AsRef<str>,
96 {
97 let mut path = path.as_ref().to_string();
98 if !path.ends_with('/') {
99 path.push('/');
100 }
101
102 match router.routes {
103 Some(routes) => routes.into_iter().fold(self, |router, (mut sp, route)| {
104 let is_empty = sp.is_empty();
105 sp = path.clone() + &sp;
106 if is_empty {
107 sp = sp.trim_end_matches('/').to_string();
108 }
109 router.route(sp, route)
110 }),
111 None => self,
112 }
113 }
114
115 repeat!(
116 export_verb
117 get GET
118 post POST
119 put PUT
120 delete DELETE
121 head HEAD
122 options OPTIONS
123 connect CONNECT
124 patch PATCH
125 trace TRACE
126 );
127
128 #[must_use]
130 pub fn any<S, H, O>(self, path: S, handler: H) -> Self
131 where
132 S: AsRef<str>,
133 H: Handler<Request, Output = Result<O>> + Clone,
134 O: IntoResponse + Send + 'static,
135 {
136 self.route(path, Route::new().any(handler))
137 }
138
139 #[must_use]
141 pub fn map_handler<F>(self, f: F) -> Self
142 where
143 F: Fn(BoxHandler<Request, Result<Response>>) -> BoxHandler<Request, Result<Response>>,
144 {
145 Self {
146 routes: self.routes.map(|routes| {
147 routes
148 .into_iter()
149 .map(|(path, route)| {
150 (
151 path,
152 route
153 .into_iter()
154 .map(|(method, handler)| (method, f(handler)))
155 .collect(),
156 )
157 })
158 .collect()
159 }),
160 }
161 }
162
163 #[must_use]
165 pub fn with<T>(self, t: T) -> Self
166 where
167 T: Transform<BoxHandler>,
168 T::Output: Handler<Request, Output = Result<Response>> + Clone,
169 {
170 self.map_handler(|handler| t.transform(handler).boxed())
171 }
172
173 #[must_use]
175 pub fn with_handler<H>(self, f: H) -> Self
176 where
177 H: Handler<Next<Request, BoxHandler>, Output = Result<Response>> + Clone,
178 {
179 self.map_handler(|handler| handler.around(f.clone()).boxed())
180 }
181}
182
183#[cfg(test)]
184#[allow(clippy::unused_async)]
185mod tests {
186 use http_body_util::{BodyExt, Full};
187 use std::sync::Arc;
188 use viz_core::{
189 Body, Error, Handler, HandlerExt, IntoResponse, Method, Next, Request, RequestExt,
190 Response, ResponseExt, Result, StatusCode, Transform, async_trait,
191 types::{Params, RouteInfo},
192 };
193
194 use crate::{Resources, Route, Router, Tree, any, get};
195
196 #[derive(Clone)]
197 struct Logger;
198
199 impl Logger {
200 const fn new() -> Self {
201 Self
202 }
203 }
204
205 impl<H: Clone> Transform<H> for Logger {
206 type Output = LoggerHandler<H>;
207
208 fn transform(&self, h: H) -> Self::Output {
209 LoggerHandler(h)
210 }
211 }
212
213 #[derive(Clone)]
214 struct LoggerHandler<H>(H);
215
216 #[async_trait]
217 impl<H> Handler<Request> for LoggerHandler<H>
218 where
219 H: Handler<Request>,
220 {
221 type Output = H::Output;
222
223 async fn call(&self, req: Request) -> Self::Output {
224 self.0.call(req).await
225 }
226 }
227
228 #[tokio::test]
229 async fn router() -> anyhow::Result<()> {
230 async fn index(_: Request) -> Result<Response> {
231 Ok(Response::text("index"))
232 }
233
234 async fn all(_: Request) -> Result<Response> {
235 Ok(Response::text("any"))
236 }
237
238 async fn not_found(_: Request) -> Result<impl IntoResponse> {
239 Ok(StatusCode::NOT_FOUND)
240 }
241
242 async fn search(_: Request) -> Result<Response> {
243 Ok(Response::text("search"))
244 }
245
246 async fn show(req: Request) -> Result<Response> {
247 let ids: Vec<String> = req.params()?;
248 let items = ids.into_iter().fold(String::new(), |mut s, id| {
249 s.push(' ');
250 s.push_str(&id);
251 s
252 });
253 Ok(Response::text("show".to_string() + &items))
254 }
255
256 async fn create(_: Request) -> Result<Response> {
257 Ok(Response::text("create"))
258 }
259
260 async fn update(req: Request) -> Result<Response> {
261 let ids: Vec<String> = req.params()?;
262 let items = ids.into_iter().fold(String::new(), |mut s, id| {
263 s.push(' ');
264 s.push_str(&id);
265 s
266 });
267 Ok(Response::text("update".to_string() + &items))
268 }
269
270 async fn delete(req: Request) -> Result<Response> {
271 let ids: Vec<String> = req.params()?;
272 let items = ids.into_iter().fold(String::new(), |mut s, id| {
273 s.push(' ');
274 s.push_str(&id);
275 s
276 });
277 Ok(Response::text("delete".to_string() + &items))
278 }
279
280 async fn middle<H>((req, h): Next<Request, H>) -> Result<Response>
281 where
282 H: Handler<Request, Output = Result<Response>>,
283 {
284 h.call(req).await
285 }
286
287 let users = Resources::default()
288 .named("user")
289 .index(index)
290 .create(create.before(|r: Request| async { Ok(r) }).around(middle))
291 .show(show)
292 .update(update)
293 .destroy(delete)
294 .map_handler(|h| {
295 h.and_then(|res: Response| async {
296 let (parts, body) = res.into_parts();
297
298 let mut buf = bytes::BytesMut::new();
299 buf.extend(b"users: ");
300 buf.extend(body.collect().await.map_err(Error::boxed)?.to_bytes());
301
302 Ok(Response::from_parts(parts, Full::from(buf.freeze()).into()))
303 })
304 .boxed()
305 });
306
307 let posts = Router::new().route("search", get(search)).resources(
308 "",
309 Resources::default()
310 .named("post")
311 .create(create)
312 .show(show)
313 .update(update)
314 .destroy(delete)
315 .map_handler(|h| {
316 h.and_then(|res: Response| async {
317 let (parts, body) = res.into_parts();
318
319 let mut buf = bytes::BytesMut::new();
320 buf.extend(b"posts: ");
321 buf.extend(body.collect().await.map_err(Error::boxed)?.to_bytes());
322
323 Ok(Response::from_parts(parts, Full::from(buf.freeze()).into()))
324 })
325 .boxed()
326 }),
327 );
328
329 let router = Router::new()
330 .get("", index)
332 .resources("users", users.clone())
333 .nest("posts", posts.resources(":post_id/users", users))
334 .route("search", any(all))
335 .route("*", Route::new().any(not_found))
336 .with(Logger::new());
337
338 let tree: Tree = router.into();
339
340 let (req, method, path) = client(Method::GET, "/posts");
342 let node = tree.find(&method, &path);
343 assert!(node.is_some());
344 let (h, _) = node.unwrap();
345 assert_eq!(
346 h.call(req).await?.into_body().collect().await?.to_bytes(),
347 ""
348 );
349
350 let (req, method, path) = client(Method::POST, "/posts");
352 let node = tree.find(&method, &path);
353 assert!(node.is_some());
354 let (h, _) = node.unwrap();
355 assert_eq!(
356 h.call(req).await?.into_body().collect().await?.to_bytes(),
357 "posts: create"
358 );
359
360 let (mut req, method, path) = client(Method::GET, "/posts/foo");
362 let node = tree.find(&method, &path);
363 assert!(node.is_some());
364 let (h, route) = node.unwrap();
365 req.extensions_mut().insert(Arc::from(RouteInfo {
366 id: *route.id,
367 pattern: route.pattern(),
368 params: route.params().into(),
369 }));
370 assert_eq!(
371 h.call(req).await?.into_body().collect().await?.to_bytes(),
372 "posts: show foo"
373 );
374
375 let (mut req, method, path) = client(Method::PUT, "/posts/foo");
377 let node = tree.find(&method, &path);
378 assert!(node.is_some());
379 let (h, route) = node.unwrap();
380 req.extensions_mut().insert(Arc::from(RouteInfo {
381 id: *route.id,
382 pattern: route.pattern(),
383 params: Into::<Params>::into(route.params()),
384 }));
385 assert_eq!(
386 h.call(req).await?.into_body().collect().await?.to_bytes(),
387 "posts: update foo"
388 );
389
390 let (mut req, method, path) = client(Method::DELETE, "/posts/foo");
392 let node = tree.find(&method, &path);
393 assert!(node.is_some());
394 let (h, route) = node.unwrap();
395 req.extensions_mut().insert(Arc::from(RouteInfo {
396 id: *route.id,
397 pattern: route.pattern(),
398 params: route.params().into(),
399 }));
400 assert_eq!(
401 h.call(req).await?.into_body().collect().await?.to_bytes(),
402 "posts: delete foo"
403 );
404
405 let (req, method, path) = client(Method::GET, "/posts/foo/users");
407 let node = tree.find(&method, &path);
408 assert!(node.is_some());
409 let (h, _) = node.unwrap();
410 assert_eq!(
411 h.call(req).await?.into_body().collect().await?.to_bytes(),
412 "users: index"
413 );
414
415 let (req, method, path) = client(Method::POST, "/posts/foo/users");
417 let node = tree.find(&method, &path);
418 assert!(node.is_some());
419 let (h, _) = node.unwrap();
420 assert_eq!(
421 h.call(req).await?.into_body().collect().await?.to_bytes(),
422 "users: create"
423 );
424
425 let (mut req, method, path) = client(Method::GET, "/posts/foo/users/bar");
427 let node = tree.find(&method, &path);
428 assert!(node.is_some());
429 let (h, route) = node.unwrap();
430 req.extensions_mut().insert(Arc::from(RouteInfo {
431 id: *route.id,
432 pattern: route.pattern(),
433 params: route.params().into(),
434 }));
435 assert_eq!(
436 h.call(req).await?.into_body().collect().await?.to_bytes(),
437 "users: show foo bar"
438 );
439
440 let (mut req, method, path) = client(Method::PUT, "/posts/foo/users/bar");
442 let node = tree.find(&method, &path);
443 assert!(node.is_some());
444 let (h, route) = node.unwrap();
445 let route_info = Arc::from(RouteInfo {
446 id: *route.id,
447 pattern: route.pattern(),
448 params: route.params().into(),
449 });
450 assert_eq!(route.pattern(), "/posts/:post_id/users/:user_id");
451 assert_eq!(route_info.pattern, "/posts/:post_id/users/:user_id");
452 req.extensions_mut().insert(route_info);
453 assert_eq!(
454 h.call(req).await?.into_body().collect().await?.to_bytes(),
455 "users: update foo bar"
456 );
457
458 let (mut req, method, path) = client(Method::DELETE, "/posts/foo/users/bar");
460 let node = tree.find(&method, &path);
461 assert!(node.is_some());
462 let (h, route) = node.unwrap();
463 req.extensions_mut().insert(Arc::from(RouteInfo {
464 id: *route.id,
465 pattern: route.pattern(),
466 params: route.params().into(),
467 }));
468 assert_eq!(
469 h.call(req).await?.into_body().collect().await?.to_bytes(),
470 "users: delete foo bar"
471 );
472
473 Ok(())
474 }
475
476 #[test]
477 fn debug() {
478 let search = Route::new().get(|_: Request| async { Ok(Response::text("search")) });
479
480 let orgs = Resources::default()
481 .index(|_: Request| async { Ok(Response::text("list posts")) })
482 .create(|_: Request| async { Ok(Response::text("create post")) })
483 .show(|_: Request| async { Ok(Response::text("show post")) });
484
485 let settings = Router::new()
486 .get("/", |_: Request| async { Ok(Response::text("settings")) })
487 .get("/:page", |_: Request| async {
488 Ok(Response::text("setting page"))
489 });
490
491 let app = Router::new()
492 .get("/", |_: Request| async { Ok(Response::text("index")) })
493 .route("search", search.clone())
494 .resources(":org", orgs)
495 .nest("settings", settings)
496 .nest("api", Router::new().route("/search", search));
497
498 let tree: Tree = app.into();
499
500 assert_eq!(
501 format!("{tree:#?}"),
502 "Tree {
503 method: GET,
504 paths:
505 / •0
506 ├── api/search •6
507 ├── se
508 │ ├── arch •1
509 │ └── ttings •4
510 │ └── /
511 │ └── : •5
512 └── : •2
513 └── /
514 └── : •3
515 ,
516 method: POST,
517 paths:
518 /
519 └── : •0
520 ,
521}"
522 );
523 }
524
525 fn client(method: Method, path: &str) -> (Request, Method, String) {
526 (
527 Request::builder()
528 .method(method.clone())
529 .uri(path.to_owned())
530 .body(Body::Empty)
531 .unwrap(),
532 method,
533 path.to_string(),
534 )
535 }
536}