1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
use std::{
    fmt,
    sync::atomic::{AtomicU32, Ordering},
};

use fxhash::FxHashMap;
use http_body::Body as HttpBody;
use motore::{BoxCloneService, Service};
use volo::Unwrap;

use super::NamedService;
use crate::{body::Body, context::ServerContext, Request, Response, Status};

#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct RouteId(u32);

impl RouteId {
    fn next() -> Self {
        // `AtomicU64` isn't supported on all platforms
        static ID: AtomicU32 = AtomicU32::new(0);
        let id = ID.fetch_add(1, Ordering::Relaxed);
        if id == u32::MAX {
            panic!("Over `u32::MAX` routes created. If you need this, please file an issue.");
        }
        Self(id)
    }
}

#[derive(Default)]
pub struct Router<B = hyper::Body> {
    routes: FxHashMap<RouteId, BoxCloneService<ServerContext, Request<B>, Response<Body>, Status>>,
    node: matchit::Router<RouteId>,
}

impl<B> Clone for Router<B> {
    fn clone(&self) -> Self {
        Self {
            routes: self.routes.clone(),
            node: self.node.clone(),
        }
    }
}

impl<B> Router<B>
where
    B: HttpBody + 'static,
{
    pub fn new() -> Self {
        Self {
            routes: Default::default(),
            node: Default::default(),
        }
    }

    pub fn add_service<S>(mut self, service: S) -> Self
    where
        S: Service<ServerContext, Request<B>, Response = Response<Body>, Error = Status>
            + NamedService
            + Clone
            + Send
            + Sync
            + 'static,
    {
        let path = format!("/{}/*rest", S::NAME);

        if path.is_empty() {
            panic!("[VOLO] Paths must start with a `/`. Use \"/\" for root routes");
        } else if !path.starts_with('/') {
            panic!("[VOLO] Paths must start with a `/`");
        }

        let id = RouteId::next();

        self.set_node(path, id);

        self.routes.insert(id, BoxCloneService::new(service));

        self
    }

    #[track_caller]
    fn set_node(&mut self, path: String, id: RouteId) {
        if let Err(err) = self.node.insert(path, id) {
            panic!("[VOLO] Invalid route: {err}");
        }
    }
}

impl<B> Service<ServerContext, Request<B>> for Router<B>
where
    B: HttpBody + Send,
{
    type Response = Response<Body>;
    type Error = Status;

    async fn call<'s, 'cx>(
        &'s self,
        cx: &'cx mut ServerContext,
        req: Request<B>,
    ) -> Result<Self::Response, Self::Error> {
        let path = cx.rpc_info.method.as_ref().unwrap();
        match self.node.at(path) {
            Ok(match_) => {
                let id = match_.value;
                let route = self.routes.get(id).volo_unwrap().clone();
                route.call(cx, req).await
            }
            Err(err) => Err(Status::unimplemented(err.to_string())),
        }
    }
}

impl<B> fmt::Debug for Router<B> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("Router")
            .field("routes", &self.routes)
            .finish()
    }
}