apollo_router/layers/
async_checkpoint.rs

1//! Asynchronous Checkpoint
2//!
3//! Provides a general mechanism for controlling the flow of a request. Useful in any situation
4//! where the caller wishes to provide control flow for a request.
5//!
6//! If the evaluated closure succeeds then the request is passed onto the next service in the
7//! chain of responsibilities. If it fails, then the control flow is broken and a response is passed
8//! back to the invoking service.
9//!
10//! See [`Layer`] and [`Service`] for more details.
11
12use std::marker::PhantomData;
13use std::ops::ControlFlow;
14use std::pin::Pin;
15use std::sync::Arc;
16
17use futures::Future;
18use futures::future::BoxFuture;
19use tower::BoxError;
20use tower::Layer;
21use tower::Service;
22
23/// [`Layer`] for Asynchronous Checkpoints. See [`ServiceBuilderExt::checkpoint_async()`](crate::layers::ServiceBuilderExt::checkpoint_async()).
24#[allow(clippy::type_complexity)]
25pub struct AsyncCheckpointLayer<S, Fut, Request>
26where
27    S: Service<Request, Error = BoxError> + Clone + Send + 'static,
28    Fut: Future<Output = Result<ControlFlow<<S as Service<Request>>::Response, Request>, BoxError>>,
29{
30    checkpoint_fn: Arc<Pin<Box<dyn Fn(Request) -> Fut + Send + Sync + 'static>>>,
31    phantom: PhantomData<S>, // We use PhantomData because the compiler can't detect that S is used in the Future.
32}
33
34impl<S, Fut, Request> AsyncCheckpointLayer<S, Fut, Request>
35where
36    S: Service<Request, Error = BoxError> + Clone + Send + 'static,
37    Fut: Future<Output = Result<ControlFlow<<S as Service<Request>>::Response, Request>, BoxError>>,
38{
39    /// Create an `AsyncCheckpointLayer` from a function that takes a Service Request and returns a `ControlFlow`
40    pub fn new<F>(checkpoint_fn: F) -> Self
41    where
42        F: Fn(Request) -> Fut + Send + Sync + 'static,
43    {
44        Self {
45            checkpoint_fn: Arc::new(Box::pin(checkpoint_fn)),
46            phantom: PhantomData,
47        }
48    }
49}
50
51impl<S, Fut, Request> Layer<S> for AsyncCheckpointLayer<S, Fut, Request>
52where
53    S: Service<Request, Error = BoxError> + Clone + Send + 'static,
54    <S as Service<Request>>::Future: Send,
55    Request: Send + 'static,
56    <S as Service<Request>>::Response: Send + 'static,
57    Fut: Future<Output = Result<ControlFlow<<S as Service<Request>>::Response, Request>, BoxError>>,
58{
59    type Service = AsyncCheckpointService<S, Fut, Request>;
60
61    fn layer(&self, service: S) -> Self::Service {
62        AsyncCheckpointService {
63            checkpoint_fn: Arc::clone(&self.checkpoint_fn),
64            service,
65        }
66    }
67}
68
69/// [`Service`] for Asynchronous Checkpoints. See [`ServiceBuilderExt::checkpoint_async()`](crate::layers::ServiceBuilderExt::checkpoint_async()).
70#[allow(clippy::type_complexity)]
71pub struct AsyncCheckpointService<S, Fut, Request>
72where
73    Request: Send + 'static,
74    S: Service<Request, Error = BoxError> + Clone + Send + 'static,
75    <S as Service<Request>>::Response: Send + 'static,
76    <S as Service<Request>>::Future: Send + 'static,
77    Fut: Future<Output = Result<ControlFlow<<S as Service<Request>>::Response, Request>, BoxError>>,
78{
79    service: S,
80    checkpoint_fn: Arc<Pin<Box<dyn Fn(Request) -> Fut + Send + Sync + 'static>>>,
81}
82
83impl<S, Fut, Request> AsyncCheckpointService<S, Fut, Request>
84where
85    Request: Send + 'static,
86    S: Service<Request, Error = BoxError> + Clone + Send + 'static,
87    <S as Service<Request>>::Response: Send + 'static,
88    <S as Service<Request>>::Future: Send + 'static,
89    Fut: Future<Output = Result<ControlFlow<<S as Service<Request>>::Response, Request>, BoxError>>,
90{
91    /// Create an `AsyncCheckpointLayer` from a function that takes a Service Request and returns a `ControlFlow`
92    pub fn new<F>(checkpoint_fn: F, service: S) -> Self
93    where
94        F: Fn(Request) -> Fut + Send + Sync + 'static,
95    {
96        Self {
97            checkpoint_fn: Arc::new(Box::pin(checkpoint_fn)),
98            service,
99        }
100    }
101}
102
103impl<S, Fut, Request> Service<Request> for AsyncCheckpointService<S, Fut, Request>
104where
105    Request: Send + 'static,
106    S: Service<Request, Error = BoxError> + Clone + Send + 'static,
107    <S as Service<Request>>::Response: Send + 'static,
108    <S as Service<Request>>::Future: Send + 'static,
109    Fut: Future<Output = Result<ControlFlow<<S as Service<Request>>::Response, Request>, BoxError>>
110        + Send
111        + 'static,
112{
113    type Response = <S as Service<Request>>::Response;
114
115    type Error = BoxError;
116
117    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
118
119    fn poll_ready(
120        &mut self,
121        cx: &mut std::task::Context<'_>,
122    ) -> std::task::Poll<Result<(), Self::Error>> {
123        self.service.poll_ready(cx)
124    }
125
126    fn call(&mut self, req: Request) -> Self::Future {
127        let checkpoint_fn = Arc::clone(&self.checkpoint_fn);
128        let service = self.service.clone();
129        let mut inner = std::mem::replace(&mut self.service, service);
130
131        Box::pin(async move {
132            match (checkpoint_fn)(req).await {
133                Ok(ControlFlow::Break(response)) => Ok(response),
134                Ok(ControlFlow::Continue(request)) => inner.call(request).await,
135                Err(error) => Err(error),
136            }
137        })
138    }
139}
140
141#[cfg(test)]
142mod async_checkpoint_tests {
143    use tower::BoxError;
144    use tower::Layer;
145    use tower::ServiceBuilder;
146    use tower::ServiceExt;
147
148    use super::*;
149    use crate::layers::ServiceBuilderExt;
150    use crate::plugin::test::MockExecutionService;
151    use crate::services::ExecutionRequest;
152    use crate::services::ExecutionResponse;
153
154    #[tokio::test]
155    async fn test_service_builder() {
156        let expected_label = "from_mock_service";
157
158        let mut execution_service = MockExecutionService::new();
159
160        execution_service
161            .expect_clone()
162            .return_once(MockExecutionService::new);
163
164        execution_service
165            .expect_call()
166            .times(1)
167            .returning(move |req| {
168                Ok(ExecutionResponse::fake_builder()
169                    .label(expected_label.to_string())
170                    .context(req.context)
171                    .build()
172                    .unwrap())
173            });
174
175        let service_stack = ServiceBuilder::new()
176            .checkpoint_async(|req: ExecutionRequest| async { Ok(ControlFlow::Continue(req)) })
177            .service(execution_service);
178
179        let request = ExecutionRequest::fake_builder().build();
180
181        let actual_label = service_stack
182            .oneshot(request)
183            .await
184            .unwrap()
185            .next_response()
186            .await
187            .unwrap()
188            .label
189            .unwrap();
190
191        assert_eq!(actual_label, expected_label)
192    }
193
194    #[tokio::test]
195    async fn test_continue() {
196        let expected_label = "from_mock_service";
197        let mut router_service = MockExecutionService::new();
198
199        router_service
200            .expect_clone()
201            .return_once(MockExecutionService::new);
202
203        router_service
204            .expect_call()
205            .times(1)
206            .returning(move |_req| {
207                Ok(ExecutionResponse::fake_builder()
208                    .label(expected_label.to_string())
209                    .build()
210                    .unwrap())
211            });
212        let service_stack =
213            AsyncCheckpointLayer::new(|req| async { Ok(ControlFlow::Continue(req)) })
214                .layer(router_service);
215
216        let request = ExecutionRequest::fake_builder().build();
217
218        let actual_label = service_stack
219            .oneshot(request)
220            .await
221            .unwrap()
222            .next_response()
223            .await
224            .unwrap()
225            .label
226            .unwrap();
227
228        assert_eq!(actual_label, expected_label)
229    }
230
231    #[tokio::test]
232    async fn test_return() {
233        let expected_label = "returned_before_mock_service";
234        let mut router_service = MockExecutionService::new();
235        router_service
236            .expect_clone()
237            .return_once(MockExecutionService::new);
238
239        let service_stack = AsyncCheckpointLayer::new(|_req| async {
240            Ok(ControlFlow::Break(
241                ExecutionResponse::fake_builder()
242                    .label("returned_before_mock_service".to_string())
243                    .build()
244                    .unwrap(),
245            ))
246        })
247        .layer(router_service);
248
249        let request = ExecutionRequest::fake_builder().build();
250
251        let actual_label = service_stack
252            .oneshot(request)
253            .await
254            .unwrap()
255            .next_response()
256            .await
257            .unwrap()
258            .label
259            .unwrap();
260
261        assert_eq!(actual_label, expected_label)
262    }
263
264    #[tokio::test]
265    async fn test_error() {
266        let expected_error = "checkpoint_error";
267        let mut router_service = MockExecutionService::new();
268        router_service
269            .expect_clone()
270            .return_once(MockExecutionService::new);
271
272        let service_stack =
273            AsyncCheckpointLayer::new(
274                move |_req| async move { Err(BoxError::from(expected_error)) },
275            )
276            .layer(router_service);
277
278        let request = ExecutionRequest::fake_builder().build();
279
280        let actual_error = service_stack
281            .oneshot(request)
282            .await
283            .map(|_| unreachable!())
284            .unwrap_err()
285            .to_string();
286
287        assert_eq!(actual_error, expected_error)
288    }
289
290    #[tokio::test]
291    async fn test_service_builder_oneshot() {
292        let expected_label = "from_mock_service";
293
294        let mut execution_service = MockExecutionService::new();
295        execution_service
296            .expect_call()
297            .times(1)
298            .returning(move |req: ExecutionRequest| {
299                Ok(ExecutionResponse::fake_builder()
300                    .label(expected_label.to_string())
301                    .context(req.context)
302                    .build()
303                    .unwrap())
304            });
305
306        execution_service
307            .expect_clone()
308            .returning(MockExecutionService::new);
309
310        let service_stack = ServiceBuilder::new()
311            .checkpoint_async(|req: ExecutionRequest| async { Ok(ControlFlow::Continue(req)) })
312            .service(execution_service);
313
314        let request = ExecutionRequest::fake_builder().build();
315
316        let actual_label = service_stack
317            .oneshot(request)
318            .await
319            .unwrap()
320            .next_response()
321            .await
322            .unwrap()
323            .label
324            .unwrap();
325
326        assert_eq!(actual_label, expected_label)
327    }
328
329    #[tokio::test]
330    #[should_panic]
331    async fn test_service_builder_buffered_oneshot() {
332        let expected_label = "from_mock_service";
333
334        let mut execution_service = MockExecutionService::new();
335        execution_service
336            .expect_call()
337            .times(1)
338            .returning(move |req: ExecutionRequest| {
339                Ok(ExecutionResponse::fake_builder()
340                    .label(expected_label.to_string())
341                    .context(req.context)
342                    .build()
343                    .unwrap())
344            });
345
346        let mut service_stack = ServiceBuilder::new()
347            .checkpoint_async(|req: ExecutionRequest| async { Ok(ControlFlow::Continue(req)) })
348            .buffered()
349            .service(execution_service);
350
351        let request = ExecutionRequest::fake_builder().build();
352        let request_again = ExecutionRequest::fake_builder().build();
353
354        let _ = service_stack.call(request).await.unwrap();
355        // Trying to use the service again should cause a panic
356        let _ = service_stack.call(request_again).await.unwrap();
357    }
358
359    #[tokio::test]
360    async fn test_double_ready_doesnt_panic() {
361        let mut router_service = MockExecutionService::new();
362
363        router_service
364            .expect_clone()
365            .returning(MockExecutionService::new);
366
367        let mut service_stack = AsyncCheckpointLayer::new(|_req| async {
368            Ok(ControlFlow::Break(
369                ExecutionResponse::fake_builder()
370                    .label("returned_before_mock_service".to_string())
371                    .build()
372                    .unwrap(),
373            ))
374        })
375        .layer(router_service);
376
377        service_stack.ready().await.unwrap();
378        service_stack
379            .call(ExecutionRequest::fake_builder().build())
380            .await
381            .unwrap();
382
383        assert!(service_stack.ready().await.is_ok());
384    }
385
386    #[tokio::test]
387    async fn test_double_call_doesnt_panic() {
388        let mut router_service = MockExecutionService::new();
389
390        router_service.expect_clone().returning(|| {
391            let mut mes = MockExecutionService::new();
392            mes.expect_clone().returning(MockExecutionService::new);
393            mes
394        });
395
396        let mut service_stack = AsyncCheckpointLayer::new(|_req| async {
397            Ok(ControlFlow::Break(
398                ExecutionResponse::fake_builder()
399                    .label("returned_before_mock_service".to_string())
400                    .build()
401                    .unwrap(),
402            ))
403        })
404        .layer(router_service);
405
406        service_stack.ready().await.unwrap();
407
408        service_stack
409            .call(ExecutionRequest::fake_builder().build())
410            .await
411            .unwrap();
412
413        service_stack.ready().await.unwrap();
414
415        assert!(
416            service_stack
417                .call(ExecutionRequest::fake_builder().build())
418                .await
419                .is_ok()
420        );
421    }
422}