1use std::time::Duration;
2
3use apalis_core::{
4 backend::{BackendExt, codec::Codec},
5 error::BoxDynError,
6 task::{Task, builder::TaskBuilder, metadata::MetadataExt, task_id::TaskId},
7};
8use futures::sink::SinkExt;
9use futures::{FutureExt, Sink, future::BoxFuture};
10use tower::Service;
11
12use crate::{
13 SteppedService, Workflow,
14 id_generator::GenerateId,
15 sequential::context::{StepContext, WorkflowContext},
16 sequential::router::{GoTo, StepResult, WorkflowRouter},
17 sequential::step::{Layer, Stack, Step},
18};
19
20#[derive(Clone, Debug)]
22pub struct DelayFor {
23 duration: Duration,
24}
25
26impl<S> Layer<S> for DelayFor
27where
28 S: Clone,
29{
30 type Step = DelayForStep<S>;
31
32 fn layer(&self, step: S) -> Self::Step {
33 DelayForStep {
34 inner: step,
35 duration: self.duration,
36 }
37 }
38}
39
40#[derive(Clone, Debug)]
42pub struct DelayForStep<S> {
43 inner: S,
44 duration: Duration,
45}
46
47impl<Input, B, S, Err> Step<Input, B> for DelayForStep<S>
48where
49 B::IdType: GenerateId + Send + 'static,
50 B::Compact: Send + 'static,
51 B: Sink<Task<B::Compact, B::Context, B::IdType>, Error = Err>
52 + Unpin
53 + Send
54 + Sync
55 + Clone
56 + 'static,
57 Err: std::error::Error + Send + Sync + 'static,
58 S: Clone + Send + Sync + 'static,
59 S::Response: Send + 'static,
60 B::Codec: Codec<Duration, Compact = B::Compact> + Codec<Input, Compact = B::Compact> + 'static,
61 <B::Codec as Codec<Duration>>::Error: Into<BoxDynError>,
62 B::Context: Send + 'static + MetadataExt<WorkflowContext>,
63 Input: Send + Sync + 'static,
64 <B::Codec as Codec<Input>>::Error: Into<BoxDynError>,
65 B: BackendExt,
66 S: Step<Input, B>,
67{
68 type Response = Input;
69 type Error = BoxDynError;
70 fn register(&mut self, ctx: &mut WorkflowRouter<B>) -> Result<(), BoxDynError> {
71 let duration = self.duration;
72 let svc = SteppedService::new(DelayWithStep {
73 f: Box::new(move |_| duration),
74 inner: self.inner.clone(),
75 _marker: std::marker::PhantomData,
76 });
77 let count = ctx.steps.len();
78 ctx.steps.insert(count, svc);
79 self.inner.register(ctx)
80 }
81}
82
83#[derive(Clone, Debug)]
85pub struct DelayWith<F, B, Input> {
86 f: F,
87 _marker: std::marker::PhantomData<(B, Input)>,
88}
89
90impl<S, F: Clone, B, I> Layer<S> for DelayWith<F, B, I> {
91 type Step = DelayWithStep<S, F, B, I>;
92
93 fn layer(&self, step: S) -> Self::Step {
94 DelayWithStep {
95 f: self.f.clone(),
96 inner: step,
97 _marker: std::marker::PhantomData,
98 }
99 }
100}
101
102#[derive(Debug)]
104pub struct DelayWithStep<S, F, B, Input> {
105 f: F,
106 inner: S,
107 _marker: std::marker::PhantomData<(B, Input)>,
108}
109
110impl<S: Clone, F: Clone, B, Input> Clone for DelayWithStep<S, F, B, Input> {
111 fn clone(&self) -> Self {
112 Self {
113 f: self.f.clone(),
114 inner: self.inner.clone(),
115 _marker: std::marker::PhantomData,
116 }
117 }
118}
119
120impl<Input, F, B, S, Err> Step<Input, B> for DelayWithStep<S, F, B, Input>
121where
122 F: FnMut(Task<Input, B::Context, B::IdType>) -> Duration + Send + Sync + 'static + Clone,
123 B::IdType: GenerateId + Send + 'static,
124 B::Compact: Send + 'static,
125 B: Sink<Task<B::Compact, B::Context, B::IdType>, Error = Err>
126 + Unpin
127 + Send
128 + Sync
129 + Clone
130 + 'static,
131 Err: std::error::Error + Send + Sync + 'static,
132 S: Clone + Send + Sync + 'static,
133 S::Response: Send + 'static,
134 B::Codec: Codec<Duration, Compact = B::Compact> + Codec<Input, Compact = B::Compact> + 'static,
135 <B::Codec as Codec<Duration>>::Error: Into<BoxDynError>,
136 B::Context: Send + 'static + MetadataExt<WorkflowContext>,
137 Input: Send + Sync + 'static,
138 <B::Codec as Codec<Input>>::Error: Into<BoxDynError>,
139 B: BackendExt,
140 S: Step<Input, B>,
141{
142 type Response = Input;
143 type Error = BoxDynError;
144 fn register(&mut self, ctx: &mut WorkflowRouter<B>) -> Result<(), BoxDynError> {
145 let svc = SteppedService::new(Self {
146 f: self.f.clone(),
147 inner: self.inner.clone(),
148 _marker: std::marker::PhantomData,
149 });
150 let count = ctx.steps.len();
151 ctx.steps.insert(count, svc);
152 self.inner.register(ctx)
153 }
154}
155
156impl<S, F, B: BackendExt + Send + Sync + 'static + Clone, Input, Err>
157 Service<Task<B::Compact, B::Context, B::IdType>> for DelayWithStep<S, F, B, Input>
158where
159 F: FnMut(Task<Input, B::Context, B::IdType>) -> Duration + Send + 'static + Clone,
160 S: Step<Input, B> + Send + 'static,
161 S::Response: Send + 'static,
162 B::IdType: GenerateId + Send + 'static,
163 B::Compact: Send + 'static,
164 B: Sink<Task<B::Compact, B::Context, B::IdType>, Error = Err> + Unpin + Send + Sync,
165 Err: std::error::Error + Send + Sync + 'static,
166 B::Codec: Codec<Duration, Compact = B::Compact> + Codec<Input, Compact = B::Compact> + 'static,
167 <B::Codec as Codec<Duration>>::Error: Into<BoxDynError>,
168 <B::Codec as Codec<Input>>::Error: Into<BoxDynError>,
169 B::Context: Send + 'static + MetadataExt<WorkflowContext>,
170{
171 type Response = GoTo<StepResult<B::Compact, B::IdType>>;
172 type Error = BoxDynError;
173 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
174
175 fn poll_ready(
176 &mut self,
177 _cx: &mut std::task::Context<'_>,
178 ) -> std::task::Poll<Result<(), Self::Error>> {
179 std::task::Poll::Ready(Ok(()))
180 }
181
182 fn call(&mut self, req: Task<B::Compact, B::Context, B::IdType>) -> Self::Future {
183 let mut ctx: StepContext<B> = req.parts.data.get().cloned().unwrap();
184 let mut f = self.f.clone();
185
186 let task_id = TaskId::new(B::IdType::generate());
187 async move {
188 let decoded: Input = B::Codec::decode(&req.args)
189 .map_err(|e: <B::Codec as Codec<Input>>::Error| e.into())?;
190 let (args, parts) = req.take();
191 let delay_duration = f(Task {
192 args: decoded,
193 parts,
194 });
195 let task = TaskBuilder::new(args)
196 .with_task_id(task_id.clone())
197 .meta(WorkflowContext {
198 step_index: ctx.current_step + 1,
199 })
200 .run_after(delay_duration)
201 .build();
202 ctx.backend
203 .send(task)
204 .await
205 .map_err(|e| BoxDynError::from(e))?;
206 Ok(GoTo::Next(StepResult {
207 result: B::Codec::encode(&delay_duration).map_err(|e| e.into())?,
208 next_task_id: Some(task_id),
209 }))
210 }
211 .boxed()
212 }
213}
214
215impl<Start, Cur, B, L> Workflow<Start, Cur, B, L> {
216 pub fn delay_for(self, delay: Duration) -> Workflow<Start, Cur, B, Stack<DelayFor, L>> {
218 self.add_step(DelayFor { duration: delay })
219 }
220}
221impl<Start, Cur, B, L> Workflow<Start, Cur, B, L> {
222 pub fn delay_with<F, I>(self, f: F) -> Workflow<Start, I, B, Stack<DelayWith<F, B, I>, L>> {
224 self.add_step(DelayWith {
225 f,
226 _marker: std::marker::PhantomData,
227 })
228 }
229}