1mod pattern;
2mod route;
3mod service;
4
5pub(crate) use self::pattern::Pattern;
6pub use self::route::Path;
7pub(crate) use self::route::Route;
8use crate::endpoint::Endpoint;
9use crate::middleware::Middleware;
10use crate::{Request, Response};
11use std::pin::Pin;
12use std::sync::Arc;
13use tokio::sync::watch;
14
15pub struct Router {
49 regex: regex::RegexSet,
50 routes: Vec<Arc<Route>>,
51 middleware: Vec<Pin<Box<dyn Middleware>>>,
52 fallback: Option<Pin<Box<dyn Endpoint>>>,
53 terminate: Option<watch::Receiver<bool>>,
54}
55
56impl Default for Router {
57 fn default() -> Self {
58 Router {
59 regex: regex::RegexSet::empty(),
60 middleware: vec![],
61 routes: vec![],
62 fallback: None,
63 terminate: None,
64 }
65 }
66}
67
68impl Router {
69 #[allow(clippy::missing_panics_doc)]
77 pub fn prepare(&mut self) {
78 let patterns = self
79 .routes
80 .iter()
81 .map(|route| route.pattern.regex().as_str());
82 let set = regex::RegexSet::new(patterns).unwrap();
85 self.regex = set;
86 }
87
88 pub(crate) fn routes(&self) -> &[Arc<Route>] {
89 &self.routes[..]
90 }
91
92 pub fn at<P: AsRef<str>>(&mut self, prefix: P) -> Path<'_> {
94 Path::new(join_paths("", prefix.as_ref()), &mut self.routes)
95 }
96
97 pub fn under<P: AsRef<str>, F: FnOnce(&mut Path<'_>)>(
100 &mut self,
101 prefix: P,
102 build: F,
103 ) -> &mut Self {
104 let mut path = Path::new(join_paths("", prefix.as_ref()), &mut self.routes);
105 build(&mut path);
106 self
107 }
108
109 pub fn with<M: Middleware>(&mut self, middleware: M) -> &mut Self {
120 self.middleware.push(Box::pin(middleware));
121 self
122 }
123
124 pub fn fallback<E: Endpoint>(&mut self, endpoint: E) -> &mut Self {
144 self.fallback = Some(Box::pin(endpoint));
145 self
146 }
147
148 pub fn termination_signal(&mut self) -> watch::Sender<bool> {
158 let (tx, rx) = watch::channel(false);
159 self.terminate = Some(rx);
160 tx
161 }
162
163 pub async fn handle(&self, request: Request) -> Result<Response, anyhow::Error> {
172 Pin::new(self).apply(request).await
173 }
174
175 pub(crate) fn lookup(&self, path: &str, method: &http::Method) -> Option<Arc<Route>> {
176 self.regex
177 .matches(path)
178 .into_iter()
179 .map(|i| &self.routes[i])
180 .filter(|r| r.matches(method))
181 .next_back()
182 .cloned()
183 }
184
185 fn fallback_endpoint(&self) -> Option<Pin<&dyn Endpoint>> {
186 self.fallback.as_ref().map(Pin::as_ref)
187 }
188}
189
190#[async_trait]
191impl crate::Endpoint for Router {
192 async fn apply(self: Pin<&Self>, mut request: Request) -> Result<Response, anyhow::Error> {
193 let route = self.lookup(request.uri().path(), request.method());
194 if let Some(route) = route.clone() {
195 if let Some(fragment) =
198 crate::request::fragment::Fragment::new(request.uri().path(), &route)
199 {
200 request.extensions_mut().insert(fragment);
201 }
202 request.extensions_mut().insert(route);
203 }
204
205 let endpoint = {
206 let route_endpoint = || route.as_ref().map(|e| e.endpoint().as_ref());
207 let fallback_endpoint = || self.fallback_endpoint();
208 route_endpoint()
209 .or_else(fallback_endpoint)
210 .unwrap_or_else(default_endpoint)
211 };
212 log::trace!("{} {} --> {:?}", request.method(), request.uri(), endpoint);
213 let next = crate::middleware::Next::new(&self.middleware[..], endpoint);
214 next.apply(request).await
215 }
216}
217
218impl std::fmt::Debug for Router {
219 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
220 f.debug_struct("Router")
221 .field("regex", &self.regex)
222 .field("routes", &self.routes)
223 .finish()
224 }
225}
226
227lazy_static::lazy_static! {
228 static ref DEFAULT_ENDPOINT: crate::endpoints::SyncEndpoint<fn(Request) -> Response> = crate::endpoints::SyncEndpoint::new(|_| Response::empty_500());
229 static ref DEFAULT_ENDPOINT_PIN: Pin<&'static (dyn Endpoint + Unpin + 'static)> = Pin::new(&*DEFAULT_ENDPOINT);
230}
231
232pub(crate) fn default_endpoint<'r>() -> Pin<&'r dyn Endpoint> {
235 *DEFAULT_ENDPOINT_PIN
236}
237
238fn join_paths(base: &str, extend: &str) -> String {
240 let mut buffer = String::with_capacity(base.len() + extend.len());
241 buffer.push_str(base);
242
243 match (base.ends_with('/'), extend.starts_with('/')) {
244 (true, true) => {
245 buffer.push_str(&extend[1..]);
246 }
247 (false, true) | (true, false) => {
248 buffer.push_str(extend);
249 }
250 (false, false) => {
251 buffer.push('/');
252 buffer.push_str(extend);
253 }
254 }
255
256 buffer.shrink_to_fit();
257 buffer
258}
259
260#[cfg(test)]
261mod test {
262 use super::*;
263 use crate::request::Request;
264 use crate::response::Response;
265 use crate::UnderError;
266
267 #[allow(clippy::unused_async)]
268 async fn simple_endpoint(_: Request) -> Result<Response, UnderError> {
269 unimplemented!()
270 }
271
272 fn simple_router() -> Router {
273 let mut router = Router::default();
274 router.at("/").get(simple_endpoint);
275 router.at("/alpha").get(simple_endpoint);
276 router.at("/beta/{id}").get(simple_endpoint);
277 router.at("/gamma/{all:path}").get(simple_endpoint);
278 router.prepare();
279 router
280 }
281
282 #[test]
283 fn test_join_paths() {
284 assert_eq!(join_paths("", "/id"), "/id");
285 assert_eq!(join_paths("", "id"), "/id");
286 assert_eq!(join_paths("/user", "/id"), "/user/id");
287 assert_eq!(join_paths("/user/", "/id"), "/user/id");
288 assert_eq!(join_paths("/user/", "id"), "/user/id");
289 }
290
291 #[test]
292 fn test_build() {
293 simple_router();
294 }
295
296 #[test]
297 fn test_basic_match() {
298 let router = simple_router();
299 dbg!(&router);
300 let result = router.lookup("/", &http::Method::GET);
301 assert!(result.is_some());
302 let result = result.unwrap();
303 assert_eq!("/", &result.path);
304 }
305
306 #[test]
307 fn test_simple_match() {
308 let router = simple_router();
309 let result = router.lookup("/beta/4444", &http::Method::GET);
310 assert!(result.is_some());
311 let result = result.unwrap();
312 assert_eq!("/beta/{id}", &result.path);
313 }
314
315 #[test]
316 fn test_multi_match() {
317 let router = simple_router();
318 let result = router.lookup("/gamma/a/b/c", &http::Method::GET);
319 assert!(result.is_some());
320 let result = result.unwrap();
321 assert_eq!("/gamma/{all:path}", &result.path);
322 }
323
324 #[test]
325 fn test_missing_match() {
326 let router = simple_router();
327 let result = router.lookup("/omega/aaa", &http::Method::GET);
328 assert!(result.is_none());
329 }
330
331 #[test]
332 fn test_correct_method() {
333 let router = simple_router();
334 let result = router.lookup("/alpha", &http::Method::POST);
335 assert!(result.is_none());
336 }
337}