1use std::ops::ControlFlow;
13use std::sync::Arc;
14
15use futures::future::BoxFuture;
16use tower::BoxError;
17use tower::Layer;
18use tower::Service;
19
20#[allow(clippy::type_complexity)]
22pub struct CheckpointLayer<S, Request>
23where
24 S: Service<Request> + Send + 'static,
25 Request: Send + 'static,
26 S::Future: Send,
27 S::Response: Send + 'static,
28 S::Error: Send + 'static,
29{
30 checkpoint_fn: Arc<
31 dyn Fn(
32 Request,
33 ) -> Result<
34 ControlFlow<<S as Service<Request>>::Response, Request>,
35 <S as Service<Request>>::Error,
36 > + Send
37 + Sync
38 + 'static,
39 >,
40}
41
42#[allow(clippy::type_complexity)]
43impl<S, Request> CheckpointLayer<S, Request>
44where
45 S: Service<Request> + Send + 'static,
46 Request: Send + 'static,
47 S::Future: Send,
48 S::Response: Send + 'static,
49 <S as Service<Request>>::Error: Into<BoxError> + Send + 'static,
50{
51 pub fn new(
53 checkpoint_fn: impl Fn(
54 Request,
55 ) -> Result<
56 ControlFlow<<S as Service<Request>>::Response, Request>,
57 <S as Service<Request>>::Error,
58 > + Send
59 + Sync
60 + 'static,
61 ) -> Self {
62 Self {
63 checkpoint_fn: Arc::new(checkpoint_fn),
64 }
65 }
66}
67
68impl<S, Request> Layer<S> for CheckpointLayer<S, Request>
69where
70 S: Service<Request> + Send + 'static,
71 <S as Service<Request>>::Future: Send,
72 Request: Send + 'static,
73 <S as Service<Request>>::Response: Send + 'static,
74 <S as Service<Request>>::Error: Into<BoxError> + Send + 'static,
75{
76 type Service = CheckpointService<S, Request>;
77
78 fn layer(&self, service: S) -> Self::Service {
79 CheckpointService {
80 checkpoint_fn: Arc::clone(&self.checkpoint_fn),
81 inner: service,
82 }
83 }
84}
85
86#[derive(Clone)]
88#[allow(clippy::type_complexity)]
89pub struct CheckpointService<S, Request>
90where
91 Request: Send + 'static,
92 S: Service<Request> + Send + 'static,
93 <S as Service<Request>>::Error: Into<BoxError> + Send + 'static,
94 <S as Service<Request>>::Response: Send + 'static,
95 <S as Service<Request>>::Future: Send + 'static,
96{
97 inner: S,
98 checkpoint_fn: Arc<
99 dyn Fn(
100 Request,
101 ) -> Result<
102 ControlFlow<<S as Service<Request>>::Response, Request>,
103 <S as Service<Request>>::Error,
104 > + Send
105 + Sync
106 + 'static,
107 >,
108}
109
110#[allow(clippy::type_complexity)]
111impl<S, Request> CheckpointService<S, Request>
112where
113 Request: Send + 'static,
114 S: Service<Request> + Send + 'static,
115 <S as Service<Request>>::Error: Into<BoxError> + Send + 'static,
116 <S as Service<Request>>::Response: Send + 'static,
117 <S as Service<Request>>::Future: Send + 'static,
118{
119 pub fn new(
121 checkpoint_fn: impl Fn(
122 Request,
123 ) -> Result<
124 ControlFlow<<S as Service<Request>>::Response, Request>,
125 <S as Service<Request>>::Error,
126 > + Send
127 + Sync
128 + 'static,
129 inner: S,
130 ) -> Self {
131 Self {
132 checkpoint_fn: Arc::new(checkpoint_fn),
133 inner,
134 }
135 }
136}
137
138impl<S, Request> Service<Request> for CheckpointService<S, Request>
139where
140 S: Service<Request>,
141 S: Send + 'static,
142 S::Future: Send,
143 Request: Send + 'static,
144 <S as Service<Request>>::Response: Send + 'static,
145 <S as Service<Request>>::Error: Into<BoxError> + Send + 'static,
146{
147 type Response = <S as Service<Request>>::Response;
148
149 type Error = <S as Service<Request>>::Error;
150
151 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
152
153 fn poll_ready(
154 &mut self,
155 cx: &mut std::task::Context<'_>,
156 ) -> std::task::Poll<Result<(), Self::Error>> {
157 self.inner.poll_ready(cx)
158 }
159
160 fn call(&mut self, req: Request) -> Self::Future {
161 match (self.checkpoint_fn)(req) {
162 Ok(ControlFlow::Break(response)) => Box::pin(async move { Ok(response) }),
163 Ok(ControlFlow::Continue(request)) => Box::pin(self.inner.call(request)),
164 Err(error) => Box::pin(async move { Err(error) }),
165 }
166 }
167}
168
169#[cfg(test)]
170mod checkpoint_tests {
171 use tower::BoxError;
172 use tower::Layer;
173 use tower::ServiceBuilder;
174 use tower::ServiceExt;
175
176 use super::*;
177 use crate::layers::ServiceBuilderExt;
178 use crate::plugin::test::MockExecutionService;
179 use crate::services::ExecutionRequest;
180 use crate::services::ExecutionResponse;
181
182 #[tokio::test]
183 async fn test_service_builder() {
184 let expected_label = "from_mock_service";
185
186 let mut execution_service = MockExecutionService::new();
187
188 execution_service
189 .expect_call()
190 .times(1)
191 .returning(move |req: ExecutionRequest| {
192 Ok(ExecutionResponse::fake_builder()
193 .label(expected_label.to_string())
194 .context(req.context)
195 .build()
196 .unwrap())
197 });
198
199 let service_stack = ServiceBuilder::new()
200 .checkpoint(|req: ExecutionRequest| Ok(ControlFlow::Continue(req)))
201 .service(execution_service);
202
203 let request = ExecutionRequest::fake_builder().build();
204
205 let actual_label = service_stack
206 .oneshot(request)
207 .await
208 .unwrap()
209 .next_response()
210 .await
211 .unwrap()
212 .label
213 .unwrap();
214
215 assert_eq!(actual_label, expected_label)
216 }
217
218 #[tokio::test]
219 async fn test_continue() {
220 let expected_label = "from_mock_service";
221 let mut router_service = MockExecutionService::new();
222
223 router_service
224 .expect_call()
225 .times(1)
226 .returning(move |_req| {
227 Ok(ExecutionResponse::fake_builder()
228 .label(expected_label.to_string())
229 .build()
230 .unwrap())
231 });
232
233 let service_stack =
234 CheckpointLayer::new(|req| Ok(ControlFlow::Continue(req))).layer(router_service);
235
236 let request = ExecutionRequest::fake_builder().build();
237
238 let actual_label = service_stack
239 .oneshot(request)
240 .await
241 .unwrap()
242 .next_response()
243 .await
244 .unwrap()
245 .label
246 .unwrap();
247
248 assert_eq!(actual_label, expected_label)
249 }
250
251 #[tokio::test]
252 async fn test_return() {
253 let expected_label = "returned_before_mock_service";
254 let router_service = MockExecutionService::new();
255
256 let service_stack = CheckpointLayer::new(|_req| {
257 Ok(ControlFlow::Break(
258 ExecutionResponse::fake_builder()
259 .label("returned_before_mock_service".to_string())
260 .build()
261 .unwrap(),
262 ))
263 })
264 .layer(router_service);
265
266 let request = ExecutionRequest::fake_builder().build();
267
268 let actual_label = service_stack
269 .oneshot(request)
270 .await
271 .unwrap()
272 .next_response()
273 .await
274 .unwrap()
275 .label
276 .unwrap();
277
278 assert_eq!(actual_label, expected_label)
279 }
280
281 #[tokio::test]
282 async fn test_error() {
283 let expected_error = "checkpoint_error";
284 let router_service = MockExecutionService::new();
285
286 let service_stack = CheckpointLayer::new(move |_req| Err(BoxError::from(expected_error)))
287 .layer(router_service);
288
289 let request = ExecutionRequest::fake_builder().build();
290
291 let actual_error = service_stack
292 .oneshot(request)
293 .await
294 .map(|_| unreachable!())
295 .unwrap_err()
296 .to_string();
297
298 assert_eq!(actual_error, expected_error)
299 }
300}