1use std::ops::ControlFlow;
13use std::sync::Arc;
14
15use futures::future::BoxFuture;
16use tower::BoxError;
17use tower::Layer;
18use tower::Service;
19
20#[allow(clippy::type_complexity)]
22pub struct CheckpointLayer<S, Request>
23where
24 S: Service<Request> + Send + 'static,
25 Request: Send + 'static,
26 S::Future: Send,
27 S::Response: Send + 'static,
28 S::Error: Send + 'static,
29{
30 checkpoint_fn: Arc<
31 dyn Fn(
32 Request,
33 ) -> Result<
34 ControlFlow<<S as Service<Request>>::Response, Request>,
35 <S as Service<Request>>::Error,
36 > + Send
37 + Sync
38 + 'static,
39 >,
40}
41
42#[allow(clippy::type_complexity)]
43impl<S, Request> CheckpointLayer<S, Request>
44where
45 S: Service<Request> + Send + 'static,
46 Request: Send + 'static,
47 S::Future: Send,
48 S::Response: Send + 'static,
49 <S as Service<Request>>::Error: Into<BoxError> + Send + 'static,
50{
51 pub fn new(
53 checkpoint_fn: impl Fn(
54 Request,
55 ) -> Result<
56 ControlFlow<<S as Service<Request>>::Response, Request>,
57 <S as Service<Request>>::Error,
58 > + Send
59 + Sync
60 + 'static,
61 ) -> Self {
62 Self {
63 checkpoint_fn: Arc::new(checkpoint_fn),
64 }
65 }
66}
67
68impl<S, Request> Layer<S> for CheckpointLayer<S, Request>
69where
70 S: Service<Request> + Send + 'static,
71 <S as Service<Request>>::Future: Send,
72 Request: Send + 'static,
73 <S as Service<Request>>::Response: Send + 'static,
74 <S as Service<Request>>::Error: Into<BoxError> + Send + 'static,
75{
76 type Service = CheckpointService<S, Request>;
77
78 fn layer(&self, service: S) -> Self::Service {
79 CheckpointService {
80 checkpoint_fn: Arc::clone(&self.checkpoint_fn),
81 inner: service,
82 }
83 }
84}
85
86#[allow(clippy::type_complexity)]
88pub struct CheckpointService<S, Request>
89where
90 Request: Send + 'static,
91 S: Service<Request> + Send + 'static,
92 <S as Service<Request>>::Error: Into<BoxError> + Send + 'static,
93 <S as Service<Request>>::Response: Send + 'static,
94 <S as Service<Request>>::Future: Send + 'static,
95{
96 inner: S,
97 checkpoint_fn: Arc<
98 dyn Fn(
99 Request,
100 ) -> Result<
101 ControlFlow<<S as Service<Request>>::Response, Request>,
102 <S as Service<Request>>::Error,
103 > + Send
104 + Sync
105 + 'static,
106 >,
107}
108
109impl<S, Request> Clone for CheckpointService<S, Request>
110where
111 S: Clone,
112 Request: Send + 'static,
114 S: Service<Request> + Send + 'static,
115 <S as Service<Request>>::Error: Into<BoxError> + Send + 'static,
116 <S as Service<Request>>::Response: Send + 'static,
117 <S as Service<Request>>::Future: Send + 'static,
118{
119 fn clone(&self) -> Self {
120 Self {
121 inner: self.inner.clone(),
122 checkpoint_fn: self.checkpoint_fn.clone(),
123 }
124 }
125}
126
127#[allow(clippy::type_complexity)]
128impl<S, Request> CheckpointService<S, Request>
129where
130 Request: Send + 'static,
131 S: Service<Request> + Send + 'static,
132 <S as Service<Request>>::Error: Into<BoxError> + Send + 'static,
133 <S as Service<Request>>::Response: Send + 'static,
134 <S as Service<Request>>::Future: Send + 'static,
135{
136 pub fn new(
138 checkpoint_fn: impl Fn(
139 Request,
140 ) -> Result<
141 ControlFlow<<S as Service<Request>>::Response, Request>,
142 <S as Service<Request>>::Error,
143 > + Send
144 + Sync
145 + 'static,
146 inner: S,
147 ) -> Self {
148 Self {
149 checkpoint_fn: Arc::new(checkpoint_fn),
150 inner,
151 }
152 }
153}
154
155impl<S, Request> Service<Request> for CheckpointService<S, Request>
156where
157 S: Service<Request>,
158 S: Send + 'static,
159 S::Future: Send,
160 Request: Send + 'static,
161 <S as Service<Request>>::Response: Send + 'static,
162 <S as Service<Request>>::Error: Into<BoxError> + Send + 'static,
163{
164 type Response = <S as Service<Request>>::Response;
165
166 type Error = <S as Service<Request>>::Error;
167
168 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
169
170 fn poll_ready(
171 &mut self,
172 cx: &mut std::task::Context<'_>,
173 ) -> std::task::Poll<Result<(), Self::Error>> {
174 self.inner.poll_ready(cx)
175 }
176
177 fn call(&mut self, req: Request) -> Self::Future {
178 match (self.checkpoint_fn)(req) {
179 Ok(ControlFlow::Break(response)) => Box::pin(async move { Ok(response) }),
180 Ok(ControlFlow::Continue(request)) => Box::pin(self.inner.call(request)),
181 Err(error) => Box::pin(async move { Err(error) }),
182 }
183 }
184}
185
186#[cfg(test)]
187mod checkpoint_tests {
188 use tower::BoxError;
189 use tower::Layer;
190 use tower::ServiceBuilder;
191 use tower::ServiceExt;
192
193 use super::*;
194 use crate::layers::ServiceBuilderExt;
195 use crate::plugin::test::MockExecutionService;
196 use crate::services::ExecutionRequest;
197 use crate::services::ExecutionResponse;
198
199 #[tokio::test]
200 async fn test_service_builder() {
201 let expected_label = "from_mock_service";
202
203 let mut execution_service = MockExecutionService::new();
204
205 execution_service
206 .expect_call()
207 .times(1)
208 .returning(move |req: ExecutionRequest| {
209 Ok(ExecutionResponse::fake_builder()
210 .label(expected_label.to_string())
211 .context(req.context)
212 .build()
213 .unwrap())
214 });
215
216 let service_stack = ServiceBuilder::new()
217 .checkpoint(|req: ExecutionRequest| Ok(ControlFlow::Continue(req)))
218 .service(execution_service);
219
220 let request = ExecutionRequest::fake_builder().build();
221
222 let actual_label = service_stack
223 .oneshot(request)
224 .await
225 .unwrap()
226 .next_response()
227 .await
228 .unwrap()
229 .label
230 .unwrap();
231
232 assert_eq!(actual_label, expected_label)
233 }
234
235 #[tokio::test]
236 async fn test_continue() {
237 let expected_label = "from_mock_service";
238 let mut router_service = MockExecutionService::new();
239
240 router_service
241 .expect_call()
242 .times(1)
243 .returning(move |_req| {
244 Ok(ExecutionResponse::fake_builder()
245 .label(expected_label.to_string())
246 .build()
247 .unwrap())
248 });
249
250 let service_stack =
251 CheckpointLayer::new(|req| Ok(ControlFlow::Continue(req))).layer(router_service);
252
253 let request = ExecutionRequest::fake_builder().build();
254
255 let actual_label = service_stack
256 .oneshot(request)
257 .await
258 .unwrap()
259 .next_response()
260 .await
261 .unwrap()
262 .label
263 .unwrap();
264
265 assert_eq!(actual_label, expected_label)
266 }
267
268 #[tokio::test]
269 async fn test_return() {
270 let expected_label = "returned_before_mock_service";
271 let router_service = MockExecutionService::new();
272
273 let service_stack = CheckpointLayer::new(|_req| {
274 Ok(ControlFlow::Break(
275 ExecutionResponse::fake_builder()
276 .label("returned_before_mock_service".to_string())
277 .build()
278 .unwrap(),
279 ))
280 })
281 .layer(router_service);
282
283 let request = ExecutionRequest::fake_builder().build();
284
285 let actual_label = service_stack
286 .oneshot(request)
287 .await
288 .unwrap()
289 .next_response()
290 .await
291 .unwrap()
292 .label
293 .unwrap();
294
295 assert_eq!(actual_label, expected_label)
296 }
297
298 #[tokio::test]
299 async fn test_error() {
300 let expected_error = "checkpoint_error";
301 let router_service = MockExecutionService::new();
302
303 let service_stack = CheckpointLayer::new(move |_req| Err(BoxError::from(expected_error)))
304 .layer(router_service);
305
306 let request = ExecutionRequest::fake_builder().build();
307
308 let actual_error = service_stack
309 .oneshot(request)
310 .await
311 .map(|_| unreachable!())
312 .unwrap_err()
313 .to_string();
314
315 assert_eq!(actual_error, expected_error)
316 }
317}