apalis_workflow/sequential/fold/
mod.rs

1use std::{marker::PhantomData, task::Context};
2
3use apalis_core::{
4    backend::{BackendExt, TaskSinkError, codec::Codec},
5    error::BoxDynError,
6    task::{Task, builder::TaskBuilder, metadata::MetadataExt, task_id::TaskId},
7    task_fn::{TaskFn, task_fn},
8};
9use futures::{FutureExt, Sink, SinkExt, future::BoxFuture};
10use serde::{Deserialize, Serialize};
11use tower::Service;
12
13use crate::{
14    SteppedService,
15    id_generator::GenerateId,
16    sequential::context::{StepContext, WorkflowContext},
17    sequential::router::{GoTo, StepResult, WorkflowRouter},
18    sequential::step::{Layer, Stack, Step},
19    sequential::workflow::Workflow,
20};
21
22/// The fold layer that folds over a collection of items.
23#[derive(Clone, Debug)]
24pub struct Fold<F, Init> {
25    fold: F,
26    _marker: std::marker::PhantomData<Init>,
27}
28
29impl<F, Init, S> Layer<S> for Fold<F, Init>
30where
31    F: Clone,
32    Init: Clone,
33{
34    type Step = FoldStep<S, F, Init>;
35
36    fn layer(&self, step: S) -> Self::Step {
37        FoldStep {
38            inner: step,
39            fold: self.fold.clone(),
40            _marker: std::marker::PhantomData,
41        }
42    }
43}
44impl<Start, C, L, I: IntoIterator<Item = C>, B: BackendExt> Workflow<Start, I, B, L> {
45    /// Folds over a collection of items in the workflow.
46    pub fn fold<F, Output, FnArgs, Init>(
47        self,
48        fold: F,
49    ) -> Workflow<Start, Output, B, Stack<Fold<TaskFn<F, (Init, C), B::Context, FnArgs>, Init>, L>>
50    where
51        TaskFn<F, (Init, C), B::Context, FnArgs>:
52            Service<Task<(Init, C), B::Context, B::IdType>, Response = Output>,
53    {
54        self.add_step(Fold {
55            fold: task_fn(fold),
56            _marker: PhantomData,
57        })
58    }
59}
60
61/// The fold step that folds over a collection of items.
62#[derive(Clone, Debug)]
63pub struct FoldStep<S, F, Init> {
64    inner: S,
65    fold: F,
66    _marker: std::marker::PhantomData<Init>,
67}
68
69impl<S, F, Input, I: IntoIterator<Item = Input>, Init, B, MetaErr, Err, CodecError> Step<I, B>
70    for FoldStep<S, F, Init>
71where
72    F: Service<Task<(Init, Input), B::Context, B::IdType>, Response = Init>
73        + Send
74        + Sync
75        + 'static
76        + Clone,
77    S: Step<Init, B>,
78    B: BackendExt<Error = Err>
79        + Send
80        + Sync
81        + Clone
82        + Sink<Task<B::Compact, B::Context, B::IdType>, Error = Err>
83        + Unpin
84        + 'static,
85    I: IntoIterator<Item = Input> + Send + Sync + 'static,
86    B::Context: MetadataExt<FoldState, Error = MetaErr>
87        + MetadataExt<WorkflowContext, Error = MetaErr>
88        + Send
89        + 'static,
90    B::Codec: Codec<(Init, Vec<Input>), Error = CodecError, Compact = B::Compact>
91        + Codec<Init, Error = CodecError, Compact = B::Compact>
92        + Codec<I, Error = CodecError, Compact = B::Compact>
93        + Codec<(Init, Input), Error = CodecError, Compact = B::Compact>
94        + 'static,
95    B::IdType: GenerateId + Send + 'static + Clone,
96    Init: Default + Send + Sync + 'static,
97    Err: std::error::Error + Send + Sync + 'static,
98    CodecError: std::error::Error + Send + Sync + 'static,
99    F::Error: Into<BoxDynError> + Send + 'static,
100    MetaErr: std::error::Error + Send + Sync + 'static,
101    F::Future: Send + 'static,
102    B::Compact: Send + 'static,
103    Input: Send + 'static,
104{
105    type Response = Init;
106    type Error = F::Error;
107    fn register(&mut self, ctx: &mut WorkflowRouter<B>) -> Result<(), BoxDynError> {
108        let svc = SteppedService::new(FoldService {
109            fold: self.fold.clone(),
110            _marker: PhantomData::<(Init, I, B)>,
111        });
112        let count = ctx.steps.len();
113        ctx.steps.insert(count, svc);
114        self.inner.register(ctx)
115    }
116}
117
118/// The fold service that handles folding over a collection of items.
119#[derive(Debug)]
120pub struct FoldService<F, Init, I, B> {
121    fold: F,
122    _marker: std::marker::PhantomData<(Init, I, B)>,
123}
124
125impl<F: Clone, Init, I, B> Clone for FoldService<F, Init, I, B> {
126    fn clone(&self) -> Self {
127        Self {
128            fold: self.fold.clone(),
129            _marker: std::marker::PhantomData,
130        }
131    }
132}
133
134impl<F, Init, I, B> FoldService<F, Init, I, B> {
135    /// Creates a new `FoldService` with the given fold function.
136    pub fn new(fold: F) -> Self {
137        Self {
138            fold,
139            _marker: std::marker::PhantomData,
140        }
141    }
142}
143
144impl<F, Init, I, B, Input, CodecError, MetaErr, Err>
145    Service<Task<B::Compact, B::Context, B::IdType>> for FoldService<F, Init, I, B>
146where
147    F: Service<Task<(Init, Input), B::Context, B::IdType>, Response = Init>
148        + Send
149        + 'static
150        + Clone,
151    B: BackendExt<Error = Err>
152        + Send
153        + Sync
154        + Clone
155        + Sink<Task<B::Compact, B::Context, B::IdType>, Error = Err>
156        + Unpin
157        + 'static,
158    I: IntoIterator<Item = Input> + Send + 'static,
159    B::Context: MetadataExt<FoldState, Error = MetaErr>
160        + MetadataExt<WorkflowContext, Error = MetaErr>
161        + Send
162        + 'static,
163    B::Codec: Codec<(Init, Vec<Input>), Error = CodecError, Compact = B::Compact>
164        + Codec<Init, Error = CodecError, Compact = B::Compact>
165        + Codec<I, Error = CodecError, Compact = B::Compact>
166        + Codec<(Init, Input), Error = CodecError, Compact = B::Compact>
167        + 'static,
168    B::IdType: GenerateId + Send + 'static,
169    Init: Default + Send + 'static,
170    Err: std::error::Error + Send + Sync + 'static,
171    CodecError: std::error::Error + Send + Sync + 'static,
172    F::Error: Into<BoxDynError> + Send + 'static,
173    MetaErr: std::error::Error + Send + Sync + 'static,
174    F::Future: Send + 'static,
175    B::Compact: Send + 'static,
176    Input: Send + 'static,
177{
178    type Response = GoTo<StepResult<B::Compact, B::IdType>>;
179    type Error = BoxDynError;
180    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
181
182    fn poll_ready(&mut self, cx: &mut Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
183        self.fold.poll_ready(cx).map_err(|e| e.into())
184    }
185
186    fn call(&mut self, task: Task<B::Compact, B::Context, B::IdType>) -> Self::Future {
187        let state = task.parts.ctx.extract().unwrap_or(FoldState::Unknown);
188        let mut ctx = task.parts.data.get::<StepContext<B>>().cloned().unwrap();
189        let mut fold = self.fold.clone();
190
191        match state {
192            FoldState::Unknown => async move {
193                let task_id = TaskId::new(B::IdType::generate());
194                let steps: Task<I, _, _> = task.try_map(|arg| B::Codec::decode(&arg))?;
195                let steps = steps.args.into_iter().collect::<Vec<_>>();
196                let task = TaskBuilder::new(B::Codec::encode(&(Init::default(), steps))?)
197                    .meta(WorkflowContext {
198                        step_index: ctx.current_step,
199                    })
200                    .with_task_id(task_id.clone())
201                    .meta(FoldState::Collection)
202                    .build();
203                ctx.backend
204                    .send(task)
205                    .await
206                    .map_err(|e| TaskSinkError::PushError(e))?;
207                Ok(GoTo::Next(StepResult {
208                    result: B::Codec::encode(&Init::default())?,
209                    next_task_id: Some(task_id),
210                }))
211            }
212            .boxed(),
213            FoldState::Collection => async move {
214                let args: (Init, Vec<Input>) = B::Codec::decode(&task.args)?;
215                let (acc, items) = args;
216
217                let mut items = items.into_iter();
218                let next = items.next().unwrap();
219                let rest = items.collect::<Vec<_>>();
220                let fold_task = task.map(|_| (acc, next));
221                let response = fold.call(fold_task).await.map_err(|e| e.into())?;
222
223                match rest.len() {
224                    0 if ctx.has_next => {
225                        let task_id = TaskId::new(B::IdType::generate());
226                        let result = B::Codec::encode(&response)?;
227                        let next_step = TaskBuilder::new(result)
228                            .with_task_id(task_id.clone())
229                            .meta(WorkflowContext {
230                                step_index: ctx.current_step + 1,
231                            })
232                            .build();
233                        ctx.backend
234                            .send(next_step)
235                            .await
236                            .map_err(|e| TaskSinkError::PushError(e))?;
237                        Ok(GoTo::Break(StepResult {
238                            result: B::Codec::encode(&response)?,
239                            next_task_id: Some(task_id),
240                        }))
241                    }
242                    0 => Ok(GoTo::Break(StepResult {
243                        result: B::Codec::encode(&response)?,
244                        next_task_id: None,
245                    })),
246                    1.. => {
247                        // Shouldn't this be limited?
248                        let task_id = TaskId::new(B::IdType::generate());
249                        let result = B::Codec::encode(&response)?;
250                        let steps = TaskBuilder::new(B::Codec::encode(&(response, rest))?)
251                            .with_task_id(task_id.clone())
252                            .meta(WorkflowContext {
253                                step_index: ctx.current_step,
254                            })
255                            .meta(FoldState::Collection)
256                            .build();
257                        ctx.backend
258                            .send(steps)
259                            .await
260                            .map_err(|e| TaskSinkError::PushError(e))?;
261                        Ok(GoTo::Next(StepResult {
262                            result,
263                            next_task_id: Some(task_id),
264                        }))
265                    }
266                }
267            }
268            .boxed(),
269        }
270    }
271}
272
273/// The state of the fold operation
274#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
275pub enum FoldState {
276    /// Unknown
277    Unknown,
278    /// Collection has started
279    Collection,
280}