apalis_workflow/sequential/repeat_until/
mod.rs

1use std::convert::Infallible;
2use std::marker::PhantomData;
3use std::task::Context;
4
5use apalis_core::backend::TaskSinkError;
6use apalis_core::backend::codec::Codec;
7use apalis_core::error::BoxDynError;
8use apalis_core::task::builder::TaskBuilder;
9use apalis_core::task::metadata::MetadataExt;
10use apalis_core::task::task_id::TaskId;
11use apalis_core::task_fn::{FromRequest, TaskFn, task_fn};
12use apalis_core::{backend::BackendExt, task::Task};
13use futures::future::BoxFuture;
14use futures::{FutureExt, Sink, SinkExt};
15use serde::{Deserialize, Serialize};
16use tower::Service;
17
18use crate::id_generator::GenerateId;
19use crate::sequential::router::WorkflowRouter;
20use crate::sequential::{GoTo, Layer, Stack, Step, StepContext, StepResult, WorkflowContext};
21use crate::{SteppedService, Workflow};
22
23/// A layer that represents a `repeat_until` step in the workflow.
24#[derive(Clone, Debug)]
25pub struct RepeatUntil<F, Input, Output> {
26    repeater: F,
27    _marker: PhantomData<(Input, Output)>,
28}
29
30impl<F, Input, Output, S> Layer<S> for RepeatUntil<F, Input, Output>
31where
32    F: Clone,
33{
34    type Step = RepeatUntilStep<S, F, Input, Output>;
35
36    fn layer(&self, step: S) -> Self::Step {
37        RepeatUntilStep {
38            inner: step,
39            repeater: self.repeater.clone(),
40            _marker: std::marker::PhantomData,
41        }
42    }
43}
44impl<Start, L, Input, B: BackendExt> Workflow<Start, Input, B, L> {
45    /// Folds over a collection of items in the workflow.
46    pub fn repeat_until<F, Output, FnArgs>(
47        self,
48        repeater: F,
49    ) -> Workflow<
50        Start,
51        Output,
52        B,
53        Stack<RepeatUntil<TaskFn<F, Input, B::Context, FnArgs>, Input, Output>, L>,
54    >
55    where
56        TaskFn<F, Input, B::Context, FnArgs>:
57            Service<Task<Input, B::Context, B::IdType>, Response = Option<Output>>,
58    {
59        self.add_step(RepeatUntil {
60            repeater: task_fn(repeater),
61            _marker: PhantomData::<(Input, Output)>,
62        })
63    }
64}
65
66/// The step implementation for the `repeat_until` layer.
67#[derive(Clone, Debug)]
68pub struct RepeatUntilStep<S, R, Input, Output> {
69    inner: S,
70    repeater: R,
71    _marker: PhantomData<(Input, Output)>,
72}
73
74/// The service that handles the `repeat_until` logic
75#[derive(Debug)]
76pub struct RepeatUntilService<F, B, Input, Output> {
77    repeater: F,
78    _marker: std::marker::PhantomData<(B, Input, Output)>,
79}
80
81impl<F, B, Input, Output> Clone for RepeatUntilService<F, B, Input, Output>
82where
83    F: Clone,
84{
85    fn clone(&self) -> Self {
86        Self {
87            repeater: self.repeater.clone(),
88            _marker: std::marker::PhantomData,
89        }
90    }
91}
92
93impl<F, Res, B, Input, CodecError, MetaErr, Err> Service<Task<B::Compact, B::Context, B::IdType>>
94    for RepeatUntilService<F, B, Input, Res>
95where
96    F: Service<Task<Input, B::Context, B::IdType>, Response = Option<Res>> + Send + 'static + Clone,
97    B: BackendExt<Error = Err>
98        + Send
99        + Sync
100        + Clone
101        + Sink<Task<B::Compact, B::Context, B::IdType>, Error = Err>
102        + Unpin
103        + 'static,
104    B::Context: MetadataExt<RepeaterState<B::IdType>, Error = MetaErr>
105        + MetadataExt<WorkflowContext, Error = MetaErr>
106        + Send
107        + 'static,
108    B::Codec: Codec<Input, Error = CodecError, Compact = B::Compact>
109        + Codec<Res, Error = CodecError, Compact = B::Compact>
110        + Codec<Option<Res>, Error = CodecError, Compact = B::Compact>
111        + 'static,
112    B::IdType: GenerateId + Send + 'static,
113    Err: std::error::Error + Send + Sync + 'static,
114    CodecError: std::error::Error + Send + Sync + 'static,
115    F::Error: Into<BoxDynError> + Send + 'static,
116    MetaErr: std::error::Error + Send + Sync + 'static,
117    F::Future: Send + 'static,
118    B::Compact: Send + 'static,
119    Input: Send + 'static, // We don't need Clone because decoding just needs a reference
120    Res: Send + 'static,
121{
122    type Response = GoTo<StepResult<B::Compact, B::IdType>>;
123    type Error = BoxDynError;
124    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
125
126    fn poll_ready(&mut self, cx: &mut Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
127        self.repeater.poll_ready(cx).map_err(|e| e.into())
128    }
129
130    fn call(&mut self, task: Task<B::Compact, B::Context, B::IdType>) -> Self::Future {
131        let state: RepeaterState<B::IdType> = task.parts.ctx.extract().unwrap_or_default();
132        let mut ctx =
133            task.parts.data.get::<StepContext<B>>().cloned().expect(
134                "StepContext missing, Did you call the repeater outside of a workflow step?",
135            );
136        let mut repeater = self.repeater.clone();
137
138        (async move {
139            let mut compact = None;
140            let decoded: Input = B::Codec::decode(&task.args)?;
141            let prev_task_id = task.parts.task_id.clone();
142            let repeat_task = task.map(|c| {
143                compact = Some(c);
144                decoded
145            });
146            let response = repeater.call(repeat_task).await.map_err(|e| e.into())?;
147            Ok(match response {
148                Some(res) if ctx.has_next => {
149                    let task_id = TaskId::new(B::IdType::generate());
150                    let next_step = TaskBuilder::new(B::Codec::encode(&res)?)
151                        .with_task_id(task_id.clone())
152                        .meta(WorkflowContext {
153                            step_index: ctx.current_step + 1,
154                        })
155                        .build();
156                    ctx.backend
157                        .send(next_step)
158                        .await
159                        .map_err(|e| TaskSinkError::PushError(e))?;
160                    GoTo::Next(StepResult {
161                        result: B::Codec::encode(&res)?,
162                        next_task_id: Some(task_id),
163                    })
164                }
165                Some(res) => GoTo::Break(StepResult {
166                    result: B::Codec::encode(&res)?,
167                    next_task_id: None,
168                }),
169                None => {
170                    let task_id = TaskId::new(B::IdType::generate());
171                    let next_step =
172                        TaskBuilder::new(compact.take().expect("Compact args should be set"))
173                            .with_task_id(task_id.clone())
174                            .meta(WorkflowContext {
175                                step_index: ctx.current_step,
176                            })
177                            .meta(RepeaterState {
178                                iterations: state.iterations + 1,
179                                prev_task_id,
180                            })
181                            .build();
182                    ctx.backend
183                        .send(next_step)
184                        .await
185                        .map_err(|e| TaskSinkError::PushError(e))?;
186                    GoTo::Break(StepResult {
187                        result: B::Codec::encode(&None::<Res>)?,
188                        next_task_id: Some(task_id),
189                    })
190                }
191            })
192        }
193        .boxed()) as _
194    }
195}
196
197/// The state of the fold operation
198#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
199pub struct RepeaterState<IdType> {
200    iterations: usize,
201    prev_task_id: Option<TaskId<IdType>>,
202}
203
204impl<IdType> Default for RepeaterState<IdType> {
205    fn default() -> Self {
206        Self {
207            iterations: 0,
208            prev_task_id: None,
209        }
210    }
211}
212
213impl<IdType> RepeaterState<IdType> {
214    /// Get the number of iterations completed so far.
215    pub fn iterations(&self) -> usize {
216        self.iterations
217    }
218
219    /// Get the previous task id.
220    pub fn previous_task_id(&self) -> Option<&TaskId<IdType>> {
221        self.prev_task_id.as_ref()
222    }
223}
224
225impl<Args: Sync, Ctx: MetadataExt<Self> + Sync, IdType: Sync> FromRequest<Task<Args, Ctx, IdType>>
226    for RepeaterState<IdType>
227{
228    type Error = Infallible;
229    async fn from_request(task: &Task<Args, Ctx, IdType>) -> Result<Self, Infallible> {
230        let state: Self = task.parts.ctx.extract().unwrap_or_default();
231        Ok(Self {
232            iterations: state.iterations,
233            prev_task_id: state.prev_task_id,
234        })
235    }
236}
237
238impl<B, F, Input, Res, S, MetaErr, Err, CodecError> Step<Input, B>
239    for RepeatUntilStep<S, F, Input, Res>
240where
241    F: Service<Task<Input, B::Context, B::IdType>, Response = Option<Res>>
242        + Send
243        + Sync
244        + 'static
245        + Clone,
246    B: BackendExt<Error = Err>
247        + Send
248        + Sync
249        + Clone
250        + Sink<Task<B::Compact, B::Context, B::IdType>, Error = Err>
251        + Unpin
252        + 'static,
253    B::Context: MetadataExt<RepeaterState<B::IdType>, Error = MetaErr>
254        + MetadataExt<WorkflowContext, Error = MetaErr>
255        + Send
256        + 'static,
257    B::Codec: Codec<Input, Error = CodecError, Compact = B::Compact>
258        + Codec<Res, Error = CodecError, Compact = B::Compact>
259        + Codec<Option<Res>, Error = CodecError, Compact = B::Compact>
260        + 'static,
261    B::IdType: GenerateId + Send + 'static,
262    Err: std::error::Error + Send + Sync + 'static,
263    CodecError: std::error::Error + Send + Sync + 'static,
264    F::Error: Into<BoxDynError> + Send + 'static,
265    MetaErr: std::error::Error + Send + Sync + 'static,
266    F::Future: Send + 'static,
267    B::Compact: Send + 'static,
268    Input: Send + Sync + 'static, // We don't need Clone because decoding just needs a reference
269    Res: Send + Sync + 'static,
270    S: Step<Input, B> + Send + 'static,
271{
272    type Response = Res;
273    type Error = F::Error;
274    fn register(&mut self, ctx: &mut WorkflowRouter<B>) -> Result<(), BoxDynError> {
275        let svc = SteppedService::new(RepeatUntilService {
276            repeater: self.repeater.clone(),
277            _marker: PhantomData::<(B, Input, Res)>,
278        });
279        let count = ctx.steps.len();
280        ctx.steps.insert(count, svc);
281        self.inner.register(ctx)
282    }
283}