1use tower_service::Service;
2
3use crate::error::Error;
4use crate::future::Maybe;
5use crate::request::PathReq;
6use crate::request::RemovePrefix;
7use crate::route::Route;
8
9#[derive(Clone, Copy, Debug)]
14pub struct Mount<S> {
15 inner: S,
16 prefix: &'static str,
17}
18
19impl<S> Mount<S> {
20 #[inline]
21 pub(crate) fn new<T>(inner: S, prefix: &'static str) -> Mount<S>
22 where
23 S: Service<T>,
24 T: PathReq + RemovePrefix,
25 S::Error: From<Error>,
26 {
27 Mount { inner, prefix }
28 }
29}
30
31impl<S, T> Service<T> for Mount<S>
32where
33 S: Service<T>,
34 T: PathReq + RemovePrefix,
35 S::Error: From<Error>,
36{
37 type Response = S::Response;
38
39 type Error = S::Error;
40
41 type Future = Maybe<S::Future, Result<Self::Response, Self::Error>>;
42
43 #[inline]
44 fn poll_ready(
45 &mut self,
46 _: &mut std::task::Context<'_>,
47 ) -> std::task::Poll<Result<(), Self::Error>> {
48 std::task::Poll::Ready(Ok(()))
49 }
50
51 fn call(&mut self, req: T) -> Self::Future {
52 match req.remove_prefix(self.prefix) {
53 Err(err) => Maybe::ready(Err(err.into())),
54 Ok(req) => Maybe::Future(self.inner.call(req)),
55 }
56 }
57}
58
59impl<S, T> Route<T> for Mount<S>
60where
61 S: Service<T>,
62 T: PathReq + RemovePrefix,
63 S::Error: From<Error>,
64{
65 type Param = Param;
66
67 fn call_with_param(&mut self, req: T, _: Self::Param) -> Self::Future {
68 self.call(req)
69 }
70
71 fn param(&self, req: &T) -> Result<Self::Param, Error> {
72 match req.path().starts_with(self.prefix) {
73 true => Ok(Param),
74 false => Err(Error::Path),
75 }
76 }
77}
78
79#[derive(Clone, Copy, Debug)]
80pub struct Param;
81
82impl<T> crate::param::Param<T> for Param {
83 fn from_request(_: &T) -> Result<Self, Error> {
84 Ok(Param)
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use http::Request;
91
92 use crate::error::Error;
93 use crate::exec::run;
94 use crate::macros::param;
95 use crate::router::Router;
96
97 param!(Root, GET, "/");
98 param!(Route1, GET, "/route1");
99 param!(Route2, GET, "/route2");
100 param!(Other, GET, "/other");
101
102 async fn root(_: Request<()>, _: Root) -> Result<&'static str, Error> {
103 Ok("root")
104 }
105
106 async fn route1(_: Request<()>, _: Route1) -> Result<&'static str, Error> {
107 Ok("route1")
108 }
109
110 async fn route2(_: Request<()>, _: Route2) -> Result<&'static str, Error> {
111 Ok("route2")
112 }
113
114 async fn other(_: Request<()>, _: Other) -> Result<&'static str, Error> {
115 Ok("other")
116 }
117
118 fn req(path: &'static str) -> Request<()> {
119 Request::builder()
120 .method(http::Method::GET)
121 .uri(http::Uri::from_static(path))
122 .body(())
123 .unwrap()
124 }
125
126 #[test]
127 fn test() {
128 let root = Router::new(root).route(other);
129 let r1 = Router::new(route1);
130 let r2 = Router::new(route2);
131 let router = root.mount("/r1", r1).mount("/r2", r2);
132
133 let res = run(router, req("/"));
134 assert_eq!(res, Ok("root"));
135
136 let res = run(router, req("/other"));
137 assert_eq!(res, Ok("other"));
138
139 let res = run(router, req("/r1/route1"));
140 assert_eq!(res, Ok("route1"));
141
142 let res = run(router, req("/r2/route2"));
143 assert_eq!(res, Ok("route2"));
144 }
145}