Skip to main content

apollo_router/layers/
sync_checkpoint.rs

1//! Synchronous Checkpoint.
2//!
3//! Provides a general mechanism for controlling the flow of a request. Useful in any situation
4//! where the caller wishes to provide control flow for a request.
5//!
6//! If the evaluated closure succeeds then the request is passed onto the next service in the
7//! chain of responsibilities. If it fails, then the control flow is broken and a response is passed
8//! back to the invoking service.
9//!
10//! See [`Layer`] and [`Service`] for more details.
11
12use std::ops::ControlFlow;
13use std::sync::Arc;
14
15use futures::future::BoxFuture;
16use tower::BoxError;
17use tower::Layer;
18use tower::Service;
19
20/// [`Layer`] for Synchronous Checkpoints. See [`ServiceBuilderExt::checkpoint()`](crate::layers::ServiceBuilderExt::checkpoint()).
21#[allow(clippy::type_complexity)]
22pub struct CheckpointLayer<S, Request>
23where
24    S: Service<Request> + Send + 'static,
25    Request: Send + 'static,
26    S::Future: Send,
27    S::Response: Send + 'static,
28    S::Error: Send + 'static,
29{
30    checkpoint_fn: Arc<
31        dyn Fn(
32                Request,
33            ) -> Result<
34                ControlFlow<<S as Service<Request>>::Response, Request>,
35                <S as Service<Request>>::Error,
36            > + Send
37            + Sync
38            + 'static,
39    >,
40}
41
42#[allow(clippy::type_complexity)]
43impl<S, Request> CheckpointLayer<S, Request>
44where
45    S: Service<Request> + Send + 'static,
46    Request: Send + 'static,
47    S::Future: Send,
48    S::Response: Send + 'static,
49    <S as Service<Request>>::Error: Into<BoxError> + Send + 'static,
50{
51    /// Create a `CheckpointLayer` from a function that takes a Service Request and returns a `ControlFlow`
52    pub fn new(
53        checkpoint_fn: impl Fn(
54            Request,
55        ) -> Result<
56            ControlFlow<<S as Service<Request>>::Response, Request>,
57            <S as Service<Request>>::Error,
58        > + Send
59        + Sync
60        + 'static,
61    ) -> Self {
62        Self {
63            checkpoint_fn: Arc::new(checkpoint_fn),
64        }
65    }
66}
67
68impl<S, Request> Layer<S> for CheckpointLayer<S, Request>
69where
70    S: Service<Request> + Send + 'static,
71    <S as Service<Request>>::Future: Send,
72    Request: Send + 'static,
73    <S as Service<Request>>::Response: Send + 'static,
74    <S as Service<Request>>::Error: Into<BoxError> + Send + 'static,
75{
76    type Service = CheckpointService<S, Request>;
77
78    fn layer(&self, service: S) -> Self::Service {
79        CheckpointService {
80            checkpoint_fn: Arc::clone(&self.checkpoint_fn),
81            inner: service,
82        }
83    }
84}
85
86/// [`Service`] for Synchronous Checkpoints. See [`ServiceBuilderExt::checkpoint()`](crate::layers::ServiceBuilderExt::checkpoint()).
87#[allow(clippy::type_complexity)]
88pub struct CheckpointService<S, Request>
89where
90    Request: Send + 'static,
91    S: Service<Request> + Send + 'static,
92    <S as Service<Request>>::Error: Into<BoxError> + Send + 'static,
93    <S as Service<Request>>::Response: Send + 'static,
94    <S as Service<Request>>::Future: Send + 'static,
95{
96    inner: S,
97    checkpoint_fn: Arc<
98        dyn Fn(
99                Request,
100            ) -> Result<
101                ControlFlow<<S as Service<Request>>::Response, Request>,
102                <S as Service<Request>>::Error,
103            > + Send
104            + Sync
105            + 'static,
106    >,
107}
108
109impl<S, Request> Clone for CheckpointService<S, Request>
110where
111    S: Clone,
112    // bounds to match the service struct...
113    Request: Send + 'static,
114    S: Service<Request> + Send + 'static,
115    <S as Service<Request>>::Error: Into<BoxError> + Send + 'static,
116    <S as Service<Request>>::Response: Send + 'static,
117    <S as Service<Request>>::Future: Send + 'static,
118{
119    fn clone(&self) -> Self {
120        Self {
121            inner: self.inner.clone(),
122            checkpoint_fn: self.checkpoint_fn.clone(),
123        }
124    }
125}
126
127#[allow(clippy::type_complexity)]
128impl<S, Request> CheckpointService<S, Request>
129where
130    Request: Send + 'static,
131    S: Service<Request> + Send + 'static,
132    <S as Service<Request>>::Error: Into<BoxError> + Send + 'static,
133    <S as Service<Request>>::Response: Send + 'static,
134    <S as Service<Request>>::Future: Send + 'static,
135{
136    /// Create a `CheckpointLayer` from a function that takes a Service Request and returns a `ControlFlow`
137    pub fn new(
138        checkpoint_fn: impl Fn(
139            Request,
140        ) -> Result<
141            ControlFlow<<S as Service<Request>>::Response, Request>,
142            <S as Service<Request>>::Error,
143        > + Send
144        + Sync
145        + 'static,
146        inner: S,
147    ) -> Self {
148        Self {
149            checkpoint_fn: Arc::new(checkpoint_fn),
150            inner,
151        }
152    }
153}
154
155impl<S, Request> Service<Request> for CheckpointService<S, Request>
156where
157    S: Service<Request>,
158    S: Send + 'static,
159    S::Future: Send,
160    Request: Send + 'static,
161    <S as Service<Request>>::Response: Send + 'static,
162    <S as Service<Request>>::Error: Into<BoxError> + Send + 'static,
163{
164    type Response = <S as Service<Request>>::Response;
165
166    type Error = <S as Service<Request>>::Error;
167
168    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
169
170    fn poll_ready(
171        &mut self,
172        cx: &mut std::task::Context<'_>,
173    ) -> std::task::Poll<Result<(), Self::Error>> {
174        self.inner.poll_ready(cx)
175    }
176
177    fn call(&mut self, req: Request) -> Self::Future {
178        match (self.checkpoint_fn)(req) {
179            Ok(ControlFlow::Break(response)) => Box::pin(async move { Ok(response) }),
180            Ok(ControlFlow::Continue(request)) => Box::pin(self.inner.call(request)),
181            Err(error) => Box::pin(async move { Err(error) }),
182        }
183    }
184}
185
186#[cfg(test)]
187mod checkpoint_tests {
188    use tower::BoxError;
189    use tower::Layer;
190    use tower::ServiceBuilder;
191    use tower::ServiceExt;
192
193    use super::*;
194    use crate::layers::ServiceBuilderExt;
195    use crate::plugin::test::MockExecutionService;
196    use crate::services::ExecutionRequest;
197    use crate::services::ExecutionResponse;
198
199    #[tokio::test]
200    async fn test_service_builder() {
201        let expected_label = "from_mock_service";
202
203        let mut execution_service = MockExecutionService::new();
204
205        execution_service
206            .expect_call()
207            .times(1)
208            .returning(move |req: ExecutionRequest| {
209                Ok(ExecutionResponse::fake_builder()
210                    .label(expected_label.to_string())
211                    .context(req.context)
212                    .build()
213                    .unwrap())
214            });
215
216        let service_stack = ServiceBuilder::new()
217            .checkpoint(|req: ExecutionRequest| Ok(ControlFlow::Continue(req)))
218            .service(execution_service);
219
220        let request = ExecutionRequest::fake_builder().build();
221
222        let actual_label = service_stack
223            .oneshot(request)
224            .await
225            .unwrap()
226            .next_response()
227            .await
228            .unwrap()
229            .label
230            .unwrap();
231
232        assert_eq!(actual_label, expected_label)
233    }
234
235    #[tokio::test]
236    async fn test_continue() {
237        let expected_label = "from_mock_service";
238        let mut router_service = MockExecutionService::new();
239
240        router_service
241            .expect_call()
242            .times(1)
243            .returning(move |_req| {
244                Ok(ExecutionResponse::fake_builder()
245                    .label(expected_label.to_string())
246                    .build()
247                    .unwrap())
248            });
249
250        let service_stack =
251            CheckpointLayer::new(|req| Ok(ControlFlow::Continue(req))).layer(router_service);
252
253        let request = ExecutionRequest::fake_builder().build();
254
255        let actual_label = service_stack
256            .oneshot(request)
257            .await
258            .unwrap()
259            .next_response()
260            .await
261            .unwrap()
262            .label
263            .unwrap();
264
265        assert_eq!(actual_label, expected_label)
266    }
267
268    #[tokio::test]
269    async fn test_return() {
270        let expected_label = "returned_before_mock_service";
271        let router_service = MockExecutionService::new();
272
273        let service_stack = CheckpointLayer::new(|_req| {
274            Ok(ControlFlow::Break(
275                ExecutionResponse::fake_builder()
276                    .label("returned_before_mock_service".to_string())
277                    .build()
278                    .unwrap(),
279            ))
280        })
281        .layer(router_service);
282
283        let request = ExecutionRequest::fake_builder().build();
284
285        let actual_label = service_stack
286            .oneshot(request)
287            .await
288            .unwrap()
289            .next_response()
290            .await
291            .unwrap()
292            .label
293            .unwrap();
294
295        assert_eq!(actual_label, expected_label)
296    }
297
298    #[tokio::test]
299    async fn test_error() {
300        let expected_error = "checkpoint_error";
301        let router_service = MockExecutionService::new();
302
303        let service_stack = CheckpointLayer::new(move |_req| Err(BoxError::from(expected_error)))
304            .layer(router_service);
305
306        let request = ExecutionRequest::fake_builder().build();
307
308        let actual_error = service_stack
309            .oneshot(request)
310            .await
311            .map(|_| unreachable!())
312            .unwrap_err()
313            .to_string();
314
315        assert_eq!(actual_error, expected_error)
316    }
317}