covert_framework/
router.rs

1use std::{future::Future, marker::PhantomData, pin::Pin};
2
3use covert_types::{error::ApiError, request::Request, response::Response};
4use tower::{Layer, Service, ServiceExt};
5
6use super::{
7    method_router::{MethodRouter, Route},
8    SyncService,
9};
10
11pub struct Building;
12pub struct Ready;
13
14/// Wrapper around `matchit::Router`
15pub struct Router<Stage = Building> {
16    routes: Vec<(&'static str, MethodRouter)>,
17    router: matchit::Router<MethodRouter>,
18    _marker: PhantomData<Stage>,
19}
20
21impl Default for Router {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl Router {
28    #[must_use]
29    pub fn new() -> Self {
30        Self {
31            routes: Vec::default(),
32            router: matchit::Router::default(),
33            _marker: PhantomData,
34        }
35    }
36
37    #[must_use]
38    pub fn route(mut self, path: &'static str, route: MethodRouter) -> Self {
39        self.routes.push((path, route));
40        self
41    }
42
43    #[must_use]
44    pub fn layer<L>(mut self, layer: L) -> Self
45    where
46        L: Layer<Route>,
47        L::Service:
48            Service<Request, Error = ApiError, Response = Response> + Clone + Send + 'static,
49        <L::Service as Service<Request>>::Future: Send + 'static,
50    {
51        self.routes = self
52            .routes
53            .into_iter()
54            .map(|(path, route)| (path, route.layer(&layer)))
55            .collect();
56        self
57    }
58
59    pub fn build(mut self) -> Router<Ready> {
60        for (path, route) in self.routes.clone() {
61            self.router
62                .insert(path, route)
63                .expect("No path should overlap");
64        }
65        Router::<Ready> {
66            routes: self.routes,
67            router: self.router,
68            _marker: PhantomData,
69        }
70    }
71}
72
73impl Router<Ready> {
74    // TODO: rename to `into_make_service`?
75    pub fn into_service(self) -> SyncService<Request, Response> {
76        SyncService::new(self)
77    }
78}
79
80impl Clone for Router<Ready> {
81    fn clone(&self) -> Self {
82        Self {
83            routes: self.routes.clone(),
84            router: self.router.clone(),
85            _marker: PhantomData,
86        }
87    }
88}
89
90impl Service<Request> for Router<Ready> {
91    type Response = Response;
92
93    type Error = ApiError;
94
95    type Future =
96        Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
97
98    fn poll_ready(
99        &mut self,
100        _cx: &mut std::task::Context<'_>,
101    ) -> std::task::Poll<Result<(), Self::Error>> {
102        std::task::Poll::Ready(Ok(()))
103    }
104
105    fn call(&mut self, mut req: Request) -> Self::Future {
106        let prefixed_path = if req.path.starts_with('/') {
107            req.path.clone()
108        } else {
109            format!("/{}", req.path)
110        };
111        let Ok(matched_router) =  self.router.at(&prefixed_path) else {
112            return Box::pin(async { Err(ApiError::not_found()) });
113        };
114        req.params = matched_router
115            .params
116            .iter()
117            .map(|(_key, val)| val.to_string())
118            .collect();
119        let matched_router = matched_router.value.clone();
120        Box::pin(async move { matched_router.oneshot(req).await })
121    }
122}