1use std::collections::HashMap;
2use std::fmt;
3
4use matchit::Match;
5use puzz_core::http::uri::{Parts, PathAndQuery, Uri};
6use puzz_core::response::IntoResponse;
7use puzz_core::service::util::BoxService;
8use puzz_core::service::{Service, ServiceExt};
9use puzz_core::{BoxError, Request, Response};
10
11use crate::error::NotFound;
12use crate::RouteFuture;
13
14const PRIVATE_TAIL_PARAM: &'static str = "__private__tail_param";
15
16enum Endpoint {
17 Full(BoxService<Request, Response, BoxError>),
18 Nest(BoxService<Request, Response, BoxError>),
19}
20
21pub struct Router {
37 inner: matchit::Router<Endpoint>,
38}
39
40impl Router {
41 pub fn new() -> Self {
43 Self {
44 inner: matchit::Router::new(),
45 }
46 }
47
48 pub fn route<S>(self, path: &str, service: S) -> Self
62 where
63 S: Service<Request> + 'static,
64 S::Response: IntoResponse,
65 S::Error: Into<BoxError>,
66 {
67 if !path.starts_with('/') {
68 panic!("Path must start with a `/`");
69 }
70 let path = if path.ends_with('*') {
71 format!("{path}{PRIVATE_TAIL_PARAM}")
72 } else {
73 path.into()
74 };
75 self.add_route(path, Endpoint::Full(Self::into_box_service(service)))
76 }
77
78 pub fn nest<S>(self, path: &str, service: S) -> Self
92 where
93 S: Service<Request> + 'static,
94 S::Response: IntoResponse,
95 S::Error: Into<BoxError>,
96 {
97 if !path.starts_with('/') {
98 panic!("Path must start with a `/`");
99 }
100 let path = if path.ends_with('/') {
101 format!("{path}*{PRIVATE_TAIL_PARAM}")
102 } else {
103 format!("{path}/*{PRIVATE_TAIL_PARAM}")
104 };
105 self.add_route(path, Endpoint::Nest(Self::into_box_service(service)))
106 }
107
108 fn add_route(mut self, path: String, endpoint: Endpoint) -> Self {
109 if let Err(e) = self.inner.insert(path, endpoint) {
110 panic!("{e}");
111 }
112 self
113 }
114
115 fn into_box_service<S>(service: S) -> BoxService<Request, Response, BoxError>
116 where
117 S: Service<Request> + 'static,
118 S::Response: IntoResponse,
119 S::Error: Into<BoxError>,
120 {
121 service
122 .map_response(IntoResponse::into_response)
123 .map_err(Into::into)
124 .boxed()
125 }
126}
127
128impl fmt::Debug for Router {
129 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130 f.debug_struct("Router").finish()
131 }
132}
133
134impl Service<Request> for Router {
135 type Response = Response;
136 type Error = BoxError;
137 type Future = RouteFuture;
138
139 fn call(&self, mut request: Request) -> Self::Future {
140 match self.inner.at(request.uri().path()) {
141 Ok(Match { value, params }) => {
142 let fut = match value {
143 Endpoint::Full(service) => {
144 let (params, _) = take_params(params);
145 insert_params(&mut request, params);
146 service.call(request)
147 }
148 Endpoint::Nest(service) => {
149 let (params, tail) = take_params(params);
150 insert_params(&mut request, params);
151 replace_path(&mut request, &tail.unwrap());
152 service.call(request)
153 }
154 };
155 RouteFuture::Future { fut }
156 }
157 Err(_) => RouteFuture::Error {
158 err: Some(NotFound::new(request).into()),
159 },
160 }
161 }
162}
163
164#[derive(Debug, Clone)]
166pub struct Params(HashMap<String, String>);
167
168impl Params {
169 pub(crate) fn new() -> Self {
170 Self(HashMap::new())
171 }
172
173 pub fn get_ref(&self) -> &HashMap<String, String> {
174 &self.0
175 }
176
177 pub fn into_inner(self) -> HashMap<String, String> {
178 self.0
179 }
180}
181
182fn take_params(params: matchit::Params) -> (Vec<(String, String)>, Option<String>) {
183 let mut path = None;
184 (
185 params
186 .iter()
187 .filter_map(|(k, v)| {
188 if k == PRIVATE_TAIL_PARAM {
189 path = Some(v.to_owned());
190 None
191 } else {
192 Some((k.to_owned(), v.to_owned()))
193 }
194 })
195 .collect(),
196 path,
197 )
198}
199
200fn insert_params(request: &mut Request, captures: Vec<(String, String)>) {
201 let extensions = request.extensions_mut();
202
203 let params = if let Some(params) = extensions.get_mut::<Params>() {
204 params
205 } else {
206 extensions.insert(Params::new());
207 extensions.get_mut::<Params>().unwrap()
208 };
209
210 params.0.extend(captures);
211}
212
213fn replace_path(request: &mut Request, path: &str) {
214 let uri = request.uri_mut();
215
216 let path_and_query = if let Some(query) = uri.query() {
217 format!("{}?{}", path, query)
218 .parse::<PathAndQuery>()
219 .unwrap()
220 } else {
221 path.parse().unwrap()
222 };
223
224 replace_path_and_query(uri, path_and_query);
225}
226
227fn replace_path_and_query(uri: &mut Uri, path_and_query: PathAndQuery) {
228 let mut parts = Parts::default();
229
230 parts.scheme = uri.scheme().cloned();
231 parts.authority = uri.authority().cloned();
232 parts.path_and_query = Some(path_and_query);
233
234 *uri = Uri::from_parts(parts).unwrap();
235}