apollo_router/layers/
sync_checkpoint.rs

1//! Synchronous Checkpoint.
2//!
3//! Provides a general mechanism for controlling the flow of a request. Useful in any situation
4//! where the caller wishes to provide control flow for a request.
5//!
6//! If the evaluated closure succeeds then the request is passed onto the next service in the
7//! chain of responsibilities. If it fails, then the control flow is broken and a response is passed
8//! back to the invoking service.
9//!
10//! See [`Layer`] and [`Service`] for more details.
11
12use std::ops::ControlFlow;
13use std::sync::Arc;
14
15use futures::future::BoxFuture;
16use tower::BoxError;
17use tower::Layer;
18use tower::Service;
19
20/// [`Layer`] for Synchronous Checkpoints. See [`ServiceBuilderExt::checkpoint()`](crate::layers::ServiceBuilderExt::checkpoint()).
21#[allow(clippy::type_complexity)]
22pub struct CheckpointLayer<S, Request>
23where
24    S: Service<Request> + Send + 'static,
25    Request: Send + 'static,
26    S::Future: Send,
27    S::Response: Send + 'static,
28    S::Error: Send + 'static,
29{
30    checkpoint_fn: Arc<
31        dyn Fn(
32                Request,
33            ) -> Result<
34                ControlFlow<<S as Service<Request>>::Response, Request>,
35                <S as Service<Request>>::Error,
36            > + Send
37            + Sync
38            + 'static,
39    >,
40}
41
42#[allow(clippy::type_complexity)]
43impl<S, Request> CheckpointLayer<S, Request>
44where
45    S: Service<Request> + Send + 'static,
46    Request: Send + 'static,
47    S::Future: Send,
48    S::Response: Send + 'static,
49    <S as Service<Request>>::Error: Into<BoxError> + Send + 'static,
50{
51    /// Create a `CheckpointLayer` from a function that takes a Service Request and returns a `ControlFlow`
52    pub fn new(
53        checkpoint_fn: impl Fn(
54            Request,
55        ) -> Result<
56            ControlFlow<<S as Service<Request>>::Response, Request>,
57            <S as Service<Request>>::Error,
58        > + Send
59        + Sync
60        + 'static,
61    ) -> Self {
62        Self {
63            checkpoint_fn: Arc::new(checkpoint_fn),
64        }
65    }
66}
67
68impl<S, Request> Layer<S> for CheckpointLayer<S, Request>
69where
70    S: Service<Request> + Send + 'static,
71    <S as Service<Request>>::Future: Send,
72    Request: Send + 'static,
73    <S as Service<Request>>::Response: Send + 'static,
74    <S as Service<Request>>::Error: Into<BoxError> + Send + 'static,
75{
76    type Service = CheckpointService<S, Request>;
77
78    fn layer(&self, service: S) -> Self::Service {
79        CheckpointService {
80            checkpoint_fn: Arc::clone(&self.checkpoint_fn),
81            inner: service,
82        }
83    }
84}
85
86/// [`Service`] for Synchronous Checkpoints. See [`ServiceBuilderExt::checkpoint()`](crate::layers::ServiceBuilderExt::checkpoint()).
87#[derive(Clone)]
88#[allow(clippy::type_complexity)]
89pub struct CheckpointService<S, Request>
90where
91    Request: Send + 'static,
92    S: Service<Request> + Send + 'static,
93    <S as Service<Request>>::Error: Into<BoxError> + Send + 'static,
94    <S as Service<Request>>::Response: Send + 'static,
95    <S as Service<Request>>::Future: Send + 'static,
96{
97    inner: S,
98    checkpoint_fn: Arc<
99        dyn Fn(
100                Request,
101            ) -> Result<
102                ControlFlow<<S as Service<Request>>::Response, Request>,
103                <S as Service<Request>>::Error,
104            > + Send
105            + Sync
106            + 'static,
107    >,
108}
109
110#[allow(clippy::type_complexity)]
111impl<S, Request> CheckpointService<S, Request>
112where
113    Request: Send + 'static,
114    S: Service<Request> + Send + 'static,
115    <S as Service<Request>>::Error: Into<BoxError> + Send + 'static,
116    <S as Service<Request>>::Response: Send + 'static,
117    <S as Service<Request>>::Future: Send + 'static,
118{
119    /// Create a `CheckpointLayer` from a function that takes a Service Request and returns a `ControlFlow`
120    pub fn new(
121        checkpoint_fn: impl Fn(
122            Request,
123        ) -> Result<
124            ControlFlow<<S as Service<Request>>::Response, Request>,
125            <S as Service<Request>>::Error,
126        > + Send
127        + Sync
128        + 'static,
129        inner: S,
130    ) -> Self {
131        Self {
132            checkpoint_fn: Arc::new(checkpoint_fn),
133            inner,
134        }
135    }
136}
137
138impl<S, Request> Service<Request> for CheckpointService<S, Request>
139where
140    S: Service<Request>,
141    S: Send + 'static,
142    S::Future: Send,
143    Request: Send + 'static,
144    <S as Service<Request>>::Response: Send + 'static,
145    <S as Service<Request>>::Error: Into<BoxError> + Send + 'static,
146{
147    type Response = <S as Service<Request>>::Response;
148
149    type Error = <S as Service<Request>>::Error;
150
151    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
152
153    fn poll_ready(
154        &mut self,
155        cx: &mut std::task::Context<'_>,
156    ) -> std::task::Poll<Result<(), Self::Error>> {
157        self.inner.poll_ready(cx)
158    }
159
160    fn call(&mut self, req: Request) -> Self::Future {
161        match (self.checkpoint_fn)(req) {
162            Ok(ControlFlow::Break(response)) => Box::pin(async move { Ok(response) }),
163            Ok(ControlFlow::Continue(request)) => Box::pin(self.inner.call(request)),
164            Err(error) => Box::pin(async move { Err(error) }),
165        }
166    }
167}
168
169#[cfg(test)]
170mod checkpoint_tests {
171    use tower::BoxError;
172    use tower::Layer;
173    use tower::ServiceBuilder;
174    use tower::ServiceExt;
175
176    use super::*;
177    use crate::layers::ServiceBuilderExt;
178    use crate::plugin::test::MockExecutionService;
179    use crate::services::ExecutionRequest;
180    use crate::services::ExecutionResponse;
181
182    #[tokio::test]
183    async fn test_service_builder() {
184        let expected_label = "from_mock_service";
185
186        let mut execution_service = MockExecutionService::new();
187
188        execution_service
189            .expect_call()
190            .times(1)
191            .returning(move |req: ExecutionRequest| {
192                Ok(ExecutionResponse::fake_builder()
193                    .label(expected_label.to_string())
194                    .context(req.context)
195                    .build()
196                    .unwrap())
197            });
198
199        let service_stack = ServiceBuilder::new()
200            .checkpoint(|req: ExecutionRequest| Ok(ControlFlow::Continue(req)))
201            .service(execution_service);
202
203        let request = ExecutionRequest::fake_builder().build();
204
205        let actual_label = service_stack
206            .oneshot(request)
207            .await
208            .unwrap()
209            .next_response()
210            .await
211            .unwrap()
212            .label
213            .unwrap();
214
215        assert_eq!(actual_label, expected_label)
216    }
217
218    #[tokio::test]
219    async fn test_continue() {
220        let expected_label = "from_mock_service";
221        let mut router_service = MockExecutionService::new();
222
223        router_service
224            .expect_call()
225            .times(1)
226            .returning(move |_req| {
227                Ok(ExecutionResponse::fake_builder()
228                    .label(expected_label.to_string())
229                    .build()
230                    .unwrap())
231            });
232
233        let service_stack =
234            CheckpointLayer::new(|req| Ok(ControlFlow::Continue(req))).layer(router_service);
235
236        let request = ExecutionRequest::fake_builder().build();
237
238        let actual_label = service_stack
239            .oneshot(request)
240            .await
241            .unwrap()
242            .next_response()
243            .await
244            .unwrap()
245            .label
246            .unwrap();
247
248        assert_eq!(actual_label, expected_label)
249    }
250
251    #[tokio::test]
252    async fn test_return() {
253        let expected_label = "returned_before_mock_service";
254        let router_service = MockExecutionService::new();
255
256        let service_stack = CheckpointLayer::new(|_req| {
257            Ok(ControlFlow::Break(
258                ExecutionResponse::fake_builder()
259                    .label("returned_before_mock_service".to_string())
260                    .build()
261                    .unwrap(),
262            ))
263        })
264        .layer(router_service);
265
266        let request = ExecutionRequest::fake_builder().build();
267
268        let actual_label = service_stack
269            .oneshot(request)
270            .await
271            .unwrap()
272            .next_response()
273            .await
274            .unwrap()
275            .label
276            .unwrap();
277
278        assert_eq!(actual_label, expected_label)
279    }
280
281    #[tokio::test]
282    async fn test_error() {
283        let expected_error = "checkpoint_error";
284        let router_service = MockExecutionService::new();
285
286        let service_stack = CheckpointLayer::new(move |_req| Err(BoxError::from(expected_error)))
287            .layer(router_service);
288
289        let request = ExecutionRequest::fake_builder().build();
290
291        let actual_error = service_stack
292            .oneshot(request)
293            .await
294            .map(|_| unreachable!())
295            .unwrap_err()
296            .to_string();
297
298        assert_eq!(actual_error, expected_error)
299    }
300}