ic_pluto/
router.rs

1use std::{collections::HashMap, future::Future, pin::Pin};
2
3use dyn_clone::{clone_trait_object, DynClone};
4use matchit::{Match, Router as MatchRouter};
5
6use crate::{
7    http::{HttpRequest, HttpResponse},
8    method::Method,
9};
10
11/// A container for a handler and a flag indicating whether the handler supports HTTP upgrades.
12#[derive(Clone)]
13pub(crate) struct HandlerContainer {
14    pub(crate) upgrade: bool,
15    pub(crate) handler: Box<dyn Handler>,
16}
17
18/// A router for HTTP requests.
19/// The router is used to register handlers for different HTTP methods and paths.
20#[derive(Clone)]
21pub struct Router {
22    prefix: String,
23    trees: HashMap<Method, MatchRouter<HandlerContainer>>,
24    pub(crate) handle_options: bool,
25    pub(crate) global_options: Option<HandlerContainer>,
26}
27
28impl Router {
29    /// Create a new router.
30    /// The router is used to register handlers for different HTTP methods and paths.
31    /// The router can be used as a handler for a server.
32    /// # Examples
33    ///
34    /// ``` rust
35    /// use pluto::router::Router;
36    ///
37    /// let mut router = Router::new();
38    /// ```
39    pub fn new() -> Self {
40        Self {
41            prefix: String::from(""),
42            trees: HashMap::new(),
43            handle_options: true,
44            global_options: None,
45        }
46    }
47
48    /// Set a prefix for all paths registered on the router.
49    /// # Examples
50    ///
51    /// ``` rust
52    /// use pluto::router::Router;
53    ///
54    /// let mut router = Router::new();
55    /// router.set_global_prefix("/api".to_string());
56    /// ```
57    pub fn set_global_prefix(&mut self, p: String) -> &mut Self {
58        self.prefix = p;
59        self
60    }
61
62    /// Register a handler for a path and method.
63    /// The handler is called for requests with a matching path and method.
64    /// # Examples
65    ///
66    /// ``` rust
67    /// use pluto::router::Router;
68    /// use pluto::http::{HttpRequest, HttpResponse};
69    /// use pluto::method::Method;
70    /// use serde_json::json;
71    /// use std::collections::HashMap;
72    ///
73    /// let mut router = Router::new();
74    /// router.handle("/hello", false, Method::GET, |req: HttpRequest| async move {
75    ///     Ok(HttpResponse {
76    ///         status_code: 200,
77    ///         headers: HashMap::new(),
78    ///         body: json!({
79    ///             "statusCode": 200,
80    ///             "message": "Hello World from GET",
81    ///         })
82    ///         .into(),
83    ///     })
84    /// });
85    /// ```
86    pub fn handle(
87        &mut self,
88        path: &str,
89        upgrade: bool,
90        method: Method,
91        handler: impl Handler + 'static,
92    ) -> &mut Self {
93        if !path.starts_with('/') {
94            panic!("expect path beginning with '/', found: '{}'", path);
95        }
96        let mut global_path = self.prefix.to_owned() + path;
97        if global_path.ends_with("/") {
98            global_path.pop();
99        }
100
101        match self.trees.entry(method).or_default().insert(
102            global_path,
103            HandlerContainer {
104                handler: Box::new(handler),
105                upgrade: upgrade,
106            },
107        ) {
108            Err(err) => panic!("\nERROR: {}\n", err),
109            Ok(_) => {}
110        }
111        self
112    }
113
114    /// Lookup a handler for a path and method.
115    /// The handler is called for requests with a matching path and method.
116    pub(crate) fn lookup<'a>(
117        &'a self,
118        method: Method,
119        path: &'a str,
120    ) -> Result<Match<&HandlerContainer>, String> {
121        if let Some(tree_at_path) = self.trees.get(&method) {
122            if let Ok(match_result) = tree_at_path.at(path) {
123                return Ok(match_result);
124            }
125        }
126
127        if path == "" {
128            return Err(format!("Cannot {} {}", method, "/"));
129        }
130        return Err(format!("Cannot {} {}", method, path));
131    }
132
133    /// Register a handler for GET requests at a path.
134    /// The handler is called for requests with the GET method and a matching path.
135    /// # Examples
136    ///
137    /// ``` rust
138    /// use pluto::router::Router;
139    /// use pluto::http::{HttpRequest, HttpResponse};
140    /// use pluto::method::Method;
141    /// use serde_json::json;
142    /// use std::collections::HashMap;
143    ///
144    /// let mut router = Router::new();
145    /// router.get("/hello", false, |req: HttpRequest| async move {
146    ///     Ok(HttpResponse {
147    ///         status_code: 200,
148    ///         headers: HashMap::new(),
149    ///         body: json!({
150    ///             "statusCode": 200,
151    ///             "message": "Hello World from GET",
152    ///         })
153    ///         .into(),
154    ///     })
155    /// });
156    /// ```
157    pub fn get(&mut self, path: &str, upgrade: bool, handler: impl Handler + 'static) -> &mut Self {
158        self.handle(path, upgrade, Method::GET, handler)
159    }
160
161    /// Register a handler for HEAD requests at a path.
162    /// The handler is called for requests with the HEAD method and a matching path.
163    /// # Examples
164    ///
165    /// ``` rust
166    /// use pluto::router::Router;
167    /// use pluto::http::{HttpRequest, HttpResponse};
168    /// use pluto::method::Method;
169    /// use serde_json::json;
170    /// use std::collections::HashMap;
171    ///
172    /// let mut router = Router::new();
173    /// router.head("/hello", false, |req: HttpRequest| async move {
174    ///     Ok(HttpResponse {
175    ///         status_code: 200,
176    ///         headers: HashMap::new(),
177    ///         body: json!({
178    ///             "statusCode": 200,
179    ///             "message": "Hello World from HEAD",
180    ///         })
181    ///         .into(),
182    ///     })
183    /// });
184    /// ```
185    pub fn head(
186        &mut self,
187        path: &str,
188        upgrade: bool,
189        handler: impl Handler + 'static,
190    ) -> &mut Self {
191        self.handle(path, upgrade, Method::HEAD, handler)
192    }
193
194    /// Register a handler for OPTIONS requests at a path.
195    /// The handler is called for requests with the OPTIONS method and a matching path.
196    /// # Examples
197    ///
198    /// ``` rust
199    /// use pluto::router::Router;
200    /// use pluto::http::{HttpRequest, HttpResponse};
201    /// use pluto::method::Method;
202    /// use serde_json::json;
203    /// use std::collections::HashMap;
204    ///
205    /// let mut router = Router::new();
206    /// router.options("/hello", false, |req: HttpRequest| async move {
207    ///     Ok(HttpResponse {
208    ///         status_code: 200,
209    ///         headers: HashMap::new(),
210    ///         body: json!({
211    ///             "statusCode": 200,
212    ///             "message": "Hello World from OPTIONS",
213    ///         })
214    ///         .into(),
215    ///     })
216    /// });
217    /// ```
218    pub fn options(
219        &mut self,
220        path: &str,
221        upgrade: bool,
222        handler: impl Handler + 'static,
223    ) -> &mut Self {
224        self.handle(path, upgrade, Method::OPTIONS, handler)
225    }
226
227    /// Register a handler for POST requests at a path.
228    /// The handler is called for requests with the POST method and a matching path.
229    /// # Examples
230    ///
231    /// ``` rust
232    /// use pluto::router::Router;
233    /// use pluto::http::{HttpRequest, HttpResponse};
234    /// use pluto::method::Method;
235    /// use serde_json::json;
236    /// use std::collections::HashMap;
237    ///
238    /// let mut router = Router::new();
239    /// router.post("/hello", false, |req: HttpRequest| async move {
240    ///     Ok(HttpResponse {
241    ///         status_code: 200,
242    ///         headers: HashMap::new(),
243    ///         body: json!({
244    ///             "statusCode": 200,
245    ///             "message": "Hello World from POST",
246    ///         })
247    ///         .into(),
248    ///     })
249    /// });
250    /// ```
251    pub fn post(
252        &mut self,
253        path: &str,
254        upgrade: bool,
255        handler: impl Handler + 'static,
256    ) -> &mut Self {
257        self.handle(path, upgrade, Method::POST, handler)
258    }
259
260    /// Register a handler for PUT requests at a path.
261    /// The handler is called for requests with the PUT method and a matching path.
262    /// # Examples
263    ///
264    /// ``` rust
265    /// use pluto::router::Router;
266    /// use pluto::http::{HttpRequest, HttpResponse};
267    /// use pluto::method::Method;
268    /// use serde_json::json;
269    /// use std::collections::HashMap;
270    ///
271    /// let mut router = Router::new();
272    /// router.put("/hello", false, |req: HttpRequest| async move {
273    ///     Ok(HttpResponse {
274    ///         status_code: 200,
275    ///         headers: HashMap::new(),
276    ///         body: json!({
277    ///             "statusCode": 200,
278    ///             "message": "Hello World from PUT",
279    ///         })
280    ///         .into(),
281    ///     })
282    /// });
283    /// ```
284    pub fn put(&mut self, path: &str, upgrade: bool, handler: impl Handler + 'static) -> &mut Self {
285        self.handle(path, upgrade, Method::PUT, handler)
286    }
287
288    /// Register a handler for PATCH requests at a path.
289    /// The handler is called for requests with the PATCH method and a matching path.
290    /// # Examples
291    ///
292    /// ``` rust
293    /// use pluto::router::Router;
294    /// use pluto::http::{HttpRequest, HttpResponse};
295    /// use pluto::method::Method;
296    /// use serde_json::json;
297    /// use std::collections::HashMap;
298    ///
299    /// let mut router = Router::new();
300    /// router.patch("/hello", false, |req: HttpRequest| async move {
301    ///     Ok(HttpResponse {
302    ///         status_code: 200,
303    ///         headers: HashMap::new(),
304    ///         body: json!({
305    ///             "statusCode": 200,
306    ///             "message": "Hello World from PATCH",
307    ///         })
308    ///         .into(),
309    ///     })
310    /// });
311    /// ```
312    pub fn patch(
313        &mut self,
314        path: &str,
315        upgrade: bool,
316        handler: impl Handler + 'static,
317    ) -> &mut Self {
318        self.handle(path, upgrade, Method::PATCH, handler)
319    }
320
321    /// Register a handler for DELETE requests at a path.
322    /// The handler is called for requests with the DELETE method and a matching path.
323    /// # Examples
324    ///
325    /// ``` rust
326    /// use pluto::router::Router;
327    /// use pluto::http::{HttpRequest, HttpResponse};
328    /// use pluto::method::Method;
329    /// use serde_json::json;
330    /// use std::collections::HashMap;
331    ///
332    /// let mut router = Router::new();
333    /// router.delete("/hello", false, |req: HttpRequest| async move {
334    ///     Ok(HttpResponse {
335    ///         status_code: 200,
336    ///         headers: HashMap::new(),
337    ///         body: json!({
338    ///             "statusCode": 200,
339    ///             "message": "Hello World from DELETE",
340    ///         })
341    ///         .into(),
342    ///     })
343    /// });
344    /// ```
345    pub fn delete(
346        &mut self,
347        path: &str,
348        upgrade: bool,
349        handler: impl Handler + 'static,
350    ) -> &mut Self {
351        self.handle(path, upgrade, Method::DELETE, handler)
352    }
353
354    /// Allow the router to handle OPTIONS requests.
355    /// If enabled, the router will automatically respond to OPTIONS requests with the allowed methods for a path.
356    /// If disabled, the router will respond to OPTIONS requests with a 404.
357    /// # Examples
358    ///
359    /// ``` rust
360    /// use pluto::router::Router;
361    ///
362    /// let mut router = Router::new();
363    /// router.handle_options(true);
364    /// ```
365    pub fn handle_options(&mut self, handle: bool) {
366        self.handle_options = handle;
367    }
368
369    /// Register a default handler for not registered requests.
370    /// The handler is called for requests when router can't matching path or method to any handler.
371    /// # Examples
372    ///
373    /// ``` rust
374    /// use pluto::router::Router;
375    /// use pluto::http::{HttpRequest, HttpResponse};
376    /// use serde_json::json;
377    /// use std::collections::HashMap;
378    ///
379    /// let mut router = Router::new();
380    /// router.global_options(false, |req: HttpRequest| async move {
381    ///    Ok(HttpResponse {
382    ///         status_code: 404,
383    ///         headers: HashMap::new(),
384    ///         body: json!({
385    ///             "statusCode": 404,
386    ///             "message": "Not Found",
387    ///         })
388    ///         .into(),
389    ///     })
390    /// });
391    /// ```
392    pub fn global_options(mut self, upgrade: bool, handler: impl Handler + 'static) -> Self {
393        self.global_options = Some(HandlerContainer {
394            handler: Box::new(handler),
395            upgrade: upgrade,
396        });
397        self
398    }
399
400    /// Get the allowed methods for a path.
401    /// # Examples
402    ///
403    /// ``` rust
404    /// use pluto::router::Router;
405    /// use pluto::http::{HttpRequest, HttpResponse};
406    /// use serde_json::json;
407    /// use std::collections::HashMap;
408    ///
409    /// let mut router = Router::new();
410    /// router.get("/hello", false, |req: HttpRequest| async move {
411    ///    Ok(HttpResponse {
412    ///         status_code: 200,
413    ///         headers: HashMap::new(),
414    ///         body: json!({
415    ///             "statusCode": 200,
416    ///             "message": "Hello World from GET",
417    ///         })
418    ///         .into(),
419    ///     })
420    /// });
421    /// router.post("/hello", false, |req: HttpRequest| async move {
422    ///   Ok(HttpResponse {
423    ///         status_code: 200,
424    ///         headers: HashMap::new(),
425    ///         body: json!({
426    ///             "statusCode": 200,
427    ///             "message": "Hello World from POST",
428    ///         })
429    ///         .into(),
430    ///     })
431    /// });
432    /// let mut allowed = router.allowed("/hello");
433    /// allowed.sort();
434    /// assert_eq!(allowed, vec!["GET", "OPTIONS", "POST"]);
435    pub fn allowed(&self, path: &str) -> Vec<&str> {
436        let mut allowed = match path {
437            "*" => {
438                let mut allowed = Vec::with_capacity(self.trees.len());
439                for method in self
440                    .trees
441                    .keys()
442                    .filter(|&method| method != Method::OPTIONS)
443                {
444                    allowed.push(method.as_ref());
445                }
446                allowed
447            }
448            _ => self
449                .trees
450                .keys()
451                .filter(|&method| method != Method::OPTIONS)
452                .filter(|&method| {
453                    self.trees
454                        .get(method)
455                        .map(|node| node.at(path).is_ok())
456                        .unwrap_or(false)
457                })
458                .map(AsRef::as_ref)
459                .collect::<Vec<_>>(),
460        };
461
462        if !allowed.is_empty() {
463            allowed.push(Method::OPTIONS.as_ref())
464        }
465
466        allowed
467    }
468}
469
470clone_trait_object!(Handler);
471pub trait Handler: Send + Sync + DynClone {
472    /// Handle a request.
473    /// The handler is called for requests with a matching path and method.
474    fn handle(
475        &self,
476        req: HttpRequest,
477    ) -> Pin<Box<dyn Future<Output = Result<HttpResponse, HttpResponse>> + Send + Sync>>;
478}
479
480impl<F, R> Handler for F
481where
482    F: Fn(HttpRequest) -> R + Send + Sync + DynClone,
483    R: Future<Output = Result<HttpResponse, HttpResponse>> + Send + Sync + 'static,
484{
485    /// Handle a request.
486    /// The handler is called for requests with a matching path and method.
487    fn handle(
488        &self,
489        req: HttpRequest,
490    ) -> Pin<Box<dyn Future<Output = Result<HttpResponse, HttpResponse>> + Send + Sync>> {
491        Box::pin(self(req))
492    }
493}
494
495#[cfg(test)]
496mod test {
497    use serde_json::json;
498
499    use super::*;
500    use crate::http::{HttpRequest, HttpResponse};
501    use crate::method::Method;
502
503    #[test]
504    fn test_router() {
505        let mut router = Router::new();
506        router.get("/hello", false, |_req: HttpRequest| async move {
507            Ok(HttpResponse {
508                status_code: 200,
509                headers: HashMap::new(),
510                body: json!({
511                    "statusCode": 200,
512                    "message": "Hello World from GET",
513                })
514                .into(),
515            })
516        });
517        router.post("/hello", false, |_req: HttpRequest| async move {
518            Ok(HttpResponse {
519                status_code: 200,
520                headers: HashMap::new(),
521                body: json!({
522                    "statusCode": 200,
523                    "message": "Hello World from POST",
524                })
525                .into(),
526            })
527        });
528        router.put("/hello", false, |_req: HttpRequest| async move {
529            Ok(HttpResponse {
530                status_code: 200,
531                headers: HashMap::new(),
532                body: json!({
533                    "statusCode": 200,
534                    "message": "Hello World from PUT",
535                })
536                .into(),
537            })
538        });
539        router.patch("/hello", false, |_req: HttpRequest| async move {
540            Ok(HttpResponse {
541                status_code: 200,
542                headers: HashMap::new(),
543                body: json!({
544                    "statusCode": 200,
545                    "message": "Hello World from PATCH",
546                })
547                .into(),
548            })
549        });
550        router.delete("/hello", false, |_req: HttpRequest| async move {
551            Ok(HttpResponse {
552                status_code: 200,
553                headers: HashMap::new(),
554                body: json!({
555                    "statusCode": 200,
556                    "message": "Hello World from DELETE",
557                })
558                .into(),
559            })
560        });
561        router.head("/hello", false, |_req: HttpRequest| async move {
562            Ok(HttpResponse {
563                status_code: 200,
564                headers: HashMap::new(),
565                body: json!({
566                    "statusCode": 200,
567                    "message": "Hello World from HEAD",
568                })
569                .into(),
570            })
571        });
572        router.options("/hello", false, |_req: HttpRequest| async move {
573            Ok(HttpResponse {
574                status_code: 200,
575                headers: HashMap::new(),
576                body: json!({
577                    "statusCode": 200,
578                    "message": "Hello World from OPTIONS",
579                })
580                .into(),
581            })
582        });
583
584        let mut allowed = router.allowed("/hello");
585        allowed.sort();
586        assert_eq!(
587            allowed,
588            vec![
589                Method::DELETE.as_ref(),
590                Method::GET.as_ref(),
591                Method::HEAD.as_ref(),
592                Method::OPTIONS.as_ref(),
593                Method::PATCH.as_ref(),
594                Method::POST.as_ref(),
595                Method::PUT.as_ref(),
596            ]
597        );
598    }
599
600    #[test]
601    fn test_router_prefix() {
602        let mut router = Router::new();
603        router.set_global_prefix("/api".to_string());
604        router.get("/hello", false, |_req: HttpRequest| async move {
605            Ok(HttpResponse {
606                status_code: 200,
607                headers: HashMap::new(),
608                body: json!({
609                    "statusCode": 200,
610                    "message": "Hello World from GET",
611                })
612                .into(),
613            })
614        });
615        router.post("/hello", false, |_req: HttpRequest| async move {
616            Ok(HttpResponse {
617                status_code: 200,
618                headers: HashMap::new(),
619                body: json!({
620                    "statusCode": 200,
621                    "message": "Hello World from POST",
622                })
623                .into(),
624            })
625        });
626        router.put("/hello", false, |_req: HttpRequest| async move {
627            Ok(HttpResponse {
628                status_code: 200,
629                headers: HashMap::new(),
630                body: json!({
631                    "statusCode": 200,
632                    "message": "Hello World from PUT",
633                })
634                .into(),
635            })
636        });
637        router.patch("/hello", false, |_req: HttpRequest| async move {
638            Ok(HttpResponse {
639                status_code: 200,
640                headers: HashMap::new(),
641                body: json!({
642                    "statusCode": 200,
643                    "message": "Hello World from PATCH",
644                })
645                .into(),
646            })
647        });
648        router.delete("/hello", false, |_req: HttpRequest| async move {
649            Ok(HttpResponse {
650                status_code: 200,
651                headers: HashMap::new(),
652                body: json!({
653                    "statusCode": 200,
654                    "message": "Hello World from DELETE",
655                })
656                .into(),
657            })
658        });
659        router.head("/hello", false, |_req: HttpRequest| async move {
660            Ok(HttpResponse {
661                status_code: 200,
662                headers: HashMap::new(),
663                body: json!({
664                    "statusCode": 200,
665                    "message": "Hello World from HEAD",
666                })
667                .into(),
668            })
669        });
670        router.options("/hello", false, |_req: HttpRequest| async move {
671            Ok(HttpResponse {
672                status_code: 200,
673                headers: HashMap::new(),
674                body: json!({
675                    "statusCode": 200,
676                    "message": "Hello World from OPTIONS",
677                })
678                .into(),
679            })
680        });
681
682        let mut allowed = router.allowed("/api/hello");
683        allowed.sort();
684        assert_eq!(
685            allowed,
686            vec![
687                Method::DELETE.as_ref(),
688                Method::GET.as_ref(),
689                Method::HEAD.as_ref(),
690                Method::OPTIONS.as_ref(),
691                Method::PATCH.as_ref(),
692                Method::POST.as_ref(),
693                Method::PUT.as_ref(),
694            ]
695        );
696    }
697
698    #[tokio::test]
699    async fn test_lookup_works() {
700        let mut router = Router::new();
701        let response = HttpResponse {
702            status_code: 200,
703            headers: HashMap::new(),
704            body: json!({
705                "message": "Hello World from GET",
706            })
707            .into(),
708        };
709        router.get("/hello", false, |_req: HttpRequest| async move {
710            Ok(HttpResponse {
711                status_code: 200,
712                headers: HashMap::new(),
713                body: json!({
714                    "message": "Hello World from GET",
715                })
716                .into(),
717            })
718        });
719
720        let allowed = router.lookup(Method::GET, "/hello").unwrap();
721        let result = allowed
722            .value
723            .handler
724            .handle(
725                crate::http::RawHttpRequest {
726                    method: "GET".to_string(),
727                    url: "http:://localhost:8080/hello".to_string(),
728                    headers: Vec::new(),
729                    body: Vec::new(),
730                }
731                .into(),
732            )
733            .await
734            .unwrap();
735
736        assert_eq!(
737            result,
738            HttpResponse {
739                status_code: 200,
740                headers: HashMap::new(),
741                body: json!({
742                    "message": "Hello World from GET",
743                })
744                .into(),
745            }
746        );
747    }
748}