apalis_workflow/delay/
mod.rs

1use std::time::Duration;
2
3use apalis_core::{
4    backend::{BackendExt, codec::Codec},
5    error::BoxDynError,
6    task::{Task, builder::TaskBuilder, metadata::MetadataExt, task_id::TaskId},
7};
8use futures::sink::SinkExt;
9use futures::{FutureExt, Sink, future::BoxFuture};
10use tower::Service;
11
12use crate::{
13    SteppedService, Workflow,
14    context::{StepContext, WorkflowContext},
15    id_generator::GenerateId,
16    router::{GoTo, StepResult, WorkflowRouter},
17    step::{Layer, Stack, Step},
18};
19
20/// Layer that delays execution by a specified duration
21#[derive(Clone, Debug)]
22pub struct DelayFor {
23    duration: Duration,
24}
25
26impl<S> Layer<S> for DelayFor
27where
28    S: Clone,
29{
30    type Step = DelayForStep<S>;
31
32    fn layer(&self, step: S) -> Self::Step {
33        DelayForStep {
34            inner: step,
35            duration: self.duration,
36        }
37    }
38}
39
40/// Step that delays execution by a specified duration
41#[derive(Clone, Debug)]
42pub struct DelayForStep<S> {
43    inner: S,
44    duration: Duration,
45}
46
47impl<Input, B, S, Err> Step<Input, B> for DelayForStep<S>
48where
49    B::IdType: GenerateId + Send + 'static,
50    B::Compact: Send + 'static,
51    B: Sink<Task<B::Compact, B::Context, B::IdType>, Error = Err>
52        + Unpin
53        + Send
54        + Sync
55        + Clone
56        + 'static,
57    Err: std::error::Error + Send + Sync + 'static,
58    S: Clone + Send + 'static,
59    S::Response: Send + 'static,
60    B::Codec: Codec<Duration, Compact = B::Compact> + Codec<Input, Compact = B::Compact> + 'static,
61    <B::Codec as Codec<Duration>>::Error: Into<BoxDynError>,
62    B::Context: Send + 'static + MetadataExt<WorkflowContext>,
63    Input: Send + 'static,
64    <B::Codec as Codec<Input>>::Error: Into<BoxDynError>,
65    B: BackendExt,
66    S: Step<Input, B>,
67{
68    type Response = Input;
69    type Error = BoxDynError;
70    fn register(&mut self, ctx: &mut WorkflowRouter<B>) -> Result<(), BoxDynError> {
71        let duration = self.duration;
72        let svc = SteppedService::new(DelayWithStep {
73            f: Box::new(move |_| duration),
74            inner: self.inner.clone(),
75            _marker: std::marker::PhantomData,
76        });
77        let count = ctx.steps.len();
78        ctx.steps.insert(count, svc);
79        self.inner.register(ctx)
80    }
81}
82
83/// Step that delays execution by a specified duration
84#[derive(Clone, Debug)]
85pub struct DelayWith<F, B, Input> {
86    f: F,
87    _marker: std::marker::PhantomData<(B, Input)>,
88}
89
90impl<S, F: Clone, B, I> Layer<S> for DelayWith<F, B, I> {
91    type Step = DelayWithStep<S, F, B, I>;
92
93    fn layer(&self, step: S) -> Self::Step {
94        DelayWithStep {
95            f: self.f.clone(),
96            inner: step,
97            _marker: std::marker::PhantomData,
98        }
99    }
100}
101
102/// Step that delays execution by a specified duration
103#[derive(Clone, Debug)]
104pub struct DelayWithStep<S, F, B, Input> {
105    f: F,
106    inner: S,
107    _marker: std::marker::PhantomData<(B, Input)>,
108}
109
110impl<Input, F, B, S, Err> Step<Input, B> for DelayWithStep<S, F, B, Input>
111where
112    F: FnMut(Task<Input, B::Context, B::IdType>) -> Duration + Send + 'static + Clone,
113    B::IdType: GenerateId + Send + 'static,
114    B::Compact: Send + 'static,
115    B: Sink<Task<B::Compact, B::Context, B::IdType>, Error = Err>
116        + Unpin
117        + Send
118        + Sync
119        + Clone
120        + 'static,
121    Err: std::error::Error + Send + Sync + 'static,
122    S: Clone + Send + 'static,
123    S::Response: Send + 'static,
124    B::Codec: Codec<Duration, Compact = B::Compact> + Codec<Input, Compact = B::Compact> + 'static,
125    <B::Codec as Codec<Duration>>::Error: Into<BoxDynError>,
126    B::Context: Send + 'static + MetadataExt<WorkflowContext>,
127    Input: Send + 'static,
128    <B::Codec as Codec<Input>>::Error: Into<BoxDynError>,
129    B: BackendExt,
130    S: Step<Input, B>,
131{
132    type Response = Input;
133    type Error = BoxDynError;
134    fn register(&mut self, ctx: &mut WorkflowRouter<B>) -> Result<(), BoxDynError> {
135        let svc = SteppedService::new(Self {
136            f: self.f.clone(),
137            inner: self.inner.clone(),
138            _marker: std::marker::PhantomData,
139        });
140        let count = ctx.steps.len();
141        ctx.steps.insert(count, svc);
142        self.inner.register(ctx)
143    }
144}
145
146impl<S, F, B: BackendExt + Send + Sync + 'static + Clone, Input, Err>
147    Service<Task<B::Compact, B::Context, B::IdType>> for DelayWithStep<S, F, B, Input>
148where
149    F: FnMut(Task<Input, B::Context, B::IdType>) -> Duration + Send + 'static + Clone,
150    S: Step<Input, B> + Send + 'static,
151    S::Response: Send + 'static,
152    B::IdType: GenerateId + Send + 'static,
153    B::Compact: Send + 'static,
154    B: Sink<Task<B::Compact, B::Context, B::IdType>, Error = Err> + Unpin + Send + Sync,
155    Err: std::error::Error + Send + Sync + 'static,
156    B::Codec: Codec<Duration, Compact = B::Compact> + Codec<Input, Compact = B::Compact> + 'static,
157    <B::Codec as Codec<Duration>>::Error: Into<BoxDynError>,
158    <B::Codec as Codec<Input>>::Error: Into<BoxDynError>,
159    B::Context: Send + 'static + MetadataExt<WorkflowContext>,
160{
161    type Response = GoTo<StepResult<B::Compact, B::IdType>>;
162    type Error = BoxDynError;
163    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
164
165    fn poll_ready(
166        &mut self,
167        _cx: &mut std::task::Context<'_>,
168    ) -> std::task::Poll<Result<(), Self::Error>> {
169        std::task::Poll::Ready(Ok(()))
170    }
171
172    fn call(&mut self, req: Task<B::Compact, B::Context, B::IdType>) -> Self::Future {
173        let mut ctx: StepContext<B> = req.parts.data.get().cloned().unwrap();
174        let mut f = self.f.clone();
175
176        let task_id = TaskId::new(B::IdType::generate());
177        async move {
178            let decoded: Input = B::Codec::decode(&req.args)
179                .map_err(|e: <B::Codec as Codec<Input>>::Error| e.into())?;
180            let (args, parts) = req.take();
181            let delay_duration = f(Task {
182                args: decoded,
183                parts,
184            });
185            let task = TaskBuilder::new(args)
186                .with_task_id(task_id.clone())
187                .meta(WorkflowContext {
188                    step_index: ctx.current_step + 1,
189                })
190                .run_after(delay_duration)
191                .build();
192            ctx.backend
193                .send(task)
194                .await
195                .map_err(|e| BoxDynError::from(e))?;
196            Ok(GoTo::Next(StepResult {
197                result: B::Codec::encode(&delay_duration).map_err(|e| e.into())?,
198                next_task_id: Some(task_id),
199            }))
200        }
201        .boxed()
202    }
203}
204
205impl<Start, Cur, B, L> Workflow<Start, Cur, B, L> {
206    /// Delay the workflow by a fixed duration
207    pub fn delay_for(self, delay: Duration) -> Workflow<Start, Cur, B, Stack<DelayFor, L>> {
208        self.add_step(DelayFor { duration: delay })
209    }
210}
211impl<Start, Cur, B, L> Workflow<Start, Cur, B, L> {
212    /// Delay the workflow by a duration determined by a function
213    pub fn delay_with<F, I>(self, f: F) -> Workflow<Start, I, B, Stack<DelayWith<F, B, I>, L>> {
214        self.add_step(DelayWith {
215            f,
216            _marker: std::marker::PhantomData,
217        })
218    }
219}