covert_framework/
method_router.rs

1use std::{collections::HashMap, future::Future, pin::Pin, task::Poll};
2
3use covert_types::auth::AuthPolicy;
4use covert_types::error::ApiError;
5use covert_types::request::{Operation, Request};
6use covert_types::response::Response;
7use tower::{util::BoxCloneService, Service};
8use tower::{Layer, ServiceExt};
9
10use covert_types::state::StorageState;
11
12use super::handler::Handler;
13
14#[derive(Debug, Clone)]
15pub struct Route {
16    handler: BoxCloneService<Request, Response, ApiError>,
17    config: RouteConfig,
18}
19
20#[derive(Debug, Clone)]
21pub struct RouteConfig {
22    pub policy: AuthPolicy,
23    pub state: Vec<StorageState>,
24}
25
26impl RouteConfig {
27    #[must_use]
28    pub fn unauthenticated() -> Self {
29        Self {
30            policy: AuthPolicy::Unauthenticated,
31            ..Default::default()
32        }
33    }
34}
35
36impl Default for RouteConfig {
37    fn default() -> Self {
38        Self {
39            policy: AuthPolicy::Authenticated,
40            state: vec![StorageState::Unsealed],
41        }
42    }
43}
44
45impl Route {
46    #[must_use]
47    pub fn new(handler: BoxCloneService<Request, Response, ApiError>, config: RouteConfig) -> Self {
48        Self { handler, config }
49    }
50}
51
52impl Service<Request> for Route {
53    type Response = Response;
54
55    type Error = ApiError;
56
57    type Future =
58        Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
59
60    fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
61        self.handler.poll_ready(cx)
62    }
63
64    fn call(&mut self, req: Request) -> Self::Future {
65        let state = req
66            .extensions
67            .get::<StorageState>()
68            .expect("the storage should always have a state");
69        if !self.config.state.contains(state) {
70            let state = *state;
71            return Box::pin(async move { Err(ApiError::invalid_state(state)) });
72        }
73
74        let Some(policy) = req.extensions.get::<AuthPolicy>() else {
75            return Box::pin(async { Err(ApiError::unauthorized()) });
76        };
77        let auth = match self.config.policy {
78            AuthPolicy::Authenticated => *policy == AuthPolicy::Authenticated,
79            AuthPolicy::Unauthenticated => true,
80        };
81        if !auth {
82            return Box::pin(async { Err(ApiError::unauthorized()) });
83        }
84
85        self.handler.call(req)
86    }
87}
88
89#[derive(Debug, Clone)]
90pub struct MethodRouter {
91    routes: HashMap<Operation, Route>,
92}
93
94impl Default for MethodRouter {
95    fn default() -> Self {
96        Self::new()
97    }
98}
99
100macro_rules! chained_handlers {
101    ($operation:ident, $method:ident, $method_with_config:ident) => {
102        #[must_use]
103        pub fn $method<H, T>(mut self, handler: H) -> Self
104        where
105            H: Handler<T>,
106            T: Send + 'static,
107        {
108            let route = handler.into_route(RouteConfig::default());
109            self.routes.insert(Operation::$operation, route);
110            self
111        }
112
113        #[must_use]
114        pub fn $method_with_config<H, T>(mut self, handler: H, config: RouteConfig) -> Self
115        where
116            H: Handler<T>,
117            T: Send + 'static,
118        {
119            let route = handler.into_route(config);
120            self.routes.insert(Operation::$operation, route);
121            self
122        }
123    };
124}
125
126macro_rules! top_level_handlers {
127    ($operation:ident, $method:ident, $method_with_config:ident) => {
128        #[must_use]
129        pub fn $method<H, T>(handler: H) -> MethodRouter
130        where
131            H: Handler<T>,
132            T: Send + 'static,
133        {
134            MethodRouter::new().on(Operation::$operation, handler, RouteConfig::default())
135        }
136
137        #[must_use]
138        pub fn $method_with_config<H, T>(handler: H, config: RouteConfig) -> MethodRouter
139        where
140            H: Handler<T>,
141            T: Send + 'static,
142        {
143            MethodRouter::new().on(Operation::$operation, handler, config)
144        }
145    };
146}
147
148top_level_handlers!(Create, create, create_with_config);
149top_level_handlers!(Read, read, read_with_config);
150top_level_handlers!(Update, update, update_with_config);
151top_level_handlers!(Delete, delete, delete_with_config);
152top_level_handlers!(Revoke, revoke, revoke_with_config);
153top_level_handlers!(Renew, renew, renew_with_config);
154
155impl MethodRouter {
156    #[must_use]
157    pub fn new() -> Self {
158        Self {
159            routes: HashMap::default(),
160        }
161    }
162
163    chained_handlers!(Create, create, create_with_config);
164    chained_handlers!(Read, read, read_with_config);
165    chained_handlers!(Update, update, update_with_config);
166    chained_handlers!(Delete, delete, delete_with_config);
167    chained_handlers!(Revoke, revoke, revoke_with_config);
168    chained_handlers!(Renew, renew, renew_with_config);
169
170    #[must_use]
171    pub fn on<H, T>(mut self, operation: Operation, handler: H, config: RouteConfig) -> Self
172    where
173        H: Handler<T>,
174        T: Send + 'static,
175    {
176        let route = handler.into_route(config);
177        self.routes.insert(operation, route);
178        self
179    }
180
181    #[must_use]
182    pub fn layer<L>(self, layer: L) -> Self
183    where
184        L: Layer<Route>,
185        L::Service:
186            Service<Request, Error = ApiError, Response = Response> + Clone + Send + 'static,
187        <L::Service as Service<Request>>::Future: Send + 'static,
188    {
189        let routes = self
190            .routes
191            .into_iter()
192            .map(|(op, route)| {
193                let config = route.config.clone();
194                let svc = layer.layer(route);
195                let svc = BoxCloneService::new(svc);
196                let route = Route::new(svc, config);
197                (op, route)
198            })
199            .collect();
200
201        Self { routes }
202    }
203}
204
205impl Service<Request> for MethodRouter {
206    type Response = Response;
207
208    type Error = ApiError;
209
210    type Future = Pin<Box<dyn Future<Output = Result<Response, ApiError>> + Send + 'static>>;
211
212    fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
213        Poll::Ready(Ok(()))
214    }
215
216    fn call(&mut self, req: Request) -> Self::Future {
217        let route = self.routes.get(&req.operation).map(Clone::clone);
218
219        Box::pin(async move {
220            match route {
221                Some(route) => route.oneshot(req).await,
222                None => Err(ApiError::not_found()),
223            }
224        })
225    }
226}