apalis_workflow/fold/
mod.rs

1use std::{marker::PhantomData, task::Context};
2
3use apalis_core::{
4    backend::{BackendExt, TaskSinkError, codec::Codec},
5    error::BoxDynError,
6    task::{Task, builder::TaskBuilder, metadata::MetadataExt, task_id::TaskId},
7    task_fn::{TaskFn, task_fn},
8};
9use futures::{FutureExt, Sink, SinkExt, future::BoxFuture};
10use serde::{Deserialize, Serialize};
11use tower::Service;
12
13use crate::{
14    SteppedService,
15    context::{StepContext, WorkflowContext},
16    id_generator::GenerateId,
17    router::{GoTo, StepResult, WorkflowRouter},
18    step::{Layer, Stack, Step},
19    workflow::Workflow,
20};
21
22/// The fold layer that folds over a collection of items.
23#[derive(Clone, Debug)]
24pub struct Fold<F, Init> {
25    fold: F,
26    _marker: std::marker::PhantomData<Init>,
27}
28
29impl<F, Init, S> Layer<S> for Fold<F, Init>
30where
31    F: Clone,
32    Init: Clone,
33{
34    type Step = FoldStep<S, F, Init>;
35
36    fn layer(&self, step: S) -> Self::Step {
37        FoldStep {
38            inner: step,
39            fold: self.fold.clone(),
40            _marker: std::marker::PhantomData,
41        }
42    }
43}
44impl<Start, C, L, I: IntoIterator<Item = C>, B: BackendExt> Workflow<Start, I, B, L> {
45    /// Folds over a collection of items in the workflow.
46    pub fn fold<F, Output, FnArgs, Init>(
47        self,
48        fold: F,
49    ) -> Workflow<Start, Output, B, Stack<Fold<TaskFn<F, (Init, C), B::Context, FnArgs>, Init>, L>>
50    where
51        TaskFn<F, (Init, C), B::Context, FnArgs>:
52            Service<Task<(Init, C), B::Context, B::IdType>, Response = Output>,
53    {
54        self.add_step(Fold {
55            fold: task_fn(fold),
56            _marker: PhantomData,
57        })
58    }
59}
60
61/// The fold step that folds over a collection of items.
62#[derive(Clone, Debug)]
63pub struct FoldStep<S, F, Init> {
64    inner: S,
65    fold: F,
66    _marker: std::marker::PhantomData<Init>,
67}
68
69impl<S, F, Input, I: IntoIterator<Item = Input>, Init, B, MetaErr, Err, CodecError> Step<I, B>
70    for FoldStep<S, F, Init>
71where
72    F: Service<Task<(Init, Input), B::Context, B::IdType>, Response = Init>
73        + Send
74        + 'static
75        + Clone,
76    S: Step<Init, B>,
77    B: BackendExt<Error = Err>
78        + Send
79        + Sync
80        + Clone
81        + Sink<Task<B::Compact, B::Context, B::IdType>, Error = Err>
82        + Unpin
83        + 'static,
84    I: IntoIterator<Item = Input> + Send + 'static,
85    B::Context: MetadataExt<FoldState, Error = MetaErr>
86        + MetadataExt<WorkflowContext, Error = MetaErr>
87        + Send
88        + 'static,
89    B::Codec: Codec<(Init, Vec<Input>), Error = CodecError, Compact = B::Compact>
90        + Codec<Init, Error = CodecError, Compact = B::Compact>
91        + Codec<I, Error = CodecError, Compact = B::Compact>
92        + Codec<(Init, Input), Error = CodecError, Compact = B::Compact>
93        + 'static,
94    B::IdType: GenerateId + Send + 'static + Clone,
95    Init: Default + Send + 'static,
96    Err: std::error::Error + Send + Sync + 'static,
97    CodecError: std::error::Error + Send + Sync + 'static,
98    F::Error: Into<BoxDynError> + Send + 'static,
99    MetaErr: std::error::Error + Send + Sync + 'static,
100    F::Future: Send + 'static,
101    B::Compact: Send + 'static,
102    Input: Send + 'static,
103{
104    type Response = Init;
105    type Error = F::Error;
106    fn register(&mut self, ctx: &mut WorkflowRouter<B>) -> Result<(), BoxDynError> {
107        let svc = SteppedService::new(FoldService {
108            fold: self.fold.clone(),
109            _marker: PhantomData::<(Init, I, B)>,
110        });
111        let count = ctx.steps.len();
112        ctx.steps.insert(count, svc);
113        self.inner.register(ctx)
114    }
115}
116
117/// The fold service that handles folding over a collection of items.
118#[derive(Clone, Debug)]
119pub struct FoldService<F, Init, I, B> {
120    fold: F,
121    _marker: std::marker::PhantomData<(Init, I, B)>,
122}
123
124impl<F, Init, I, B> FoldService<F, Init, I, B> {
125    /// Creates a new `FoldService` with the given fold function.
126    pub fn new(fold: F) -> Self {
127        Self {
128            fold,
129            _marker: std::marker::PhantomData,
130        }
131    }
132}
133
134impl<F, Init, I, B, Input, CodecError, MetaErr, Err>
135    Service<Task<B::Compact, B::Context, B::IdType>> for FoldService<F, Init, I, B>
136where
137    F: Service<Task<(Init, Input), B::Context, B::IdType>, Response = Init>
138        + Send
139        + 'static
140        + Clone,
141    B: BackendExt<Error = Err>
142        + Send
143        + Sync
144        + Clone
145        + Sink<Task<B::Compact, B::Context, B::IdType>, Error = Err>
146        + Unpin
147        + 'static,
148    I: IntoIterator<Item = Input> + Send + 'static,
149    B::Context: MetadataExt<FoldState, Error = MetaErr>
150        + MetadataExt<WorkflowContext, Error = MetaErr>
151        + Send
152        + 'static,
153    B::Codec: Codec<(Init, Vec<Input>), Error = CodecError, Compact = B::Compact>
154        + Codec<Init, Error = CodecError, Compact = B::Compact>
155        + Codec<I, Error = CodecError, Compact = B::Compact>
156        + Codec<(Init, Input), Error = CodecError, Compact = B::Compact>
157        + 'static,
158    B::IdType: GenerateId + Send + 'static,
159    Init: Default + Send + 'static,
160    Err: std::error::Error + Send + Sync + 'static,
161    CodecError: std::error::Error + Send + Sync + 'static,
162    F::Error: Into<BoxDynError> + Send + 'static,
163    MetaErr: std::error::Error + Send + Sync + 'static,
164    F::Future: Send + 'static,
165    B::Compact: Send + 'static,
166    Input: Send + 'static,
167{
168    type Response = GoTo<StepResult<B::Compact, B::IdType>>;
169    type Error = BoxDynError;
170    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
171
172    fn poll_ready(&mut self, cx: &mut Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
173        self.fold.poll_ready(cx).map_err(|e| e.into())
174    }
175
176    fn call(&mut self, task: Task<B::Compact, B::Context, B::IdType>) -> Self::Future {
177        let state = task.parts.ctx.extract().unwrap_or(FoldState::Unknown);
178        let mut ctx = task.parts.data.get::<StepContext<B>>().cloned().unwrap();
179        let mut fold = self.fold.clone();
180
181        match state {
182            FoldState::Unknown => async move {
183                let task_id = TaskId::new(B::IdType::generate());
184                let steps: Task<I, _, _> = task.try_map(|arg| B::Codec::decode(&arg))?;
185                let steps = steps.args.into_iter().collect::<Vec<_>>();
186                let task = TaskBuilder::new(B::Codec::encode(&(Init::default(), steps))?)
187                    .meta(WorkflowContext {
188                        step_index: ctx.current_step,
189                    })
190                    .with_task_id(task_id.clone())
191                    .meta(FoldState::Collection)
192                    .build();
193                ctx.backend
194                    .send(task)
195                    .await
196                    .map_err(|e| TaskSinkError::PushError(e))?;
197                Ok(GoTo::Next(StepResult {
198                    result: B::Codec::encode(&Init::default())?,
199                    next_task_id: Some(task_id),
200                }))
201            }
202            .boxed(),
203            FoldState::Collection => async move {
204                let args: (Init, Vec<Input>) = B::Codec::decode(&task.args)?;
205                let (acc, items) = args;
206
207                let mut items = items.into_iter();
208                let next = items.next().unwrap();
209                let rest = items.collect::<Vec<_>>();
210                let fold_task = task.map(|_| (acc, next));
211                let response = fold.call(fold_task).await.map_err(|e| e.into())?;
212
213                match rest.len() {
214                    0 if ctx.has_next => {
215                        let task_id = TaskId::new(B::IdType::generate());
216                        let result = B::Codec::encode(&response)?;
217                        let next_step = TaskBuilder::new(result)
218                            .with_task_id(task_id.clone())
219                            .meta(WorkflowContext {
220                                step_index: ctx.current_step + 1,
221                            })
222                            .build();
223                        ctx.backend
224                            .send(next_step)
225                            .await
226                            .map_err(|e| TaskSinkError::PushError(e))?;
227                        Ok(GoTo::Break(StepResult {
228                            result: B::Codec::encode(&response)?,
229                            next_task_id: Some(task_id),
230                        }))
231                    }
232                    0 => Ok(GoTo::Break(StepResult {
233                        result: B::Codec::encode(&response)?,
234                        next_task_id: None,
235                    })),
236                    1.. => {
237                        // Shouldn't this be limited?
238                        let task_id = TaskId::new(B::IdType::generate());
239                        let result = B::Codec::encode(&response)?;
240                        let steps = TaskBuilder::new(B::Codec::encode(&(response, rest))?)
241                            .with_task_id(task_id.clone())
242                            .meta(WorkflowContext {
243                                step_index: ctx.current_step,
244                            })
245                            .meta(FoldState::Collection)
246                            .build();
247                        ctx.backend
248                            .send(steps)
249                            .await
250                            .map_err(|e| TaskSinkError::PushError(e))?;
251                        Ok(GoTo::Next(StepResult {
252                            result,
253                            next_task_id: Some(task_id),
254                        }))
255                    }
256                }
257            }
258            .boxed(),
259        }
260    }
261}
262
263/// The state of the fold operation
264#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
265pub enum FoldState {
266    /// Unknown
267    Unknown,
268    /// Collection has started
269    Collection,
270}