hyper_tree_router/
router.rs

1use super::{parameters::UrlParams, route};
2use hyper::{header::CONTENT_LENGTH, service::Service, Body, Request, Response, StatusCode};
3use prefix_tree_map::PrefixTreeMap;
4use std::{
5    future::Future,
6    pin::Pin,
7    task::{Context, Poll},
8};
9
10type HttpResult<T> = Result<T, StatusCode>;
11
12/// Router service to be passed to hyper
13/// passed to Server.bind().serve
14pub struct Router {
15    pub routes: PrefixTreeMap<String, String, route::Route>,
16}
17
18impl Router {
19    pub fn new(routes: PrefixTreeMap<String, String, route::Route>) -> Self {
20        Self { routes }
21    }
22}
23
24pub struct RouterSvc {
25    routes: PrefixTreeMap<String, String, route::Route>,
26    error_handler: fn(StatusCode) -> Response<Body>,
27}
28
29impl RouterSvc {
30    pub fn new(routes: PrefixTreeMap<String, String, route::Route>) -> Self {
31        Self {
32            routes,
33            error_handler: Self::default_error_handler,
34        }
35    }
36
37    pub fn route(&self, request: &Request<Body>) -> HttpResult<(route::Handler, UrlParams)> {
38        let path = request
39            .uri()
40            .path()
41            .split('/')
42            .map(|s| s.to_string())
43            .collect::<Vec<_>>();
44        let mut params = UrlParams::new();
45
46        match self.routes.find_and_capture(&path, &mut params) {
47            Some(route) => match route.handlers.get(request.method()) {
48                Some(handler) => Ok((*handler, params)),
49                _ => Err(StatusCode::METHOD_NOT_ALLOWED),
50            },
51            None => Err(StatusCode::NOT_FOUND),
52        }
53    }
54
55    fn default_error_handler(status_code: StatusCode) -> Response<Body> {
56        let (error_msg, error_code) = match status_code {
57            StatusCode::NOT_FOUND => ("Page Not Found", StatusCode::NOT_FOUND),
58            StatusCode::METHOD_NOT_ALLOWED => {
59                ("method not supported", StatusCode::METHOD_NOT_ALLOWED)
60            }
61            StatusCode::NOT_IMPLEMENTED => ("not implemented", StatusCode::NOT_IMPLEMENTED),
62            _ => ("Internal Server Error", StatusCode::INTERNAL_SERVER_ERROR),
63        };
64        Response::builder()
65            .header(CONTENT_LENGTH, error_msg.len() as u64)
66            .status(error_code)
67            .body(Body::from(error_msg))
68            .expect("Failed to construct a response")
69    }
70}
71
72impl Service<Request<Body>> for RouterSvc {
73    type Response = Response<Body>;
74    type Error = hyper::Error;
75    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
76
77    fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
78        Poll::Ready(Ok(()))
79    }
80    fn call(&mut self, request: Request<Body>) -> Self::Future {
81        let resp = match self.route(&request) {
82            Ok((handler, url_params)) => handler(url_params, request),
83            Err(status_code) => (self.error_handler)(status_code),
84        };
85        Box::pin(async { Ok(resp) })
86    }
87}
88
89impl<T> Service<T> for Router {
90    type Response = RouterSvc;
91    type Error = std::io::Error;
92    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
93
94    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
95        Ok(()).into()
96    }
97
98    fn call(&mut self, _: T) -> Self::Future {
99        let routes = self.routes.clone();
100        let fut = async move { Ok(RouterSvc::new(routes)) };
101        Box::pin(fut)
102    }
103}