1#![warn(rust_2018_idioms)]
2
3#[macro_use]
4extern crate tracing;
5
6use std::collections::hash_map::{Entry, HashMap};
7
8use conduit::{box_error, Handler, HandlerResult, Method, RequestExt};
9use route_recognizer::{Match, Params, Router};
10
11#[derive(Default)]
12pub struct RouteBuilder {
13 routers: HashMap<Method, Router<WrappedHandler>>,
14}
15
16#[derive(Clone, Copy)]
17pub struct RoutePattern(&'static str);
18
19impl RoutePattern {
20 pub fn pattern(&self) -> &str {
21 self.0
22 }
23}
24
25struct WrappedHandler {
26 pattern: RoutePattern,
27 handler: Box<dyn Handler>,
28}
29
30impl conduit::Handler for WrappedHandler {
31 fn call(&self, request: &mut dyn RequestExt) -> HandlerResult {
32 self.handler.call(request)
33 }
34}
35
36#[derive(Debug, thiserror::Error)]
37pub enum RouterError {
38 #[error("Invalid method")]
39 UnknownMethod,
40 #[error("Path not found")]
41 PathNotFound,
42}
43
44impl RouteBuilder {
45 pub fn new() -> Self {
46 Self {
47 routers: HashMap::new(),
48 }
49 }
50
51 #[instrument(level = "trace", skip(self))]
52 fn recognize<'a>(
53 &'a self,
54 method: &Method,
55 path: &str,
56 ) -> Result<Match<&WrappedHandler>, RouterError> {
57 match self.routers.get(method) {
58 Some(router) => router.recognize(path).or(Err(RouterError::PathNotFound)),
59 None => Err(RouterError::UnknownMethod),
60 }
61 }
62
63 #[instrument(level = "trace", skip(self, handler))]
64 pub fn map<H: Handler>(
65 &mut self,
66 method: Method,
67 pattern: &'static str,
68 handler: H,
69 ) -> &mut Self {
70 {
71 let router = match self.routers.entry(method) {
72 Entry::Occupied(e) => e.into_mut(),
73 Entry::Vacant(e) => e.insert(Router::new()),
74 };
75 let wrapped_handler = WrappedHandler {
76 pattern: RoutePattern(pattern),
77 handler: Box::new(handler),
78 };
79 router.add(pattern, wrapped_handler);
80 }
81 self
82 }
83
84 pub fn get<H: Handler>(&mut self, pattern: &'static str, handler: H) -> &mut Self {
85 self.map(Method::GET, pattern, handler)
86 }
87
88 pub fn post<H: Handler>(&mut self, pattern: &'static str, handler: H) -> &mut Self {
89 self.map(Method::POST, pattern, handler)
90 }
91
92 pub fn put<H: Handler>(&mut self, pattern: &'static str, handler: H) -> &mut Self {
93 self.map(Method::PUT, pattern, handler)
94 }
95
96 pub fn delete<H: Handler>(&mut self, pattern: &'static str, handler: H) -> &mut Self {
97 self.map(Method::DELETE, pattern, handler)
98 }
99
100 pub fn head<H: Handler>(&mut self, pattern: &'static str, handler: H) -> &mut Self {
101 self.map(Method::HEAD, pattern, handler)
102 }
103}
104
105impl conduit::Handler for RouteBuilder {
106 #[instrument(level = "trace", skip(self, request))]
107 fn call(&self, request: &mut dyn RequestExt) -> HandlerResult {
108 let mut m = {
109 let method = request.method();
110 let path = request.path();
111
112 match self.recognize(&method, path) {
113 Ok(m) => m,
114 Err(e) => {
115 info!("{}", e);
116 return Err(box_error(e));
117 }
118 }
119 };
120
121 let mut params = Params::new();
124 std::mem::swap(m.params_mut(), &mut params);
125
126 let pattern = m.handler().pattern;
127 debug!(pattern = pattern.0, "matching route handler found");
128
129 {
130 let extensions = request.mut_extensions();
131 extensions.insert(pattern);
132 extensions.insert(params);
133 }
134
135 let span = trace_span!("handler", pattern = pattern.0);
136 span.in_scope(|| m.handler().call(request))
137 }
138}
139
140pub trait RequestParams<'a> {
141 fn params(self) -> &'a Params;
142}
143
144impl<'a> RequestParams<'a> for &'a (dyn RequestExt + 'a) {
145 fn params(self) -> &'a Params {
146 self.extensions().get::<Params>().expect("Missing params")
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::{RequestParams, RouteBuilder, RoutePattern};
153
154 use conduit::{Body, Handler, Method, Response, StatusCode};
155 use conduit_test::{MockRequest, ResponseExt};
156
157 lazy_static::lazy_static! {
158 static ref TRACING: () = {
159 tracing_subscriber::FmtSubscriber::builder()
160 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
161 .with_span_events(tracing_subscriber::fmt::format::FmtSpan::FULL)
162 .with_test_writer()
163 .init();
164 };
165 }
166
167 #[test]
168 fn basic_get() {
169 lazy_static::initialize(&TRACING);
170
171 let router = test_router();
172 let mut req = MockRequest::new(Method::GET, "/posts/1");
173 let res = router.call(&mut req).expect("No response");
174
175 assert_eq!(res.status(), StatusCode::OK);
176 assert_eq!(*res.into_cow(), b"1, GET, /posts/:id"[..]);
177 }
178
179 #[test]
180 fn basic_post() {
181 lazy_static::initialize(&TRACING);
182
183 let router = test_router();
184 let mut req = MockRequest::new(Method::POST, "/posts/10");
185 let res = router.call(&mut req).expect("No response");
186
187 assert_eq!(res.status(), StatusCode::OK);
188 assert_eq!(*res.into_cow(), b"10, POST, /posts/:id"[..]);
189 }
190
191 #[test]
192 fn path_not_found() {
193 lazy_static::initialize(&TRACING);
194
195 let router = test_router();
196 let mut req = MockRequest::new(Method::POST, "/nonexistent");
197 let err = router.call(&mut req).err().unwrap();
198
199 assert_eq!(err.to_string(), "Path not found");
200 }
201
202 #[test]
203 fn unknown_method() {
204 lazy_static::initialize(&TRACING);
205
206 let router = test_router();
207 let mut req = MockRequest::new(Method::DELETE, "/posts/1");
208 let err = router.call(&mut req).err().unwrap();
209
210 assert_eq!(err.to_string(), "Invalid method");
211 }
212
213 #[test]
214 fn catch_all() {
215 lazy_static::initialize(&TRACING);
216
217 let mut router = RouteBuilder::new();
218 router.get("/*", test_handler);
219
220 let mut req = MockRequest::new(Method::GET, "/foo");
221 let res = router.call(&mut req).expect("No response");
222 assert_eq!(res.status(), StatusCode::OK);
223 assert_eq!(*res.into_cow(), b", GET, /*"[..]);
224 }
225
226 fn test_router() -> RouteBuilder {
227 let mut router = RouteBuilder::new();
228 router.post("/posts/:id", test_handler);
229 router.get("/posts/:id", test_handler);
230 router
231 }
232
233 fn test_handler(req: &mut dyn conduit::RequestExt) -> conduit::HttpResult {
234 let res = vec![
235 req.params().find("id").unwrap_or("").to_string(),
236 format!("{:?}", req.method()),
237 req.extensions()
238 .get::<RoutePattern>()
239 .unwrap()
240 .pattern()
241 .to_string(),
242 ];
243
244 let bytes = res.join(", ").into_bytes();
245 Response::builder().body(Body::from_vec(bytes))
246 }
247}