conduit_router/
lib.rs

1#![warn(rust_2018_idioms)]
2
3#[macro_use]
4extern crate tracing;
5
6use std::collections::hash_map::{Entry, HashMap};
7
8use conduit::{box_error, Handler, HandlerResult, Method, RequestExt};
9use route_recognizer::{Match, Params, Router};
10
11#[derive(Default)]
12pub struct RouteBuilder {
13    routers: HashMap<Method, Router<WrappedHandler>>,
14}
15
16#[derive(Clone, Copy)]
17pub struct RoutePattern(&'static str);
18
19impl RoutePattern {
20    pub fn pattern(&self) -> &str {
21        self.0
22    }
23}
24
25struct WrappedHandler {
26    pattern: RoutePattern,
27    handler: Box<dyn Handler>,
28}
29
30impl conduit::Handler for WrappedHandler {
31    fn call(&self, request: &mut dyn RequestExt) -> HandlerResult {
32        self.handler.call(request)
33    }
34}
35
36#[derive(Debug, thiserror::Error)]
37pub enum RouterError {
38    #[error("Invalid method")]
39    UnknownMethod,
40    #[error("Path not found")]
41    PathNotFound,
42}
43
44impl RouteBuilder {
45    pub fn new() -> Self {
46        Self {
47            routers: HashMap::new(),
48        }
49    }
50
51    #[instrument(level = "trace", skip(self))]
52    fn recognize<'a>(
53        &'a self,
54        method: &Method,
55        path: &str,
56    ) -> Result<Match<&WrappedHandler>, RouterError> {
57        match self.routers.get(method) {
58            Some(router) => router.recognize(path).or(Err(RouterError::PathNotFound)),
59            None => Err(RouterError::UnknownMethod),
60        }
61    }
62
63    #[instrument(level = "trace", skip(self, handler))]
64    pub fn map<H: Handler>(
65        &mut self,
66        method: Method,
67        pattern: &'static str,
68        handler: H,
69    ) -> &mut Self {
70        {
71            let router = match self.routers.entry(method) {
72                Entry::Occupied(e) => e.into_mut(),
73                Entry::Vacant(e) => e.insert(Router::new()),
74            };
75            let wrapped_handler = WrappedHandler {
76                pattern: RoutePattern(pattern),
77                handler: Box::new(handler),
78            };
79            router.add(pattern, wrapped_handler);
80        }
81        self
82    }
83
84    pub fn get<H: Handler>(&mut self, pattern: &'static str, handler: H) -> &mut Self {
85        self.map(Method::GET, pattern, handler)
86    }
87
88    pub fn post<H: Handler>(&mut self, pattern: &'static str, handler: H) -> &mut Self {
89        self.map(Method::POST, pattern, handler)
90    }
91
92    pub fn put<H: Handler>(&mut self, pattern: &'static str, handler: H) -> &mut Self {
93        self.map(Method::PUT, pattern, handler)
94    }
95
96    pub fn delete<H: Handler>(&mut self, pattern: &'static str, handler: H) -> &mut Self {
97        self.map(Method::DELETE, pattern, handler)
98    }
99
100    pub fn head<H: Handler>(&mut self, pattern: &'static str, handler: H) -> &mut Self {
101        self.map(Method::HEAD, pattern, handler)
102    }
103}
104
105impl conduit::Handler for RouteBuilder {
106    #[instrument(level = "trace", skip(self, request))]
107    fn call(&self, request: &mut dyn RequestExt) -> HandlerResult {
108        let mut m = {
109            let method = request.method();
110            let path = request.path();
111
112            match self.recognize(&method, path) {
113                Ok(m) => m,
114                Err(e) => {
115                    info!("{}", e);
116                    return Err(box_error(e));
117                }
118            }
119        };
120
121        // We don't have `pub` access to the fields to destructure `Params`, so swap with an empty
122        // value to avoid an allocation.
123        let mut params = Params::new();
124        std::mem::swap(m.params_mut(), &mut params);
125
126        let pattern = m.handler().pattern;
127        debug!(pattern = pattern.0, "matching route handler found");
128
129        {
130            let extensions = request.mut_extensions();
131            extensions.insert(pattern);
132            extensions.insert(params);
133        }
134
135        let span = trace_span!("handler", pattern = pattern.0);
136        span.in_scope(|| m.handler().call(request))
137    }
138}
139
140pub trait RequestParams<'a> {
141    fn params(self) -> &'a Params;
142}
143
144impl<'a> RequestParams<'a> for &'a (dyn RequestExt + 'a) {
145    fn params(self) -> &'a Params {
146        self.extensions().get::<Params>().expect("Missing params")
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::{RequestParams, RouteBuilder, RoutePattern};
153
154    use conduit::{Body, Handler, Method, Response, StatusCode};
155    use conduit_test::{MockRequest, ResponseExt};
156
157    lazy_static::lazy_static! {
158        static ref TRACING: () = {
159            tracing_subscriber::FmtSubscriber::builder()
160                .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
161                .with_span_events(tracing_subscriber::fmt::format::FmtSpan::FULL)
162                .with_test_writer()
163                .init();
164        };
165    }
166
167    #[test]
168    fn basic_get() {
169        lazy_static::initialize(&TRACING);
170
171        let router = test_router();
172        let mut req = MockRequest::new(Method::GET, "/posts/1");
173        let res = router.call(&mut req).expect("No response");
174
175        assert_eq!(res.status(), StatusCode::OK);
176        assert_eq!(*res.into_cow(), b"1, GET, /posts/:id"[..]);
177    }
178
179    #[test]
180    fn basic_post() {
181        lazy_static::initialize(&TRACING);
182
183        let router = test_router();
184        let mut req = MockRequest::new(Method::POST, "/posts/10");
185        let res = router.call(&mut req).expect("No response");
186
187        assert_eq!(res.status(), StatusCode::OK);
188        assert_eq!(*res.into_cow(), b"10, POST, /posts/:id"[..]);
189    }
190
191    #[test]
192    fn path_not_found() {
193        lazy_static::initialize(&TRACING);
194
195        let router = test_router();
196        let mut req = MockRequest::new(Method::POST, "/nonexistent");
197        let err = router.call(&mut req).err().unwrap();
198
199        assert_eq!(err.to_string(), "Path not found");
200    }
201
202    #[test]
203    fn unknown_method() {
204        lazy_static::initialize(&TRACING);
205
206        let router = test_router();
207        let mut req = MockRequest::new(Method::DELETE, "/posts/1");
208        let err = router.call(&mut req).err().unwrap();
209
210        assert_eq!(err.to_string(), "Invalid method");
211    }
212
213    #[test]
214    fn catch_all() {
215        lazy_static::initialize(&TRACING);
216
217        let mut router = RouteBuilder::new();
218        router.get("/*", test_handler);
219
220        let mut req = MockRequest::new(Method::GET, "/foo");
221        let res = router.call(&mut req).expect("No response");
222        assert_eq!(res.status(), StatusCode::OK);
223        assert_eq!(*res.into_cow(), b", GET, /*"[..]);
224    }
225
226    fn test_router() -> RouteBuilder {
227        let mut router = RouteBuilder::new();
228        router.post("/posts/:id", test_handler);
229        router.get("/posts/:id", test_handler);
230        router
231    }
232
233    fn test_handler(req: &mut dyn conduit::RequestExt) -> conduit::HttpResult {
234        let res = vec![
235            req.params().find("id").unwrap_or("").to_string(),
236            format!("{:?}", req.method()),
237            req.extensions()
238                .get::<RoutePattern>()
239                .unwrap()
240                .pattern()
241                .to_string(),
242        ];
243
244        let bytes = res.join(", ").into_bytes();
245        Response::builder().body(Body::from_vec(bytes))
246    }
247}