peta/
router.rs

1use crate::method;
2use crate::request;
3use crate::response;
4
5use hashbrown::HashMap;
6use tokio::prelude::*;
7
8use std::sync::Arc;
9
10/// Abstraction of the return Boxed Future.
11///
12/// # Example
13/// Each function which will be passed to `Router` must return
14/// `ResponseFut`
15/// ```
16/// fn hello(req: Request) -> ResponseFut {
17///   // rest of the code
18/// }
19///
20/// router.get("/", hello)
21/// ```
22pub type ResponseFut = Box<dyn Future<Item = response::Response, Error = ()> + Send + Sync>;
23
24/// Generates map between `path` and `method` which returns `ResponseFut`
25/// supports `*` and `:` operators.
26///
27/// # Example
28///
29/// ```
30/// let mut router = Router::new();
31///
32/// router.get("/", |req: Request| {
33///   // do not forget to return ResponseFut
34/// });
35///
36/// router.post("/home", |req: Request| {
37///   // do not forget to return ResponseFut
38/// });
39///
40/// // It is important to add default route
41/// // can be simple 404 which will be called if nothing found
42/// router.add_default(|req: Request| {});
43///
44/// let router = router.build();
45///
46/// ```
47pub struct Router {
48  get: Node,
49  put: Node,
50  post: Node,
51  head: Node,
52  patch: Node,
53  delete: Node,
54  options: Node,
55  routes: Node,
56  default: Option<StoreFunc>,
57}
58
59impl Router {
60  /// Create instance of Router.
61  ///
62  /// ```
63  /// let router = Router::new();
64  /// ```
65  pub fn new() -> Router {
66    Router {
67      default: None,
68      // for now leave this one in
69      routes: Node::default(),
70      // create separate bucket for each method (not fun) :(
71      get: Node::default(),
72      put: Node::default(),
73      post: Node::default(),
74      head: Node::default(),
75      patch: Node::default(),
76      delete: Node::default(),
77      options: Node::default(),
78    }
79  }
80
81  /// Wrap router in `Arc` to be able to pass it across threads/components.
82  ///
83  /// ```
84  /// let mut router = Router::new();
85  /// // do some things with router variable
86  ///
87  /// let router = router.build();
88  ///
89  /// // now we can simply clone and call any functions from router instance
90  /// let ref_router = router.clone();
91  /// ```
92  pub fn build(self) -> Arc<Self> {
93    Arc::new(self)
94  }
95
96  /// Adds new `path -> method` map to the Router.
97  ///
98  /// ```
99  /// router.add(method::GET, "/", |req: Request| {});
100  /// router.add(method::POST, "/", |req: Request| {});
101  /// // and so on
102  /// ```
103  pub fn add<F>(&mut self, method: &str, path: &'static str, func: F)
104  where
105    F: Fn(request::Request) -> ResponseFut + Send + Sync + 'static,
106  {
107    // use proper enum
108    let mut node = match method {
109      method::GET => &mut self.get,
110      method::PUT => &mut self.put,
111      method::POST => &mut self.post,
112      method::HEAD => &mut self.head,
113      method::PATCH => &mut self.patch,
114      method::DELETE => &mut self.delete,
115      method::OPTIONS => &mut self.options,
116      _ => &mut self.routes,
117    };
118
119    match path {
120      "/" => {
121        // handle / case
122        node.set_func(Box::new(func));
123      }
124      _ => {
125        // handle rest of the cases
126        for seg in path.split('/') {
127          if !seg.is_empty() {
128            let mut seg_arr = seg.chars();
129            // check if path is param
130            if seg_arr.next() == Some(':') {
131              node = node.add_child(":", Some(seg_arr.as_str()));
132              continue;
133            }
134            node = node.add_child(seg, None);
135          }
136        }
137
138        node.set_func(Box::new(func));
139      }
140    }
141  }
142
143  /// Searches for appropriate `method` which is mapped to specific `path`
144  ///
145  /// ```
146  /// let req: request::Request;
147  /// // it will automatically extract path from `req`
148  /// router.find(req)
149  /// ```
150  pub fn find(&self, mut req: request::Request) -> ResponseFut {
151    let mut node = match req.method() {
152      method::GET => &self.get,
153      method::PUT => &self.put,
154      method::POST => &self.post,
155      method::HEAD => &self.head,
156      method::PATCH => &self.patch,
157      method::DELETE => &self.delete,
158      method::OPTIONS => &self.options,
159      _ => &self.routes,
160    };
161
162    // handle / route
163    if req.uri().path() == "/" {
164      return match node.method.as_ref() {
165        Some(v) => (v)(req),
166        None => (self.default.as_ref().unwrap())(req),
167      };
168    }
169
170    let mut params: Vec<(&'static str, String)> = Vec::with_capacity(10);
171
172    for seg in req.uri().path().split('/') {
173      if !seg.is_empty() {
174        if node.children.is_none() {
175          // do default return
176          return (self.default.as_ref().unwrap())(req);
177        }
178
179        let children = node.children.as_ref().unwrap();
180
181        // find proper node with func
182        let found_node = match children.get(seg) {
183          Some(v) => v,
184          None => {
185            match children.get(":") {
186              Some(v) => {
187                params.push((v.param.unwrap(), seg.to_string()));
188                v
189              }
190              // break from this function if we get in star
191              None => match children.get("*") {
192                Some(v) => {
193                  // we need to attache params in here as we may and loop sooner
194                  if !params.is_empty() {
195                    req.set_params(Some(params));
196                  }
197
198                  return (v.method.as_ref().unwrap())(req);
199                }
200                None => {
201                  // execute default if route not found at all
202                  return (self.default.as_ref().unwrap())(req);
203                }
204              },
205            }
206          }
207        };
208
209        node = found_node;
210      }
211    }
212
213    match node.method.as_ref() {
214      Some(v) => {
215        // set params only if it is not empty
216        if !params.is_empty() {
217          req.set_params(Some(params));
218        }
219        (v)(req)
220      }
221      None => (self.default.as_ref().unwrap())(req),
222    }
223  }
224
225  /// Set default function for routes which were not mapped
226  /// can be simple 404 response.
227  ///
228  /// ```
229  /// router.add_default(|req: Request| {});
230  /// ```
231  pub fn add_default<F>(&mut self, func: F)
232  where
233    F: Fn(request::Request) -> ResponseFut + Send + Sync + 'static,
234  {
235    self.default = Some(Box::new(func));
236  }
237}
238
239/// Abstracts `add` method by removing `method::*` param
240impl Router {
241  pub fn get<F>(&mut self, path: &'static str, func: F)
242  where
243    F: Fn(request::Request) -> ResponseFut + Send + Sync + 'static,
244  {
245    self.add(method::GET, path, func)
246  }
247
248  pub fn put<F>(&mut self, path: &'static str, func: F)
249  where
250    F: Fn(request::Request) -> ResponseFut + Send + Sync + 'static,
251  {
252    self.add(method::PUT, path, func)
253  }
254
255  pub fn post<F>(&mut self, path: &'static str, func: F)
256  where
257    F: Fn(request::Request) -> ResponseFut + Send + Sync + 'static,
258  {
259    self.add(method::POST, path, func)
260  }
261
262  pub fn head<F>(&mut self, path: &'static str, func: F)
263  where
264    F: Fn(request::Request) -> ResponseFut + Send + Sync + 'static,
265  {
266    self.add(method::HEAD, path, func)
267  }
268
269  pub fn patch<F>(&mut self, path: &'static str, func: F)
270  where
271    F: Fn(request::Request) -> ResponseFut + Send + Sync + 'static,
272  {
273    self.add(method::PATCH, path, func)
274  }
275
276  pub fn delete<F>(&mut self, path: &'static str, func: F)
277  where
278    F: Fn(request::Request) -> ResponseFut + Send + Sync + 'static,
279  {
280    self.add(method::DELETE, path, func)
281  }
282
283  pub fn options<F>(&mut self, path: &'static str, func: F)
284  where
285    F: Fn(request::Request) -> ResponseFut + Send + Sync + 'static,
286  {
287    self.add(method::OPTIONS, path, func)
288  }
289}
290
291// need to add some docs for Node
292// probably will need to move it out of router component
293type StoreFunc = Box<
294  dyn Fn(request::Request) -> Box<dyn Future<Item = response::Response, Error = ()> + Send + Sync>
295    + Send
296    + Sync,
297>;
298
299struct Node {
300  param: Option<&'static str>,
301  method: Option<StoreFunc>,
302  children: Option<HashMap<&'static str, Node>>,
303}
304
305impl Node {
306  pub fn default() -> Node {
307    Node {
308      param: None,
309      method: None,
310      children: None,
311    }
312  }
313
314  pub fn set_func(&mut self, func: StoreFunc) {
315    self.method = Some(func);
316  }
317
318  pub fn add_child(&mut self, seg: &'static str, param: Option<&'static str>) -> &mut Node {
319    if self.children.is_none() {
320      self.children = Some(HashMap::new())
321    }
322
323    let node_map = self.children.as_mut().unwrap();
324
325    // if key exist then return existing node ref
326    if node_map.contains_key(seg) {
327      return node_map.get_mut(seg).unwrap();
328    }
329
330    // create new if node
331    node_map.insert(
332      seg,
333      Node {
334        param,
335        method: None,
336        children: None,
337      },
338    );
339    // this item is just added
340    node_map.get_mut(seg).unwrap()
341  }
342}
343
344// Need to improve debug printing
345impl std::fmt::Debug for Router {
346  fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
347    write!(
348      f,
349      "Router {{ \n   get: {:#?}, \npost: {:#?} \ndelete:{:#?}\nput:{:#?} \nroutes:{:#?} \n}}",
350      self.get, self.post, self.delete, self.put, self.routes
351    )
352  }
353}
354
355impl std::fmt::Debug for Node {
356  fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
357    write!(
358      f,
359      "Node {{ \n\tchildren: {:#?}, \n\tmethod: {:#?} \n\tparam:{:#?}\n}}",
360      self.children,
361      self.method.is_some(),
362      self.param
363    )
364  }
365}