1use std::marker::PhantomData;
13use std::ops::ControlFlow;
14use std::pin::Pin;
15use std::sync::Arc;
16
17use futures::Future;
18use futures::future::BoxFuture;
19use tower::BoxError;
20use tower::Layer;
21use tower::Service;
22
23#[allow(clippy::type_complexity)]
25pub struct AsyncCheckpointLayer<S, Fut, Request>
26where
27 S: Service<Request, Error = BoxError> + Clone + Send + 'static,
28 Fut: Future<Output = Result<ControlFlow<<S as Service<Request>>::Response, Request>, BoxError>>,
29{
30 checkpoint_fn: Arc<Pin<Box<dyn Fn(Request) -> Fut + Send + Sync + 'static>>>,
31 phantom: PhantomData<S>, }
33
34impl<S, Fut, Request> AsyncCheckpointLayer<S, Fut, Request>
35where
36 S: Service<Request, Error = BoxError> + Clone + Send + 'static,
37 Fut: Future<Output = Result<ControlFlow<<S as Service<Request>>::Response, Request>, BoxError>>,
38{
39 pub fn new<F>(checkpoint_fn: F) -> Self
41 where
42 F: Fn(Request) -> Fut + Send + Sync + 'static,
43 {
44 Self {
45 checkpoint_fn: Arc::new(Box::pin(checkpoint_fn)),
46 phantom: PhantomData,
47 }
48 }
49}
50
51impl<S, Fut, Request> Layer<S> for AsyncCheckpointLayer<S, Fut, Request>
52where
53 S: Service<Request, Error = BoxError> + Clone + Send + 'static,
54 <S as Service<Request>>::Future: Send,
55 Request: Send + 'static,
56 <S as Service<Request>>::Response: Send + 'static,
57 Fut: Future<Output = Result<ControlFlow<<S as Service<Request>>::Response, Request>, BoxError>>,
58{
59 type Service = AsyncCheckpointService<S, Fut, Request>;
60
61 fn layer(&self, service: S) -> Self::Service {
62 AsyncCheckpointService {
63 checkpoint_fn: Arc::clone(&self.checkpoint_fn),
64 service,
65 }
66 }
67}
68
69#[allow(clippy::type_complexity)]
71pub struct AsyncCheckpointService<S, Fut, Request>
72where
73 Request: Send + 'static,
74 S: Service<Request, Error = BoxError> + Clone + Send + 'static,
75 <S as Service<Request>>::Response: Send + 'static,
76 <S as Service<Request>>::Future: Send + 'static,
77 Fut: Future<Output = Result<ControlFlow<<S as Service<Request>>::Response, Request>, BoxError>>,
78{
79 service: S,
80 checkpoint_fn: Arc<Pin<Box<dyn Fn(Request) -> Fut + Send + Sync + 'static>>>,
81}
82
83impl<S, Fut, Request> AsyncCheckpointService<S, Fut, Request>
84where
85 Request: Send + 'static,
86 S: Service<Request, Error = BoxError> + Clone + Send + 'static,
87 <S as Service<Request>>::Response: Send + 'static,
88 <S as Service<Request>>::Future: Send + 'static,
89 Fut: Future<Output = Result<ControlFlow<<S as Service<Request>>::Response, Request>, BoxError>>,
90{
91 pub fn new<F>(checkpoint_fn: F, service: S) -> Self
93 where
94 F: Fn(Request) -> Fut + Send + Sync + 'static,
95 {
96 Self {
97 checkpoint_fn: Arc::new(Box::pin(checkpoint_fn)),
98 service,
99 }
100 }
101}
102
103impl<S, Fut, Request> Service<Request> for AsyncCheckpointService<S, Fut, Request>
104where
105 Request: Send + 'static,
106 S: Service<Request, Error = BoxError> + Clone + Send + 'static,
107 <S as Service<Request>>::Response: Send + 'static,
108 <S as Service<Request>>::Future: Send + 'static,
109 Fut: Future<Output = Result<ControlFlow<<S as Service<Request>>::Response, Request>, BoxError>>
110 + Send
111 + 'static,
112{
113 type Response = <S as Service<Request>>::Response;
114
115 type Error = BoxError;
116
117 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
118
119 fn poll_ready(
120 &mut self,
121 cx: &mut std::task::Context<'_>,
122 ) -> std::task::Poll<Result<(), Self::Error>> {
123 self.service.poll_ready(cx)
124 }
125
126 fn call(&mut self, req: Request) -> Self::Future {
127 let checkpoint_fn = Arc::clone(&self.checkpoint_fn);
128 let service = self.service.clone();
129 let mut inner = std::mem::replace(&mut self.service, service);
130
131 Box::pin(async move {
132 match (checkpoint_fn)(req).await {
133 Ok(ControlFlow::Break(response)) => Ok(response),
134 Ok(ControlFlow::Continue(request)) => inner.call(request).await,
135 Err(error) => Err(error),
136 }
137 })
138 }
139}
140
141#[cfg(test)]
142mod async_checkpoint_tests {
143 use tower::BoxError;
144 use tower::Layer;
145 use tower::ServiceBuilder;
146 use tower::ServiceExt;
147
148 use super::*;
149 use crate::layers::ServiceBuilderExt;
150 use crate::plugin::test::MockExecutionService;
151 use crate::services::ExecutionRequest;
152 use crate::services::ExecutionResponse;
153
154 #[tokio::test]
155 async fn test_service_builder() {
156 let expected_label = "from_mock_service";
157
158 let mut execution_service = MockExecutionService::new();
159
160 execution_service
161 .expect_clone()
162 .return_once(MockExecutionService::new);
163
164 execution_service
165 .expect_call()
166 .times(1)
167 .returning(move |req| {
168 Ok(ExecutionResponse::fake_builder()
169 .label(expected_label.to_string())
170 .context(req.context)
171 .build()
172 .unwrap())
173 });
174
175 let service_stack = ServiceBuilder::new()
176 .checkpoint_async(|req: ExecutionRequest| async { Ok(ControlFlow::Continue(req)) })
177 .service(execution_service);
178
179 let request = ExecutionRequest::fake_builder().build();
180
181 let actual_label = service_stack
182 .oneshot(request)
183 .await
184 .unwrap()
185 .next_response()
186 .await
187 .unwrap()
188 .label
189 .unwrap();
190
191 assert_eq!(actual_label, expected_label)
192 }
193
194 #[tokio::test]
195 async fn test_continue() {
196 let expected_label = "from_mock_service";
197 let mut router_service = MockExecutionService::new();
198
199 router_service
200 .expect_clone()
201 .return_once(MockExecutionService::new);
202
203 router_service
204 .expect_call()
205 .times(1)
206 .returning(move |_req| {
207 Ok(ExecutionResponse::fake_builder()
208 .label(expected_label.to_string())
209 .build()
210 .unwrap())
211 });
212 let service_stack =
213 AsyncCheckpointLayer::new(|req| async { Ok(ControlFlow::Continue(req)) })
214 .layer(router_service);
215
216 let request = ExecutionRequest::fake_builder().build();
217
218 let actual_label = service_stack
219 .oneshot(request)
220 .await
221 .unwrap()
222 .next_response()
223 .await
224 .unwrap()
225 .label
226 .unwrap();
227
228 assert_eq!(actual_label, expected_label)
229 }
230
231 #[tokio::test]
232 async fn test_return() {
233 let expected_label = "returned_before_mock_service";
234 let mut router_service = MockExecutionService::new();
235 router_service
236 .expect_clone()
237 .return_once(MockExecutionService::new);
238
239 let service_stack = AsyncCheckpointLayer::new(|_req| async {
240 Ok(ControlFlow::Break(
241 ExecutionResponse::fake_builder()
242 .label("returned_before_mock_service".to_string())
243 .build()
244 .unwrap(),
245 ))
246 })
247 .layer(router_service);
248
249 let request = ExecutionRequest::fake_builder().build();
250
251 let actual_label = service_stack
252 .oneshot(request)
253 .await
254 .unwrap()
255 .next_response()
256 .await
257 .unwrap()
258 .label
259 .unwrap();
260
261 assert_eq!(actual_label, expected_label)
262 }
263
264 #[tokio::test]
265 async fn test_error() {
266 let expected_error = "checkpoint_error";
267 let mut router_service = MockExecutionService::new();
268 router_service
269 .expect_clone()
270 .return_once(MockExecutionService::new);
271
272 let service_stack =
273 AsyncCheckpointLayer::new(
274 move |_req| async move { Err(BoxError::from(expected_error)) },
275 )
276 .layer(router_service);
277
278 let request = ExecutionRequest::fake_builder().build();
279
280 let actual_error = service_stack
281 .oneshot(request)
282 .await
283 .map(|_| unreachable!())
284 .unwrap_err()
285 .to_string();
286
287 assert_eq!(actual_error, expected_error)
288 }
289
290 #[tokio::test]
291 async fn test_service_builder_oneshot() {
292 let expected_label = "from_mock_service";
293
294 let mut execution_service = MockExecutionService::new();
295 execution_service
296 .expect_call()
297 .times(1)
298 .returning(move |req: ExecutionRequest| {
299 Ok(ExecutionResponse::fake_builder()
300 .label(expected_label.to_string())
301 .context(req.context)
302 .build()
303 .unwrap())
304 });
305
306 execution_service
307 .expect_clone()
308 .returning(MockExecutionService::new);
309
310 let service_stack = ServiceBuilder::new()
311 .checkpoint_async(|req: ExecutionRequest| async { Ok(ControlFlow::Continue(req)) })
312 .service(execution_service);
313
314 let request = ExecutionRequest::fake_builder().build();
315
316 let actual_label = service_stack
317 .oneshot(request)
318 .await
319 .unwrap()
320 .next_response()
321 .await
322 .unwrap()
323 .label
324 .unwrap();
325
326 assert_eq!(actual_label, expected_label)
327 }
328
329 #[tokio::test]
330 #[should_panic]
331 async fn test_service_builder_buffered_oneshot() {
332 let expected_label = "from_mock_service";
333
334 let mut execution_service = MockExecutionService::new();
335 execution_service
336 .expect_call()
337 .times(1)
338 .returning(move |req: ExecutionRequest| {
339 Ok(ExecutionResponse::fake_builder()
340 .label(expected_label.to_string())
341 .context(req.context)
342 .build()
343 .unwrap())
344 });
345
346 let mut service_stack = ServiceBuilder::new()
347 .checkpoint_async(|req: ExecutionRequest| async { Ok(ControlFlow::Continue(req)) })
348 .buffered()
349 .service(execution_service);
350
351 let request = ExecutionRequest::fake_builder().build();
352 let request_again = ExecutionRequest::fake_builder().build();
353
354 let _ = service_stack.call(request).await.unwrap();
355 let _ = service_stack.call(request_again).await.unwrap();
357 }
358
359 #[tokio::test]
360 async fn test_double_ready_doesnt_panic() {
361 let mut router_service = MockExecutionService::new();
362
363 router_service
364 .expect_clone()
365 .returning(MockExecutionService::new);
366
367 let mut service_stack = AsyncCheckpointLayer::new(|_req| async {
368 Ok(ControlFlow::Break(
369 ExecutionResponse::fake_builder()
370 .label("returned_before_mock_service".to_string())
371 .build()
372 .unwrap(),
373 ))
374 })
375 .layer(router_service);
376
377 service_stack.ready().await.unwrap();
378 service_stack
379 .call(ExecutionRequest::fake_builder().build())
380 .await
381 .unwrap();
382
383 assert!(service_stack.ready().await.is_ok());
384 }
385
386 #[tokio::test]
387 async fn test_double_call_doesnt_panic() {
388 let mut router_service = MockExecutionService::new();
389
390 router_service.expect_clone().returning(|| {
391 let mut mes = MockExecutionService::new();
392 mes.expect_clone().returning(MockExecutionService::new);
393 mes
394 });
395
396 let mut service_stack = AsyncCheckpointLayer::new(|_req| async {
397 Ok(ControlFlow::Break(
398 ExecutionResponse::fake_builder()
399 .label("returned_before_mock_service".to_string())
400 .build()
401 .unwrap(),
402 ))
403 })
404 .layer(router_service);
405
406 service_stack.ready().await.unwrap();
407
408 service_stack
409 .call(ExecutionRequest::fake_builder().build())
410 .await
411 .unwrap();
412
413 service_stack.ready().await.unwrap();
414
415 assert!(
416 service_stack
417 .call(ExecutionRequest::fake_builder().build())
418 .await
419 .is_ok()
420 );
421 }
422}