1use futures::future::{self, Future};
17use hyper::service::Service;
18use hyper::{Body, Method, Request, Response, StatusCode};
19use std::collections::hash_map::DefaultHasher;
20use std::hash::{Hash, Hasher};
21use std::pin::Pin;
22use std::sync::Arc;
23use std::task::{Context, Poll};
24
25lazy_static! {
26 static ref WILDCARD_HASH: u64 = calculate_hash(&"*");
27 static ref WILDCARD_STOP_HASH: u64 = calculate_hash(&"**");
28}
29
30pub type ResponseFuture =
31 Pin<Box<dyn Future<Output = Result<Response<Body>, hyper::Error>> + Send>>;
32
33pub trait Handler {
34 fn get(&self, _req: Request<Body>) -> ResponseFuture {
35 not_found()
36 }
37
38 fn post(&self, _req: Request<Body>) -> ResponseFuture {
39 not_found()
40 }
41
42 fn put(&self, _req: Request<Body>) -> ResponseFuture {
43 not_found()
44 }
45
46 fn patch(&self, _req: Request<Body>) -> ResponseFuture {
47 not_found()
48 }
49
50 fn delete(&self, _req: Request<Body>) -> ResponseFuture {
51 not_found()
52 }
53
54 fn head(&self, _req: Request<Body>) -> ResponseFuture {
55 not_found()
56 }
57
58 fn options(&self, _req: Request<Body>) -> ResponseFuture {
59 not_found()
60 }
61
62 fn trace(&self, _req: Request<Body>) -> ResponseFuture {
63 not_found()
64 }
65
66 fn connect(&self, _req: Request<Body>) -> ResponseFuture {
67 not_found()
68 }
69
70 fn call(
71 &self,
72 req: Request<Body>,
73 mut _handlers: Box<dyn Iterator<Item = HandlerObj>>,
74 ) -> ResponseFuture {
75 match *req.method() {
76 Method::GET => self.get(req),
77 Method::POST => self.post(req),
78 Method::PUT => self.put(req),
79 Method::DELETE => self.delete(req),
80 Method::PATCH => self.patch(req),
81 Method::OPTIONS => self.options(req),
82 Method::CONNECT => self.connect(req),
83 Method::TRACE => self.trace(req),
84 Method::HEAD => self.head(req),
85 _ => not_found(),
86 }
87 }
88}
89
90#[derive(Clone, thiserror::Error, Eq, Debug, PartialEq, Serialize, Deserialize)]
91pub enum RouterError {
92 #[error("Route {0} already exists")]
93 RouteAlreadyExists(String),
94 #[error("Route {0} not found")]
95 RouteNotFound(String),
96 #[error("Route value not found for {0}")]
97 NoValue(String),
98}
99
100#[derive(Clone)]
101pub struct Router {
102 nodes: Vec<Node>,
103}
104
105#[derive(Debug, Clone, Copy)]
106struct NodeId(usize);
107
108const MAX_CHILDREN: usize = 16;
109
110pub type HandlerObj = Arc<dyn Handler + Send + Sync>;
111
112#[derive(Clone)]
113pub struct Node {
114 key: u64,
115 value: Option<HandlerObj>,
116 children: [NodeId; MAX_CHILDREN],
117 children_count: usize,
118 mws: Option<Vec<HandlerObj>>,
119}
120
121impl Router {
122 pub fn new() -> Router {
123 let root = Node::new(calculate_hash(&""), None);
124 let mut nodes = vec![];
125 nodes.push(root);
126 Router { nodes }
127 }
128
129 pub fn add_middleware(&mut self, mw: HandlerObj) {
130 self.node_mut(NodeId(0)).add_middleware(mw);
131 }
132
133 fn root(&self) -> NodeId {
134 NodeId(0)
135 }
136
137 fn node(&self, id: NodeId) -> &Node {
138 &self.nodes[id.0]
139 }
140
141 fn node_mut(&mut self, id: NodeId) -> &mut Node {
142 &mut self.nodes[id.0]
143 }
144
145 fn find(&self, parent: NodeId, key: u64) -> Option<NodeId> {
146 let node = self.node(parent);
147 node.children
148 .iter()
149 .find(|&id| {
150 let node_key = self.node(*id).key;
151 node_key == key || node_key == *WILDCARD_HASH || node_key == *WILDCARD_STOP_HASH
152 })
153 .cloned()
154 }
155
156 fn add_empty_node(&mut self, parent: NodeId, key: u64) -> NodeId {
157 let id = NodeId(self.nodes.len());
158 self.nodes.push(Node::new(key, None));
159 self.node_mut(parent).add_child(id);
160 id
161 }
162
163 pub fn add_route(
164 &mut self,
165 route: &'static str,
166 value: HandlerObj,
167 ) -> Result<&mut Node, RouterError> {
168 let keys = generate_path(route);
169 let mut node_id = self.root();
170 for key in keys {
171 node_id = self
172 .find(node_id, key)
173 .unwrap_or_else(|| self.add_empty_node(node_id, key));
174 }
175 match self.node(node_id).value() {
176 None => {
177 let node = self.node_mut(node_id);
178 node.set_value(value);
179 Ok(node)
180 }
181 Some(_) => Err(RouterError::RouteAlreadyExists(route.to_string())),
182 }
183 }
184
185 pub fn get(&self, path: &str) -> Result<impl Iterator<Item = HandlerObj>, RouterError> {
186 let keys = generate_path(path);
187 let mut handlers = vec![];
188 let mut node_id = self.root();
189 collect_node_middleware(&mut handlers, self.node(node_id));
190 for key in keys {
191 node_id = self
192 .find(node_id, key)
193 .ok_or(RouterError::RouteNotFound(path.to_string()))?;
194 let node = self.node(node_id);
195 collect_node_middleware(&mut handlers, self.node(node_id));
196 if node.key == *WILDCARD_STOP_HASH {
197 break;
198 }
199 }
200
201 if let Some(h) = self.node(node_id).value() {
202 handlers.push(h);
203 Ok(handlers.into_iter())
204 } else {
205 Err(RouterError::NoValue(path.to_string()))
206 }
207 }
208}
209
210impl Service<Request<Body>> for Router {
211 type Response = Response<Body>;
212 type Error = hyper::Error;
213 type Future = ResponseFuture;
214
215 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
216 Poll::Ready(Ok(()))
217 }
218
219 fn call(&mut self, req: Request<Body>) -> Self::Future {
220 match self.get(req.uri().path()) {
221 Err(_) => not_found(),
222 Ok(mut handlers) => match handlers.next() {
223 None => not_found(),
224 Some(h) => h.call(req, Box::new(handlers)),
225 },
226 }
227 }
228}
229
230impl Node {
231 fn new(key: u64, value: Option<HandlerObj>) -> Node {
232 Node {
233 key,
234 value,
235 children: [NodeId(0); MAX_CHILDREN],
236 children_count: 0,
237 mws: None,
238 }
239 }
240
241 pub fn add_middleware(&mut self, mw: HandlerObj) -> &mut Node {
242 if self.mws.is_none() {
243 self.mws = Some(vec![]);
244 }
245 if let Some(ref mut mws) = self.mws {
246 mws.push(mw.clone());
247 }
248 self
249 }
250
251 fn value(&self) -> Option<HandlerObj> {
252 match &self.value {
253 None => None,
254 Some(v) => Some(v.clone()),
255 }
256 }
257
258 fn set_value(&mut self, value: HandlerObj) {
259 self.value = Some(value);
260 }
261
262 fn add_child(&mut self, child_id: NodeId) {
263 if self.children_count == MAX_CHILDREN {
264 panic!("Can't add a route, children limit exceeded");
265 }
266 self.children[self.children_count] = child_id;
267 self.children_count += 1;
268 }
269}
270
271pub fn not_found() -> ResponseFuture {
272 let mut response = Response::new(Body::empty());
273 *response.status_mut() = StatusCode::NOT_FOUND;
274 Box::pin(future::ok(response))
275}
276
277fn calculate_hash<T: Hash>(t: &T) -> u64 {
278 let mut s = DefaultHasher::new();
279 t.hash(&mut s);
280 s.finish()
281}
282
283fn generate_path(route: &str) -> Vec<u64> {
284 route
285 .split('/')
286 .skip(1)
287 .map(|path| calculate_hash(&path))
288 .collect()
289}
290
291fn collect_node_middleware(handlers: &mut Vec<HandlerObj>, node: &Node) {
292 if let Some(ref mws) = node.mws {
293 for mw in mws {
294 handlers.push(mw.clone());
295 }
296 }
297}
298
299#[cfg(test)]
300mod tests {
301
302 use super::*;
303 use futures::executor::block_on;
304
305 struct HandlerImpl(u16);
306
307 impl Handler for HandlerImpl {
308 fn get(&self, _req: Request<Body>) -> ResponseFuture {
309 let code = self.0;
310 Box::pin(async move {
311 let res = Response::builder()
312 .status(code)
313 .body(Body::default())
314 .unwrap();
315 Ok(res)
316 })
317 }
318 }
319
320 #[test]
321 fn test_add_route() {
322 let mut routes = Router::new();
323 let h1 = Arc::new(HandlerImpl(1));
324 let h2 = Arc::new(HandlerImpl(2));
325 let h3 = Arc::new(HandlerImpl(3));
326 routes.add_route("/v1/users", h1.clone()).unwrap();
327 assert!(routes.add_route("/v1/users", h2.clone()).is_err());
328 routes.add_route("/v1/users/xxx", h3.clone()).unwrap();
329 routes.add_route("/v1/users/xxx/yyy", h3.clone()).unwrap();
330 routes.add_route("/v1/zzz/*", h3.clone()).unwrap();
331 assert!(routes.add_route("/v1/zzz/ccc", h2.clone()).is_err());
332 routes
333 .add_route("/v1/zzz/*/zzz", Arc::new(HandlerImpl(6)))
334 .unwrap();
335 }
336
337 #[test]
338 fn test_get() {
339 let mut routes = Router::new();
340 routes
341 .add_route("/v1/users", Arc::new(HandlerImpl(101)))
342 .unwrap();
343 routes
344 .add_route("/v1/users/xxx", Arc::new(HandlerImpl(103)))
345 .unwrap();
346 routes
347 .add_route("/v1/users/xxx/yyy", Arc::new(HandlerImpl(103)))
348 .unwrap();
349 routes
350 .add_route("/v1/zzz/*", Arc::new(HandlerImpl(103)))
351 .unwrap();
352 routes
353 .add_route("/v1/zzz/*/zzz", Arc::new(HandlerImpl(106)))
354 .unwrap();
355
356 let call_handler = |url| {
357 let task = async {
358 let resp = routes
359 .get(url)
360 .unwrap()
361 .next()
362 .unwrap()
363 .get(Request::new(Body::default()))
364 .await
365 .unwrap();
366 resp.status().as_u16()
367 };
368 block_on(task)
369 };
370
371 assert_eq!(call_handler("/v1/users"), 101);
372 assert_eq!(call_handler("/v1/users/xxx"), 103);
373 assert!(routes.get("/v1/users/yyy").is_err());
374 assert_eq!(call_handler("/v1/users/xxx/yyy"), 103);
375 assert!(routes.get("/v1/zzz").is_err());
376 assert_eq!(call_handler("/v1/zzz/1"), 103);
377 assert_eq!(call_handler("/v1/zzz/2"), 103);
378 assert_eq!(call_handler("/v1/zzz/2/zzz"), 106);
379 }
380}