routerify/router/
mod.rs

1use crate::constants;
2use crate::data_map::ScopedDataMap;
3use crate::middleware::{PostMiddleware, PreMiddleware};
4use crate::route::Route;
5use crate::types::RequestInfo;
6use crate::Error;
7use crate::RouteError;
8use hyper::{body::HttpBody, header, Method, Request, Response, StatusCode};
9use regex::RegexSet;
10use std::any::Any;
11use std::fmt::{self, Debug, Formatter};
12use std::future::Future;
13use std::pin::Pin;
14
15pub use self::builder::RouterBuilder;
16
17mod builder;
18
19pub(crate) type ErrHandlerWithoutInfo<B> =
20    Box<dyn Fn(RouteError) -> ErrHandlerWithoutInfoReturn<B> + Send + Sync + 'static>;
21pub(crate) type ErrHandlerWithoutInfoReturn<B> = Box<dyn Future<Output = Response<B>> + Send + 'static>;
22
23pub(crate) type ErrHandlerWithInfo<B> =
24    Box<dyn Fn(RouteError, RequestInfo) -> ErrHandlerWithInfoReturn<B> + Send + Sync + 'static>;
25pub(crate) type ErrHandlerWithInfoReturn<B> = Box<dyn Future<Output = Response<B>> + Send + 'static>;
26
27/// Represents a modular, lightweight and mountable router type.
28///
29/// A router consists of some routes, some pre-middlewares and some post-middlewares.
30///
31/// This `Router<B, E>` type accepts two type parameters: `B` and `E`.
32///
33/// * The `B` represents the response body type which will be used by route handlers and the middlewares and this body type must implement
34///   the [HttpBody](https://docs.rs/hyper/0.14.4/hyper/body/trait.HttpBody.html) trait. For an instance, `B` could be [hyper::Body](https://docs.rs/hyper/0.14.4/hyper/body/struct.Body.html)
35///   type.
36/// * The `E` represents any error type which will be used by route handlers and the middlewares. This error type must implement the [std::error::Error](https://doc.rust-lang.org/std/error/trait.Error.html).
37///
38/// A `Router` can be created using the `Router::builder()` method.
39///
40/// # Examples
41///
42/// ```
43/// use routerify::Router;
44/// use hyper::{Response, Request, Body};
45///
46/// // A handler for "/about" page.
47/// // We will use hyper::Body as response body type and hyper::Error as error type.
48/// async fn about_handler(_: Request<Body>) -> Result<Response<Body>, hyper::Error> {
49///     Ok(Response::new(Body::from("About page")))
50/// }
51///
52/// # fn run() -> Router<Body, hyper::Error> {
53/// // Create a router with hyper::Body as response body type and hyper::Error as error type.
54/// let router: Router<Body, hyper::Error> = Router::builder()
55///     .get("/about", about_handler)
56///     .build()
57///     .unwrap();
58/// # router
59/// # }
60/// # run();
61/// ```
62pub struct Router<B, E> {
63    pub(crate) pre_middlewares: Vec<PreMiddleware<E>>,
64    pub(crate) routes: Vec<Route<B, E>>,
65    pub(crate) post_middlewares: Vec<PostMiddleware<B, E>>,
66    pub(crate) scoped_data_maps: Vec<ScopedDataMap>,
67
68    // This handler should be added only on root Router.
69    // Any error handler attached to scoped router will be ignored.
70    pub(crate) err_handler: Option<ErrHandler<B>>,
71
72    // We'll initialize it from the RouterService via Router::init_regex_set() method.
73    regex_set: Option<RegexSet>,
74
75    // We'll initialize it from the RouterService via Router::init_req_info_gen() method.
76    pub(crate) should_gen_req_info: Option<bool>,
77}
78
79pub(crate) enum ErrHandler<B> {
80    WithoutInfo(ErrHandlerWithoutInfo<B>),
81    WithInfo(ErrHandlerWithInfo<B>),
82}
83
84impl<B: HttpBody + Send + Sync + 'static> ErrHandler<B> {
85    pub(crate) async fn execute(&self, err: RouteError, req_info: Option<RequestInfo>) -> Response<B> {
86        match self {
87            ErrHandler::WithoutInfo(ref err_handler) => Pin::from(err_handler(err)).await,
88            ErrHandler::WithInfo(ref err_handler) => {
89                Pin::from(err_handler(err, req_info.expect("No RequestInfo is provided"))).await
90            }
91        }
92    }
93}
94
95impl<B: HttpBody + Send + Sync + 'static, E: Into<Box<dyn std::error::Error + Send + Sync>> + 'static> Router<B, E> {
96    pub(crate) fn new(
97        pre_middlewares: Vec<PreMiddleware<E>>,
98        routes: Vec<Route<B, E>>,
99        post_middlewares: Vec<PostMiddleware<B, E>>,
100        scoped_data_maps: Vec<ScopedDataMap>,
101        err_handler: Option<ErrHandler<B>>,
102    ) -> Self {
103        Router {
104            pre_middlewares,
105            routes,
106            post_middlewares,
107            scoped_data_maps,
108            err_handler,
109            regex_set: None,
110            should_gen_req_info: None,
111        }
112    }
113
114    pub(crate) fn init_regex_set(&mut self) -> crate::Result<()> {
115        let regex_iter = self
116            .pre_middlewares
117            .iter()
118            .map(|m| m.regex.as_str())
119            .chain(self.routes.iter().map(|r| r.regex.as_str()))
120            .chain(self.post_middlewares.iter().map(|m| m.regex.as_str()))
121            .chain(self.scoped_data_maps.iter().map(|d| d.regex.as_str()));
122
123        self.regex_set =
124            Some(RegexSet::new(regex_iter).map_err(|e| Error::new(format!("Couldn't create router RegexSet: {}", e)))?);
125
126        Ok(())
127    }
128
129    pub(crate) fn init_req_info_gen(&mut self) {
130        if let Some(ErrHandler::WithInfo(_)) = self.err_handler {
131            self.should_gen_req_info = Some(true);
132            return;
133        }
134
135        for post_middleware in self.post_middlewares.iter() {
136            if post_middleware.should_require_req_meta() {
137                self.should_gen_req_info = Some(true);
138                return;
139            }
140        }
141
142        self.should_gen_req_info = Some(false);
143    }
144
145    // pub(crate) fn init_keep_alive_middleware(&mut self) {
146    //     let keep_alive_post_middleware = PostMiddleware::new("/*", |mut res| async move {
147    //         res.headers_mut()
148    //             .insert(header::CONNECTION, HeaderValue::from_static("keep-alive"));
149    //         Ok(res)
150    //     })
151    //     .unwrap();
152
153    //     self.post_middlewares.push(keep_alive_post_middleware);
154    // }
155
156    pub(crate) fn init_global_options_route(&mut self) {
157        let options_method = vec![Method::OPTIONS];
158        let found = self
159            .routes
160            .iter()
161            .any(|route| route.path == "/*" && route.methods.as_slice() == options_method.as_slice());
162
163        if found {
164            return;
165        }
166
167        if let Some(router) = self.downcast_to_hyper_body_type() {
168            let options_route: Route<hyper::Body, E> = Route::new("/*", options_method, |_req| async move {
169                Ok(Response::builder()
170                    .status(StatusCode::NO_CONTENT)
171                    .body(hyper::Body::empty())
172                    .expect("Couldn't create the default OPTIONS response"))
173            })
174            .unwrap();
175
176            router.routes.push(options_route);
177        } else {
178            eprintln!(
179                "Warning: No global `options method` route added. It is recommended to send response to any `options` request.\n\
180                Please add one by calling `.options(\"/*\", handler)` method of the root router builder.\n"
181            );
182        }
183    }
184
185    pub(crate) fn init_default_404_route(&mut self) {
186        let found = self
187            .routes
188            .iter()
189            .any(|route| route.path == "/*" && route.methods.as_slice() == &constants::ALL_POSSIBLE_HTTP_METHODS[..]);
190
191        if found {
192            return;
193        }
194
195        if let Some(router) = self.downcast_to_hyper_body_type() {
196            let default_404_route: Route<hyper::Body, E> =
197                Route::new("/*", constants::ALL_POSSIBLE_HTTP_METHODS.to_vec(), |_req| async move {
198                    Ok(Response::builder()
199                        .status(StatusCode::NOT_FOUND)
200                        .header(header::CONTENT_TYPE, "text/plain")
201                        .body(hyper::Body::from(StatusCode::NOT_FOUND.canonical_reason().unwrap()))
202                        .expect("Couldn't create the default 404 response"))
203                })
204                .unwrap();
205            router.routes.push(default_404_route);
206        } else {
207            eprintln!(
208                "Warning: No default 404 route added. It is recommended to send 404 response to any non-existent route.\n\
209                Please add one by calling `.any(handler)` method of the root router builder.\n"
210            );
211        }
212    }
213
214    pub(crate) fn init_err_handler(&mut self) {
215        let found = self.err_handler.is_some();
216
217        if found {
218            return;
219        }
220
221        if let Some(router) = self.downcast_to_hyper_body_type() {
222            let handler: ErrHandler<hyper::Body> = ErrHandler::WithoutInfo(Box::new(move |err: RouteError| {
223                Box::new(async move {
224                    Response::builder()
225                        .status(StatusCode::INTERNAL_SERVER_ERROR)
226                        .header(header::CONTENT_TYPE, "text/plain")
227                        .body(hyper::Body::from(format!(
228                            "{}: {}",
229                            StatusCode::INTERNAL_SERVER_ERROR.canonical_reason().unwrap(),
230                            err
231                        )))
232                        .expect("Couldn't create a response while handling the server error")
233                })
234            }));
235            router.err_handler = Some(handler);
236        } else {
237            eprintln!(
238                "Warning: No error handler added. It is recommended to add one to see what went wrong if any route or middleware fails.\n\
239                Please add one by calling `.err_handler(handler)` method of the root router builder.\n"
240            );
241        }
242    }
243
244    fn downcast_to_hyper_body_type(&mut self) -> Option<&mut Router<hyper::Body, E>> {
245        let any_obj: &mut dyn Any = self;
246        any_obj.downcast_mut::<Router<hyper::Body, E>>()
247    }
248
249    /// Return a [RouterBuilder](./struct.RouterBuilder.html) instance to build a `Router`.
250    pub fn builder() -> RouterBuilder<B, E> {
251        builder::RouterBuilder::new()
252    }
253
254    pub(crate) async fn process(
255        &self,
256        target_path: &str,
257        mut req: Request<hyper::Body>,
258        mut req_info: Option<RequestInfo>,
259    ) -> crate::Result<Response<B>> {
260        let (
261            matched_pre_middleware_idxs,
262            matched_route_idxs,
263            matched_post_middleware_idxs,
264            matched_scoped_data_map_idxs,
265        ) = self.match_regex_set(target_path);
266
267        let mut route_scope_depth = None;
268        for idx in &matched_route_idxs {
269            let route = &self.routes[*idx];
270            // Middleware should be executed even if there's no route, e.g.
271            // logging. Before doing the depth check make sure that there's
272            // an actual route match, not a catch-all "/*".
273            if route.is_match_method(req.method()) && route.path != "/*" {
274                route_scope_depth = Some(route.scope_depth);
275                break;
276            }
277        }
278
279        let shared_data_maps = matched_scoped_data_map_idxs
280            .into_iter()
281            .map(|idx| self.scoped_data_maps[idx].clone_data_map())
282            .collect::<Vec<_>>();
283
284        if let Some(ref mut req_info) = req_info {
285            if !shared_data_maps.is_empty() {
286                req_info.shared_data_maps.replace(shared_data_maps.clone());
287            }
288        }
289
290        let ext = req.extensions_mut();
291        ext.insert(shared_data_maps);
292
293        let res_pre = self
294            .execute_pre_middleware(req, matched_pre_middleware_idxs, route_scope_depth, req_info.clone())
295            .await?;
296
297        // If pre middlewares succeed then execute the route handler.
298        // If a pre middleware fails and is able to generate error response
299        // (because Router.err_handler is set), then skip directly to post
300        // middleware.
301        let mut resp = None;
302        match res_pre {
303            Ok(transformed_req) => {
304                for idx in matched_route_idxs {
305                    let route = &self.routes[idx];
306
307                    if route.is_match_method(transformed_req.method()) {
308                        let route_resp_res = route.process(target_path, transformed_req).await;
309
310                        let route_resp = match route_resp_res {
311                            Ok(route_resp) => route_resp,
312                            Err(err) => {
313                                if let Some(ref err_handler) = self.err_handler {
314                                    err_handler.execute(err, req_info.clone()).await
315                                } else {
316                                    return Err(err);
317                                }
318                            }
319                        };
320
321                        resp = Some(route_resp);
322                        break;
323                    }
324                }
325            }
326            Err(err_response) => {
327                resp = Some(err_response);
328            }
329        };
330
331        if resp.is_none() {
332            let e = "No handlers added to handle non-existent routes. Tips: Please add an '.any' route at the bottom to handle any routes.";
333            return Err(crate::Error::new(e).into());
334        }
335
336        let mut transformed_res = resp.unwrap();
337        for idx in matched_post_middleware_idxs {
338            let post_middleware = &self.post_middlewares[idx];
339            // Do not execute middleware with the same prefix but from a deeper scope.
340            if route_scope_depth.is_none() || post_middleware.scope_depth <= route_scope_depth.unwrap() {
341                match post_middleware.process(transformed_res, req_info.clone()).await {
342                    Ok(res_resp) => {
343                        transformed_res = res_resp;
344                    }
345                    Err(err) => {
346                        if let Some(ref err_handler) = self.err_handler {
347                            return Ok(err_handler.execute(err, req_info.clone()).await);
348                        } else {
349                            return Err(err);
350                        }
351                    }
352                }
353            }
354        }
355
356        Ok(transformed_res)
357    }
358
359    async fn execute_pre_middleware(
360        &self,
361        req: Request<hyper::Body>,
362        matched_pre_middleware_idxs: Vec<usize>,
363        route_scope_depth: Option<u32>,
364        req_info: Option<RequestInfo>,
365    ) -> crate::Result<Result<Request<hyper::Body>, Response<B>>> {
366        let mut transformed_req = req;
367        for idx in matched_pre_middleware_idxs {
368            let pre_middleware = &self.pre_middlewares[idx];
369            // Do not execute middleware with the same prefix but from a deeper scope.
370            if route_scope_depth.is_none() || pre_middleware.scope_depth <= route_scope_depth.unwrap() {
371                match pre_middleware.process(transformed_req).await {
372                    Ok(res_req) => {
373                        transformed_req = res_req;
374                    }
375                    Err(err) => {
376                        if let Some(ref err_handler) = self.err_handler {
377                            return Ok(Err(err_handler.execute(err, req_info).await));
378                        } else {
379                            return Err(err);
380                        }
381                    }
382                }
383            }
384        }
385        Ok(Ok(transformed_req))
386    }
387
388    fn match_regex_set(&self, target_path: &str) -> (Vec<usize>, Vec<usize>, Vec<usize>, Vec<usize>) {
389        let matches = self
390            .regex_set
391            .as_ref()
392            .expect("The 'regex_set' field in Router is not initialized")
393            .matches(target_path)
394            .into_iter();
395
396        let pre_middlewares_len = self.pre_middlewares.len();
397        let routes_len = self.routes.len();
398        let post_middlewares_len = self.post_middlewares.len();
399        let scoped_data_maps_len = self.scoped_data_maps.len();
400
401        let mut matched_pre_middleware_idxs = Vec::new();
402        let mut matched_route_idxs = Vec::new();
403        let mut matched_post_middleware_idxs = Vec::new();
404        let mut matched_scoped_data_map_idxs = Vec::new();
405
406        for idx in matches {
407            if idx < pre_middlewares_len {
408                matched_pre_middleware_idxs.push(idx);
409            } else if idx >= pre_middlewares_len && idx < (pre_middlewares_len + routes_len) {
410                matched_route_idxs.push(idx - pre_middlewares_len);
411            } else if idx >= (pre_middlewares_len + routes_len)
412                && idx < (pre_middlewares_len + routes_len + post_middlewares_len)
413            {
414                matched_post_middleware_idxs.push(idx - pre_middlewares_len - routes_len);
415            } else if idx >= (pre_middlewares_len + routes_len + post_middlewares_len)
416                && idx < (pre_middlewares_len + routes_len + post_middlewares_len + scoped_data_maps_len)
417            {
418                matched_scoped_data_map_idxs.push(idx - pre_middlewares_len - routes_len - post_middlewares_len);
419            }
420        }
421
422        (
423            matched_pre_middleware_idxs,
424            matched_route_idxs,
425            matched_post_middleware_idxs,
426            matched_scoped_data_map_idxs,
427        )
428    }
429}
430
431impl<B, E> Debug for Router<B, E> {
432    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
433        write!(
434            f,
435            "{{ Pre-Middlewares: {:?}, Routes: {:?}, Post-Middlewares: {:?}, ScopedDataMaps: {:?}, ErrHandler: {:?}, ShouldGenReqInfo: {:?} }}",
436            self.pre_middlewares,
437            self.routes,
438            self.post_middlewares,
439            self.scoped_data_maps,
440            self.err_handler.is_some(),
441            self.should_gen_req_info
442        )
443    }
444}