1use std::time::Duration;
2
3use apalis_core::{
4 backend::{BackendExt, codec::Codec},
5 error::BoxDynError,
6 task::{Task, builder::TaskBuilder, metadata::MetadataExt, task_id::TaskId},
7};
8use futures::sink::SinkExt;
9use futures::{FutureExt, Sink, future::BoxFuture};
10use tower::Service;
11
12use crate::{
13 SteppedService, Workflow,
14 context::{StepContext, WorkflowContext},
15 id_generator::GenerateId,
16 router::{GoTo, StepResult, WorkflowRouter},
17 step::{Layer, Stack, Step},
18};
19
20#[derive(Clone, Debug)]
22pub struct DelayFor {
23 duration: Duration,
24}
25
26impl<S> Layer<S> for DelayFor
27where
28 S: Clone,
29{
30 type Step = DelayForStep<S>;
31
32 fn layer(&self, step: S) -> Self::Step {
33 DelayForStep {
34 inner: step,
35 duration: self.duration,
36 }
37 }
38}
39
40#[derive(Clone, Debug)]
42pub struct DelayForStep<S> {
43 inner: S,
44 duration: Duration,
45}
46
47impl<Input, B, S, Err> Step<Input, B> for DelayForStep<S>
48where
49 B::IdType: GenerateId + Send + 'static,
50 B::Compact: Send + 'static,
51 B: Sink<Task<B::Compact, B::Context, B::IdType>, Error = Err>
52 + Unpin
53 + Send
54 + Sync
55 + Clone
56 + 'static,
57 Err: std::error::Error + Send + Sync + 'static,
58 S: Clone + Send + 'static,
59 S::Response: Send + 'static,
60 B::Codec: Codec<Duration, Compact = B::Compact> + Codec<Input, Compact = B::Compact> + 'static,
61 <B::Codec as Codec<Duration>>::Error: Into<BoxDynError>,
62 B::Context: Send + 'static + MetadataExt<WorkflowContext>,
63 Input: Send + 'static,
64 <B::Codec as Codec<Input>>::Error: Into<BoxDynError>,
65 B: BackendExt,
66 S: Step<Input, B>,
67{
68 type Response = Input;
69 type Error = BoxDynError;
70 fn register(&mut self, ctx: &mut WorkflowRouter<B>) -> Result<(), BoxDynError> {
71 let duration = self.duration;
72 let svc = SteppedService::new(DelayWithStep {
73 f: Box::new(move |_| duration),
74 inner: self.inner.clone(),
75 _marker: std::marker::PhantomData,
76 });
77 let count = ctx.steps.len();
78 ctx.steps.insert(count, svc);
79 self.inner.register(ctx)
80 }
81}
82
83#[derive(Clone, Debug)]
85pub struct DelayWith<F, B, Input> {
86 f: F,
87 _marker: std::marker::PhantomData<(B, Input)>,
88}
89
90impl<S, F: Clone, B, I> Layer<S> for DelayWith<F, B, I> {
91 type Step = DelayWithStep<S, F, B, I>;
92
93 fn layer(&self, step: S) -> Self::Step {
94 DelayWithStep {
95 f: self.f.clone(),
96 inner: step,
97 _marker: std::marker::PhantomData,
98 }
99 }
100}
101
102#[derive(Clone, Debug)]
104pub struct DelayWithStep<S, F, B, Input> {
105 f: F,
106 inner: S,
107 _marker: std::marker::PhantomData<(B, Input)>,
108}
109
110impl<Input, F, B, S, Err> Step<Input, B> for DelayWithStep<S, F, B, Input>
111where
112 F: FnMut(Task<Input, B::Context, B::IdType>) -> Duration + Send + 'static + Clone,
113 B::IdType: GenerateId + Send + 'static,
114 B::Compact: Send + 'static,
115 B: Sink<Task<B::Compact, B::Context, B::IdType>, Error = Err>
116 + Unpin
117 + Send
118 + Sync
119 + Clone
120 + 'static,
121 Err: std::error::Error + Send + Sync + 'static,
122 S: Clone + Send + 'static,
123 S::Response: Send + 'static,
124 B::Codec: Codec<Duration, Compact = B::Compact> + Codec<Input, Compact = B::Compact> + 'static,
125 <B::Codec as Codec<Duration>>::Error: Into<BoxDynError>,
126 B::Context: Send + 'static + MetadataExt<WorkflowContext>,
127 Input: Send + 'static,
128 <B::Codec as Codec<Input>>::Error: Into<BoxDynError>,
129 B: BackendExt,
130 S: Step<Input, B>,
131{
132 type Response = Input;
133 type Error = BoxDynError;
134 fn register(&mut self, ctx: &mut WorkflowRouter<B>) -> Result<(), BoxDynError> {
135 let svc = SteppedService::new(Self {
136 f: self.f.clone(),
137 inner: self.inner.clone(),
138 _marker: std::marker::PhantomData,
139 });
140 let count = ctx.steps.len();
141 ctx.steps.insert(count, svc);
142 self.inner.register(ctx)
143 }
144}
145
146impl<S, F, B: BackendExt + Send + Sync + 'static + Clone, Input, Err>
147 Service<Task<B::Compact, B::Context, B::IdType>> for DelayWithStep<S, F, B, Input>
148where
149 F: FnMut(Task<Input, B::Context, B::IdType>) -> Duration + Send + 'static + Clone,
150 S: Step<Input, B> + Send + 'static,
151 S::Response: Send + 'static,
152 B::IdType: GenerateId + Send + 'static,
153 B::Compact: Send + 'static,
154 B: Sink<Task<B::Compact, B::Context, B::IdType>, Error = Err> + Unpin + Send + Sync,
155 Err: std::error::Error + Send + Sync + 'static,
156 B::Codec: Codec<Duration, Compact = B::Compact> + Codec<Input, Compact = B::Compact> + 'static,
157 <B::Codec as Codec<Duration>>::Error: Into<BoxDynError>,
158 <B::Codec as Codec<Input>>::Error: Into<BoxDynError>,
159 B::Context: Send + 'static + MetadataExt<WorkflowContext>,
160{
161 type Response = GoTo<StepResult<B::Compact, B::IdType>>;
162 type Error = BoxDynError;
163 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
164
165 fn poll_ready(
166 &mut self,
167 _cx: &mut std::task::Context<'_>,
168 ) -> std::task::Poll<Result<(), Self::Error>> {
169 std::task::Poll::Ready(Ok(()))
170 }
171
172 fn call(&mut self, req: Task<B::Compact, B::Context, B::IdType>) -> Self::Future {
173 let mut ctx: StepContext<B> = req.parts.data.get().cloned().unwrap();
174 let mut f = self.f.clone();
175
176 let task_id = TaskId::new(B::IdType::generate());
177 async move {
178 let decoded: Input = B::Codec::decode(&req.args)
179 .map_err(|e: <B::Codec as Codec<Input>>::Error| e.into())?;
180 let (args, parts) = req.take();
181 let delay_duration = f(Task {
182 args: decoded,
183 parts,
184 });
185 let task = TaskBuilder::new(args)
186 .with_task_id(task_id.clone())
187 .meta(WorkflowContext {
188 step_index: ctx.current_step + 1,
189 })
190 .run_after(delay_duration)
191 .build();
192 ctx.backend
193 .send(task)
194 .await
195 .map_err(|e| BoxDynError::from(e))?;
196 Ok(GoTo::Next(StepResult {
197 result: B::Codec::encode(&delay_duration).map_err(|e| e.into())?,
198 next_task_id: Some(task_id),
199 }))
200 }
201 .boxed()
202 }
203}
204
205impl<Start, Cur, B, L> Workflow<Start, Cur, B, L> {
206 pub fn delay_for(self, delay: Duration) -> Workflow<Start, Cur, B, Stack<DelayFor, L>> {
208 self.add_step(DelayFor { duration: delay })
209 }
210}
211impl<Start, Cur, B, L> Workflow<Start, Cur, B, L> {
212 pub fn delay_with<F, I>(self, f: F) -> Workflow<Start, I, B, Stack<DelayWith<F, B, I>, L>> {
214 self.add_step(DelayWith {
215 f,
216 _marker: std::marker::PhantomData,
217 })
218 }
219}