puzz_route/
router.rs

1use std::collections::HashMap;
2use std::fmt;
3
4use matchit::Match;
5use puzz_core::http::uri::{Parts, PathAndQuery, Uri};
6use puzz_core::response::IntoResponse;
7use puzz_core::service::util::BoxService;
8use puzz_core::service::{Service, ServiceExt};
9use puzz_core::{BoxError, Request, Response};
10
11use crate::error::NotFound;
12use crate::RouteFuture;
13
14const PRIVATE_TAIL_PARAM: &'static str = "__private__tail_param";
15
16enum Endpoint {
17    Full(BoxService<Request, Response, BoxError>),
18    Nest(BoxService<Request, Response, BoxError>),
19}
20
21/// 路由器
22///
23/// 匹配传入的HTTP请求并将它们分派给[服务](`Service`)进行处理。
24///
25/// # 例子
26///
27/// ```
28/// use std::convert::Infallible;
29///
30/// use puzz_core::service_fn;
31/// use puzz_route::Router;
32///
33/// Router::new()
34///     .route("/hi", service_fn(|_| async { Ok::<_, Infallible>("hi!") }));
35/// ```
36pub struct Router {
37    inner: matchit::Router<Endpoint>,
38}
39
40impl Router {
41    /// 创建一个空的路由器。
42    pub fn new() -> Self {
43        Self {
44            inner: matchit::Router::new(),
45        }
46    }
47
48    /// 将服务挂载到一条路由上。
49    ///
50    /// # 例子
51    ///
52    /// ```
53    /// use std::convert::Infallible;
54    ///
55    /// use puzz_core::service_fn;
56    /// use puzz_route::Router;
57    ///
58    /// Router::new()
59    ///     .route("/hi", service_fn(|_| async { Ok::<_, Infallible>("hi!") }));
60    /// ```
61    pub fn route<S>(self, path: &str, service: S) -> Self
62    where
63        S: Service<Request> + 'static,
64        S::Response: IntoResponse,
65        S::Error: Into<BoxError>,
66    {
67        if !path.starts_with('/') {
68            panic!("Path must start with a `/`");
69        }
70        let path = if path.ends_with('*') {
71            format!("{path}{PRIVATE_TAIL_PARAM}")
72        } else {
73            path.into()
74        };
75        self.add_route(path, Endpoint::Full(Self::into_box_service(service)))
76    }
77
78    /// 将服务挂载到一条嵌套路由上。
79    ///
80    /// # 例子
81    ///
82    /// ```
83    /// use std::convert::Infallible;
84    ///
85    /// use puzz_core::service_fn;
86    /// use puzz_route::Router;
87    ///
88    /// Router::new()
89    ///     .nest("/hi", service_fn(|_| async { Ok::<_, Infallible>("hi!") }));
90    /// ```
91    pub fn nest<S>(self, path: &str, service: S) -> Self
92    where
93        S: Service<Request> + 'static,
94        S::Response: IntoResponse,
95        S::Error: Into<BoxError>,
96    {
97        if !path.starts_with('/') {
98            panic!("Path must start with a `/`");
99        }
100        let path = if path.ends_with('/') {
101            format!("{path}*{PRIVATE_TAIL_PARAM}")
102        } else {
103            format!("{path}/*{PRIVATE_TAIL_PARAM}")
104        };
105        self.add_route(path, Endpoint::Nest(Self::into_box_service(service)))
106    }
107
108    fn add_route(mut self, path: String, endpoint: Endpoint) -> Self {
109        if let Err(e) = self.inner.insert(path, endpoint) {
110            panic!("{e}");
111        }
112        self
113    }
114
115    fn into_box_service<S>(service: S) -> BoxService<Request, Response, BoxError>
116    where
117        S: Service<Request> + 'static,
118        S::Response: IntoResponse,
119        S::Error: Into<BoxError>,
120    {
121        service
122            .map_response(IntoResponse::into_response)
123            .map_err(Into::into)
124            .boxed()
125    }
126}
127
128impl fmt::Debug for Router {
129    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130        f.debug_struct("Router").finish()
131    }
132}
133
134impl Service<Request> for Router {
135    type Response = Response;
136    type Error = BoxError;
137    type Future = RouteFuture;
138
139    fn call(&self, mut request: Request) -> Self::Future {
140        match self.inner.at(request.uri().path()) {
141            Ok(Match { value, params }) => {
142                let fut = match value {
143                    Endpoint::Full(service) => {
144                        let (params, _) = take_params(params);
145                        insert_params(&mut request, params);
146                        service.call(request)
147                    }
148                    Endpoint::Nest(service) => {
149                        let (params, tail) = take_params(params);
150                        insert_params(&mut request, params);
151                        replace_path(&mut request, &tail.unwrap());
152                        service.call(request)
153                    }
154                };
155                RouteFuture::Future { fut }
156            }
157            Err(_) => RouteFuture::Error {
158                err: Some(NotFound::new(request).into()),
159            },
160        }
161    }
162}
163
164/// 路由器提取的路径参数。
165#[derive(Debug, Clone)]
166pub struct Params(HashMap<String, String>);
167
168impl Params {
169    pub(crate) fn new() -> Self {
170        Self(HashMap::new())
171    }
172
173    pub fn get_ref(&self) -> &HashMap<String, String> {
174        &self.0
175    }
176
177    pub fn into_inner(self) -> HashMap<String, String> {
178        self.0
179    }
180}
181
182fn take_params(params: matchit::Params) -> (Vec<(String, String)>, Option<String>) {
183    let mut path = None;
184    (
185        params
186            .iter()
187            .filter_map(|(k, v)| {
188                if k == PRIVATE_TAIL_PARAM {
189                    path = Some(v.to_owned());
190                    None
191                } else {
192                    Some((k.to_owned(), v.to_owned()))
193                }
194            })
195            .collect(),
196        path,
197    )
198}
199
200fn insert_params(request: &mut Request, captures: Vec<(String, String)>) {
201    let extensions = request.extensions_mut();
202
203    let params = if let Some(params) = extensions.get_mut::<Params>() {
204        params
205    } else {
206        extensions.insert(Params::new());
207        extensions.get_mut::<Params>().unwrap()
208    };
209
210    params.0.extend(captures);
211}
212
213fn replace_path(request: &mut Request, path: &str) {
214    let uri = request.uri_mut();
215
216    let path_and_query = if let Some(query) = uri.query() {
217        format!("{}?{}", path, query)
218            .parse::<PathAndQuery>()
219            .unwrap()
220    } else {
221        path.parse().unwrap()
222    };
223
224    replace_path_and_query(uri, path_and_query);
225}
226
227fn replace_path_and_query(uri: &mut Uri, path_and_query: PathAndQuery) {
228    let mut parts = Parts::default();
229
230    parts.scheme = uri.scheme().cloned();
231    parts.authority = uri.authority().cloned();
232    parts.path_and_query = Some(path_and_query);
233
234    *uri = Uri::from_parts(parts).unwrap();
235}