1use std::{collections::HashMap, future::Future, pin::Pin, task::Poll};
2
3use covert_types::auth::AuthPolicy;
4use covert_types::error::ApiError;
5use covert_types::request::{Operation, Request};
6use covert_types::response::Response;
7use tower::{util::BoxCloneService, Service};
8use tower::{Layer, ServiceExt};
9
10use covert_types::state::StorageState;
11
12use super::handler::Handler;
13
14#[derive(Debug, Clone)]
15pub struct Route {
16 handler: BoxCloneService<Request, Response, ApiError>,
17 config: RouteConfig,
18}
19
20#[derive(Debug, Clone)]
21pub struct RouteConfig {
22 pub policy: AuthPolicy,
23 pub state: Vec<StorageState>,
24}
25
26impl RouteConfig {
27 #[must_use]
28 pub fn unauthenticated() -> Self {
29 Self {
30 policy: AuthPolicy::Unauthenticated,
31 ..Default::default()
32 }
33 }
34}
35
36impl Default for RouteConfig {
37 fn default() -> Self {
38 Self {
39 policy: AuthPolicy::Authenticated,
40 state: vec![StorageState::Unsealed],
41 }
42 }
43}
44
45impl Route {
46 #[must_use]
47 pub fn new(handler: BoxCloneService<Request, Response, ApiError>, config: RouteConfig) -> Self {
48 Self { handler, config }
49 }
50}
51
52impl Service<Request> for Route {
53 type Response = Response;
54
55 type Error = ApiError;
56
57 type Future =
58 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
59
60 fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
61 self.handler.poll_ready(cx)
62 }
63
64 fn call(&mut self, req: Request) -> Self::Future {
65 let state = req
66 .extensions
67 .get::<StorageState>()
68 .expect("the storage should always have a state");
69 if !self.config.state.contains(state) {
70 let state = *state;
71 return Box::pin(async move { Err(ApiError::invalid_state(state)) });
72 }
73
74 let Some(policy) = req.extensions.get::<AuthPolicy>() else {
75 return Box::pin(async { Err(ApiError::unauthorized()) });
76 };
77 let auth = match self.config.policy {
78 AuthPolicy::Authenticated => *policy == AuthPolicy::Authenticated,
79 AuthPolicy::Unauthenticated => true,
80 };
81 if !auth {
82 return Box::pin(async { Err(ApiError::unauthorized()) });
83 }
84
85 self.handler.call(req)
86 }
87}
88
89#[derive(Debug, Clone)]
90pub struct MethodRouter {
91 routes: HashMap<Operation, Route>,
92}
93
94impl Default for MethodRouter {
95 fn default() -> Self {
96 Self::new()
97 }
98}
99
100macro_rules! chained_handlers {
101 ($operation:ident, $method:ident, $method_with_config:ident) => {
102 #[must_use]
103 pub fn $method<H, T>(mut self, handler: H) -> Self
104 where
105 H: Handler<T>,
106 T: Send + 'static,
107 {
108 let route = handler.into_route(RouteConfig::default());
109 self.routes.insert(Operation::$operation, route);
110 self
111 }
112
113 #[must_use]
114 pub fn $method_with_config<H, T>(mut self, handler: H, config: RouteConfig) -> Self
115 where
116 H: Handler<T>,
117 T: Send + 'static,
118 {
119 let route = handler.into_route(config);
120 self.routes.insert(Operation::$operation, route);
121 self
122 }
123 };
124}
125
126macro_rules! top_level_handlers {
127 ($operation:ident, $method:ident, $method_with_config:ident) => {
128 #[must_use]
129 pub fn $method<H, T>(handler: H) -> MethodRouter
130 where
131 H: Handler<T>,
132 T: Send + 'static,
133 {
134 MethodRouter::new().on(Operation::$operation, handler, RouteConfig::default())
135 }
136
137 #[must_use]
138 pub fn $method_with_config<H, T>(handler: H, config: RouteConfig) -> MethodRouter
139 where
140 H: Handler<T>,
141 T: Send + 'static,
142 {
143 MethodRouter::new().on(Operation::$operation, handler, config)
144 }
145 };
146}
147
148top_level_handlers!(Create, create, create_with_config);
149top_level_handlers!(Read, read, read_with_config);
150top_level_handlers!(Update, update, update_with_config);
151top_level_handlers!(Delete, delete, delete_with_config);
152top_level_handlers!(Revoke, revoke, revoke_with_config);
153top_level_handlers!(Renew, renew, renew_with_config);
154
155impl MethodRouter {
156 #[must_use]
157 pub fn new() -> Self {
158 Self {
159 routes: HashMap::default(),
160 }
161 }
162
163 chained_handlers!(Create, create, create_with_config);
164 chained_handlers!(Read, read, read_with_config);
165 chained_handlers!(Update, update, update_with_config);
166 chained_handlers!(Delete, delete, delete_with_config);
167 chained_handlers!(Revoke, revoke, revoke_with_config);
168 chained_handlers!(Renew, renew, renew_with_config);
169
170 #[must_use]
171 pub fn on<H, T>(mut self, operation: Operation, handler: H, config: RouteConfig) -> Self
172 where
173 H: Handler<T>,
174 T: Send + 'static,
175 {
176 let route = handler.into_route(config);
177 self.routes.insert(operation, route);
178 self
179 }
180
181 #[must_use]
182 pub fn layer<L>(self, layer: L) -> Self
183 where
184 L: Layer<Route>,
185 L::Service:
186 Service<Request, Error = ApiError, Response = Response> + Clone + Send + 'static,
187 <L::Service as Service<Request>>::Future: Send + 'static,
188 {
189 let routes = self
190 .routes
191 .into_iter()
192 .map(|(op, route)| {
193 let config = route.config.clone();
194 let svc = layer.layer(route);
195 let svc = BoxCloneService::new(svc);
196 let route = Route::new(svc, config);
197 (op, route)
198 })
199 .collect();
200
201 Self { routes }
202 }
203}
204
205impl Service<Request> for MethodRouter {
206 type Response = Response;
207
208 type Error = ApiError;
209
210 type Future = Pin<Box<dyn Future<Output = Result<Response, ApiError>> + Send + 'static>>;
211
212 fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
213 Poll::Ready(Ok(()))
214 }
215
216 fn call(&mut self, req: Request) -> Self::Future {
217 let route = self.routes.get(&req.operation).map(Clone::clone);
218
219 Box::pin(async move {
220 match route {
221 Some(route) => route.oneshot(req).await,
222 None => Err(ApiError::not_found()),
223 }
224 })
225 }
226}