1use crate::backend::Backend;
2use crate::builder::WorkerBuilder;
3use crate::codec::Codec;
4use crate::error::{BoxDynError, Error};
5use crate::request::{Parts, Request};
6use crate::service_fn::{service_fn, ServiceFn};
7use crate::storage::Storage;
8use crate::worker::{Ready, Worker};
9use futures::future::BoxFuture;
10use futures::FutureExt;
11use serde::de::DeserializeOwned;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::fmt::Debug;
15use std::future::Future;
16use std::hash::Hash;
17use std::marker::PhantomData;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20use std::time::Duration;
21use tower::Layer;
22use tower::Service;
23
24type BoxedService<Input, Output> = tower::util::BoxService<Input, Output, crate::error::Error>;
25
26type SteppedService<Compact, Index, Ctx> =
27 BoxedService<Request<StepRequest<Compact, Index>, Ctx>, GoTo<Compact>>;
28
29#[derive(Debug, Serialize, Deserialize, Clone)]
31pub enum GoTo<N = ()> {
32 Next(N),
34 Delay {
36 next: N,
38 delay: Duration,
40 },
41 Done(N),
43}
44
45#[derive(Debug)]
47pub struct StepBuilder<Ctx, Compact, Input, Current, Encode, Index = usize> {
48 steps: HashMap<Index, SteppedService<Compact, Index, Ctx>>,
49 current_index: Index,
50 current: PhantomData<Current>,
51 codec: PhantomData<Encode>,
52 input: PhantomData<Input>,
53}
54
55impl<Ctx, Compact, Input, Encode, Index: Default> Default
56 for StepBuilder<Ctx, Compact, Input, Input, Encode, Index>
57{
58 fn default() -> Self {
59 Self {
60 steps: HashMap::new(),
61 current_index: Index::default(),
62 current: PhantomData,
63 codec: PhantomData,
64 input: PhantomData,
65 }
66 }
67}
68
69impl<Ctx, Compact, Input, Encode> StepBuilder<Ctx, Compact, Input, Input, Encode, usize> {
70 pub fn new() -> Self {
72 Self {
73 steps: HashMap::new(),
74 current_index: usize::default(),
75 current: PhantomData,
76 codec: PhantomData,
77 input: PhantomData,
78 }
79 }
80
81 pub fn new_with_stepper<I: Default>() -> StepBuilder<Ctx, Compact, Input, Input, Encode, I> {
83 StepBuilder {
84 steps: HashMap::new(),
85 current_index: I::default(),
86 current: PhantomData,
87 codec: PhantomData,
88 input: PhantomData,
89 }
90 }
91}
92
93impl<Ctx, Compact, Input, Current, Encode, Index>
109 StepBuilder<Ctx, Compact, Input, Current, Encode, Index>
110{
111 pub fn build<S>(self, store: S) -> StepService<Ctx, Compact, Input, S, Index> {
113 StepService {
114 inner: self.steps,
115 storage: store,
116 input: PhantomData,
117 }
118 }
119}
120
121#[derive(Debug)]
123pub struct StepService<Ctx, Compact, Input, S, Index> {
124 inner: HashMap<Index, SteppedService<Compact, Index, Ctx>>,
125 storage: S,
126 input: PhantomData<Input>,
127}
128
129impl<
130 Ctx,
131 Compact,
132 S: Storage<Job = StepRequest<Compact, Index>> + Send + Clone + 'static,
133 Input,
134 Index,
135 > Service<Request<StepRequest<Compact, Index>, Ctx>>
136 for StepService<Ctx, Compact, Input, S, Index>
137where
138 Compact: DeserializeOwned + Send + Clone + 'static,
139 S::Error: Send + Sync + std::error::Error,
140 Index: StepIndex + Send + Sync + 'static,
141{
142 type Response = GoTo<Compact>;
143 type Error = crate::error::Error;
144 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
145
146 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
147 Poll::Ready(Ok(()))
148 }
149
150 fn call(&mut self, req: Request<StepRequest<Compact, Index>, Ctx>) -> Self::Future {
151 let index = &req.args.index;
152 let next_index = index.next();
153
154 let service = self
155 .inner
156 .get_mut(index)
157 .expect("Invalid index in inner services");
158 let fut = service.call(req);
160 let mut storage = self.storage.clone();
161 Box::pin(async move {
162 match fut.await {
163 Ok(response) => {
164 match &response {
165 GoTo::Next(resp) => {
166 storage
167 .push(StepRequest {
168 index: next_index,
169 step: resp.clone(),
170 })
171 .await
172 .map_err(|e| Error::SourceError(Arc::new(e.into())))?;
173 }
174 GoTo::Delay { next, delay } => {
175 storage
176 .schedule(
177 StepRequest {
178 index: next_index,
179 step: next.clone(),
180 },
181 delay.as_secs().try_into().unwrap(),
182 )
183 .await
184 .map_err(|e| Error::SourceError(Arc::new(e.into())))?;
185 }
186 GoTo::Done(_) => {
187 }
189 };
190 Ok(response)
191 }
192 Err(e) => Err(e),
193 }
194 })
195 }
196}
197
198struct TransformingService<S, Compact, Input, Current, Next, Codec> {
199 inner: S,
200 _req: PhantomData<Compact>,
201 _input: PhantomData<Input>,
202 _codec: PhantomData<Codec>,
203 _output: PhantomData<Next>,
204 _current: PhantomData<Current>,
205}
206
207impl<S, Compact, Codec, Input, Current, Next>
208 TransformingService<S, Compact, Input, Current, Next, Codec>
209{
210 fn new(inner: S) -> Self {
211 TransformingService {
212 inner,
213 _req: PhantomData,
214 _input: PhantomData,
215 _output: PhantomData,
216 _codec: PhantomData,
217 _current: PhantomData,
218 }
219 }
220}
221
222impl<S, Ctx, Input, Current, Next, Compact, Encode, Index>
223 Service<Request<StepRequest<Compact, Index>, Ctx>>
224 for TransformingService<S, Compact, Input, Current, Next, Encode>
225where
226 S: Service<Request<Current, Ctx>, Response = GoTo<Next>>,
227 Ctx: Default,
228 S::Future: Send + 'static,
229 Current: DeserializeOwned,
230 Next: Serialize,
231 Encode: Codec<Compact = Compact>,
232 Encode::Error: Debug,
233{
234 type Response = GoTo<Compact>;
235 type Error = S::Error;
236 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
237
238 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
239 self.inner.poll_ready(cx)
240 }
241
242 fn call(&mut self, req: Request<StepRequest<Compact, Index>, Ctx>) -> Self::Future {
243 let transformed_req: Request<Current, Ctx> = {
244 Request::new_with_parts(
245 Encode::decode(req.args.step).expect(&format!(
246 "Could not decode step, expecting {}",
247 std::any::type_name::<Current>()
248 )),
249 req.parts,
250 )
251 };
252 let fut = self.inner.call(transformed_req).map(|res| match res {
253 Ok(o) => Ok(match o {
254 GoTo::Next(next) => {
255 GoTo::Next(Encode::encode(next).expect("Could not encode the next step"))
256 }
257 GoTo::Delay { next, delay } => GoTo::Delay {
258 next: Encode::encode(next).expect("Could not encode the next step"),
259 delay,
260 },
261 GoTo::Done(res) => {
262 GoTo::Done(Encode::encode(res).expect("Could not encode the next step"))
263 }
264 }),
265 Err(e) => Err(e),
266 });
267
268 Box::pin(fut)
269 }
270}
271
272#[derive(Debug, Serialize, Deserialize)]
274pub struct StepRequest<T, Index = usize> {
275 step: T,
276 index: Index,
277}
278
279impl<T, Index> StepRequest<T, Index> {
280 pub fn new(step: T) -> Self
282 where
283 Index: Default,
284 {
285 Self {
286 step,
287 index: Index::default(),
288 }
289 }
290
291 pub fn new_with_index(step: T, index: Index) -> Self {
293 Self { step, index }
294 }
295}
296
297pub trait Step<S, Ctx, Compact, Input, Current, Next, Encode, Index> {
299 fn step(self, service: S) -> StepBuilder<Ctx, Compact, Input, Next, Encode, Index>;
301}
302
303impl<S, Ctx, Input, Current, Next, Compact, Encode, Index>
304 Step<S, Ctx, Compact, Input, Current, Next, Encode, Index>
305 for StepBuilder<Ctx, Compact, Input, Current, Encode, Index>
306where
307 S: Service<Request<Current, Ctx>, Response = GoTo<Next>, Error = crate::error::Error>
308 + Send
309 + 'static
310 + Sync,
311 S::Future: Send + 'static,
312 Current: DeserializeOwned + Send + 'static,
313 S::Response: 'static,
314 Input: Send + 'static + Serialize,
315 Ctx: Default + Send,
316 Next: 'static + Send + Serialize,
317 Compact: Send + 'static,
318 Encode: Codec<Compact = Compact> + Send + 'static,
319 Encode::Error: Debug,
320 Index: StepIndex,
321{
322 fn step(mut self, service: S) -> StepBuilder<Ctx, Compact, Input, Next, Encode, Index> {
323 let next = self.current_index.next();
324 self.steps.insert(
325 self.current_index,
326 BoxedService::new(TransformingService::<
327 S,
328 Compact,
329 Input,
330 Current,
331 Next,
332 Encode,
333 >::new(service)),
334 );
335 StepBuilder {
336 steps: self.steps,
337 current: PhantomData,
338 codec: PhantomData,
339 input: PhantomData,
340 current_index: next,
341 }
342 }
343}
344
345pub trait StepFn<F, FnArgs, Ctx, Compact, Input, Current, Next, Codec, Index> {
347 fn step_fn(self, f: F) -> StepBuilder<Ctx, Compact, Input, Next, Codec, Index>;
349}
350
351impl<
352 S,
353 Ctx: Send + Sync,
354 F: Send + Sync,
355 FnArgs: Send + Sync,
356 Input,
357 Current,
358 Next,
359 Compact,
360 Encode,
361 Index,
362 > StepFn<F, FnArgs, Ctx, Compact, Input, Current, Next, Encode, Index> for S
363where
364 S: Step<ServiceFn<F, Current, Ctx, FnArgs>, Ctx, Compact, Input, Current, Next, Encode, Index>,
365{
366 fn step_fn(self, f: F) -> StepBuilder<Ctx, Compact, Input, Next, Encode, Index> {
367 self.step(service_fn(f))
368 }
369}
370
371pub trait StepWorkerFactory<Ctx, Compact, Input, Output, Index> {
373 type Source;
375
376 type Service;
378
379 type Codec;
381 fn build_stepped(
390 self,
391 builder: StepBuilder<Ctx, Compact, Input, Output, Self::Codec, Index>,
392 ) -> Worker<Ready<Self::Service, Self::Source>>;
393}
394
395impl<Req, P, M, Ctx, Input, Compact, Output, Index>
396 StepWorkerFactory<Ctx, Compact, Input, Output, Index>
397 for WorkerBuilder<Req, Ctx, P, M, StepService<Ctx, Compact, Input, P, Index>>
398where
399 Compact: Send + 'static + Sync,
400 P: Backend<Request<StepRequest<Compact, Index>, Ctx>> + 'static,
401 P: Storage<Job = StepRequest<Compact, Index>> + Clone,
402 M: Layer<StepService<Ctx, Compact, Input, P, Index>> + 'static,
403{
404 type Source = P;
405
406 type Service = M::Service;
407
408 type Codec = <P as Backend<Request<StepRequest<Compact, Index>, Ctx>>>::Codec;
409
410 fn build_stepped(
411 self,
412 builder: StepBuilder<Ctx, Compact, Input, Output, Self::Codec, Index>,
413 ) -> Worker<Ready<M::Service, P>> {
414 let worker_id = self.id;
415 let poller = self.source;
416 let middleware = self.layer;
417 let service = builder.build(poller.clone());
418 let service = middleware.service(service);
419
420 Worker::new(worker_id, Ready::new(service, poller))
421 }
422}
423
424#[derive(Debug, thiserror::Error)]
426pub enum StepError {
427 #[error("CodecError: {0}")]
429 CodecError(BoxDynError),
430 #[error("StorageError: {0}")]
432 StorageError(BoxDynError),
433}
434
435pub trait SteppableStorage<S: Storage, Codec, Compact, Input, Index> {
437 fn push_step<T: Serialize + Send>(
439 &mut self,
440 step: StepRequest<T, Index>,
441 ) -> impl Future<Output = Result<Parts<S::Context>, StepError>> + Send;
442
443 fn start_stepped(
445 &mut self,
446 step: Input,
447 ) -> impl Future<Output = Result<Parts<S::Context>, StepError>> + Send
448 where
449 Input: Serialize + Send,
450 Index: Default,
451 Self: Send,
452 {
453 async {
454 self.push_step(StepRequest {
455 step,
456 index: Index::default(),
457 })
458 .await
459 }
460 }
461}
462
463impl<S, Encode, Compact, Input, Index> SteppableStorage<S, Encode, Compact, Input, Index> for S
464where
465 S: Storage<Job = StepRequest<Compact, Index>, Codec = Encode>
466 + Backend<Request<StepRequest<Compact, Index>, <S as Storage>::Context>>
467 + Send,
468 Encode: Codec<Compact = Compact>,
469 Encode::Error: std::error::Error + Send + Sync + 'static,
470 S::Error: std::error::Error + Send + Sync + 'static,
471 Compact: Send,
472 Index: Send,
473{
474 async fn push_step<T: Serialize + Send>(
475 &mut self,
476 step: StepRequest<T, Index>,
477 ) -> Result<Parts<S::Context>, StepError> {
478 self.push(StepRequest {
479 index: step.index,
480 step: Encode::encode(&step.step).map_err(|e| StepError::CodecError(Box::new(e)))?,
481 })
482 .await
483 .map_err(|e| StepError::StorageError(Box::new(e)))
484 }
485}
486
487pub trait StepIndex: Eq + Hash {
490 fn next(&self) -> Self;
492}
493
494impl StepIndex for usize {
495 fn next(&self) -> Self {
496 *self + 1
497 }
498}
499
500impl StepIndex for u32 {
501 fn next(&self) -> Self {
502 *self + 1
503 }
504}