1use std::convert::Infallible;
2use std::marker::PhantomData;
3use std::task::Context;
4
5use apalis_core::backend::TaskSinkError;
6use apalis_core::backend::codec::Codec;
7use apalis_core::error::BoxDynError;
8use apalis_core::task::builder::TaskBuilder;
9use apalis_core::task::metadata::MetadataExt;
10use apalis_core::task::task_id::TaskId;
11use apalis_core::task_fn::{FromRequest, TaskFn, task_fn};
12use apalis_core::{backend::BackendExt, task::Task};
13use futures::future::BoxFuture;
14use futures::{FutureExt, Sink, SinkExt};
15use serde::{Deserialize, Serialize};
16use tower::Service;
17
18use crate::id_generator::GenerateId;
19use crate::sequential::router::WorkflowRouter;
20use crate::sequential::{GoTo, Layer, Stack, Step, StepContext, StepResult, WorkflowContext};
21use crate::{SteppedService, Workflow};
22
23#[derive(Clone, Debug)]
25pub struct RepeatUntil<F, Input, Output> {
26 repeater: F,
27 _marker: PhantomData<(Input, Output)>,
28}
29
30impl<F, Input, Output, S> Layer<S> for RepeatUntil<F, Input, Output>
31where
32 F: Clone,
33{
34 type Step = RepeatUntilStep<S, F, Input, Output>;
35
36 fn layer(&self, step: S) -> Self::Step {
37 RepeatUntilStep {
38 inner: step,
39 repeater: self.repeater.clone(),
40 _marker: std::marker::PhantomData,
41 }
42 }
43}
44impl<Start, L, Input, B: BackendExt> Workflow<Start, Input, B, L> {
45 pub fn repeat_until<F, Output, FnArgs>(
47 self,
48 repeater: F,
49 ) -> Workflow<
50 Start,
51 Output,
52 B,
53 Stack<RepeatUntil<TaskFn<F, Input, B::Context, FnArgs>, Input, Output>, L>,
54 >
55 where
56 TaskFn<F, Input, B::Context, FnArgs>:
57 Service<Task<Input, B::Context, B::IdType>, Response = Option<Output>>,
58 {
59 self.add_step(RepeatUntil {
60 repeater: task_fn(repeater),
61 _marker: PhantomData::<(Input, Output)>,
62 })
63 }
64}
65
66#[derive(Clone, Debug)]
68pub struct RepeatUntilStep<S, R, Input, Output> {
69 inner: S,
70 repeater: R,
71 _marker: PhantomData<(Input, Output)>,
72}
73
74#[derive(Debug)]
76pub struct RepeatUntilService<F, B, Input, Output> {
77 repeater: F,
78 _marker: std::marker::PhantomData<(B, Input, Output)>,
79}
80
81impl<F, B, Input, Output> Clone for RepeatUntilService<F, B, Input, Output>
82where
83 F: Clone,
84{
85 fn clone(&self) -> Self {
86 Self {
87 repeater: self.repeater.clone(),
88 _marker: std::marker::PhantomData,
89 }
90 }
91}
92
93impl<F, Res, B, Input, CodecError, MetaErr, Err> Service<Task<B::Compact, B::Context, B::IdType>>
94 for RepeatUntilService<F, B, Input, Res>
95where
96 F: Service<Task<Input, B::Context, B::IdType>, Response = Option<Res>> + Send + 'static + Clone,
97 B: BackendExt<Error = Err>
98 + Send
99 + Sync
100 + Clone
101 + Sink<Task<B::Compact, B::Context, B::IdType>, Error = Err>
102 + Unpin
103 + 'static,
104 B::Context: MetadataExt<RepeaterState<B::IdType>, Error = MetaErr>
105 + MetadataExt<WorkflowContext, Error = MetaErr>
106 + Send
107 + 'static,
108 B::Codec: Codec<Input, Error = CodecError, Compact = B::Compact>
109 + Codec<Res, Error = CodecError, Compact = B::Compact>
110 + Codec<Option<Res>, Error = CodecError, Compact = B::Compact>
111 + 'static,
112 B::IdType: GenerateId + Send + 'static,
113 Err: std::error::Error + Send + Sync + 'static,
114 CodecError: std::error::Error + Send + Sync + 'static,
115 F::Error: Into<BoxDynError> + Send + 'static,
116 MetaErr: std::error::Error + Send + Sync + 'static,
117 F::Future: Send + 'static,
118 B::Compact: Send + 'static,
119 Input: Send + 'static, Res: Send + 'static,
121{
122 type Response = GoTo<StepResult<B::Compact, B::IdType>>;
123 type Error = BoxDynError;
124 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
125
126 fn poll_ready(&mut self, cx: &mut Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
127 self.repeater.poll_ready(cx).map_err(|e| e.into())
128 }
129
130 fn call(&mut self, task: Task<B::Compact, B::Context, B::IdType>) -> Self::Future {
131 let state: RepeaterState<B::IdType> = task.parts.ctx.extract().unwrap_or_default();
132 let mut ctx =
133 task.parts.data.get::<StepContext<B>>().cloned().expect(
134 "StepContext missing, Did you call the repeater outside of a workflow step?",
135 );
136 let mut repeater = self.repeater.clone();
137
138 (async move {
139 let mut compact = None;
140 let decoded: Input = B::Codec::decode(&task.args)?;
141 let prev_task_id = task.parts.task_id.clone();
142 let repeat_task = task.map(|c| {
143 compact = Some(c);
144 decoded
145 });
146 let response = repeater.call(repeat_task).await.map_err(|e| e.into())?;
147 Ok(match response {
148 Some(res) if ctx.has_next => {
149 let task_id = TaskId::new(B::IdType::generate());
150 let next_step = TaskBuilder::new(B::Codec::encode(&res)?)
151 .with_task_id(task_id.clone())
152 .meta(WorkflowContext {
153 step_index: ctx.current_step + 1,
154 })
155 .build();
156 ctx.backend
157 .send(next_step)
158 .await
159 .map_err(|e| TaskSinkError::PushError(e))?;
160 GoTo::Next(StepResult {
161 result: B::Codec::encode(&res)?,
162 next_task_id: Some(task_id),
163 })
164 }
165 Some(res) => GoTo::Break(StepResult {
166 result: B::Codec::encode(&res)?,
167 next_task_id: None,
168 }),
169 None => {
170 let task_id = TaskId::new(B::IdType::generate());
171 let next_step =
172 TaskBuilder::new(compact.take().expect("Compact args should be set"))
173 .with_task_id(task_id.clone())
174 .meta(WorkflowContext {
175 step_index: ctx.current_step,
176 })
177 .meta(RepeaterState {
178 iterations: state.iterations + 1,
179 prev_task_id,
180 })
181 .build();
182 ctx.backend
183 .send(next_step)
184 .await
185 .map_err(|e| TaskSinkError::PushError(e))?;
186 GoTo::Break(StepResult {
187 result: B::Codec::encode(&None::<Res>)?,
188 next_task_id: Some(task_id),
189 })
190 }
191 })
192 }
193 .boxed()) as _
194 }
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
199pub struct RepeaterState<IdType> {
200 iterations: usize,
201 prev_task_id: Option<TaskId<IdType>>,
202}
203
204impl<IdType> Default for RepeaterState<IdType> {
205 fn default() -> Self {
206 Self {
207 iterations: 0,
208 prev_task_id: None,
209 }
210 }
211}
212
213impl<IdType> RepeaterState<IdType> {
214 pub fn iterations(&self) -> usize {
216 self.iterations
217 }
218
219 pub fn previous_task_id(&self) -> Option<&TaskId<IdType>> {
221 self.prev_task_id.as_ref()
222 }
223}
224
225impl<Args: Sync, Ctx: MetadataExt<Self> + Sync, IdType: Sync> FromRequest<Task<Args, Ctx, IdType>>
226 for RepeaterState<IdType>
227{
228 type Error = Infallible;
229 async fn from_request(task: &Task<Args, Ctx, IdType>) -> Result<Self, Infallible> {
230 let state: Self = task.parts.ctx.extract().unwrap_or_default();
231 Ok(Self {
232 iterations: state.iterations,
233 prev_task_id: state.prev_task_id,
234 })
235 }
236}
237
238impl<B, F, Input, Res, S, MetaErr, Err, CodecError> Step<Input, B>
239 for RepeatUntilStep<S, F, Input, Res>
240where
241 F: Service<Task<Input, B::Context, B::IdType>, Response = Option<Res>>
242 + Send
243 + Sync
244 + 'static
245 + Clone,
246 B: BackendExt<Error = Err>
247 + Send
248 + Sync
249 + Clone
250 + Sink<Task<B::Compact, B::Context, B::IdType>, Error = Err>
251 + Unpin
252 + 'static,
253 B::Context: MetadataExt<RepeaterState<B::IdType>, Error = MetaErr>
254 + MetadataExt<WorkflowContext, Error = MetaErr>
255 + Send
256 + 'static,
257 B::Codec: Codec<Input, Error = CodecError, Compact = B::Compact>
258 + Codec<Res, Error = CodecError, Compact = B::Compact>
259 + Codec<Option<Res>, Error = CodecError, Compact = B::Compact>
260 + 'static,
261 B::IdType: GenerateId + Send + 'static,
262 Err: std::error::Error + Send + Sync + 'static,
263 CodecError: std::error::Error + Send + Sync + 'static,
264 F::Error: Into<BoxDynError> + Send + 'static,
265 MetaErr: std::error::Error + Send + Sync + 'static,
266 F::Future: Send + 'static,
267 B::Compact: Send + 'static,
268 Input: Send + Sync + 'static, Res: Send + Sync + 'static,
270 S: Step<Input, B> + Send + 'static,
271{
272 type Response = Res;
273 type Error = F::Error;
274 fn register(&mut self, ctx: &mut WorkflowRouter<B>) -> Result<(), BoxDynError> {
275 let svc = SteppedService::new(RepeatUntilService {
276 repeater: self.repeater.clone(),
277 _marker: PhantomData::<(B, Input, Res)>,
278 });
279 let count = ctx.steps.len();
280 ctx.steps.insert(count, svc);
281 self.inner.register(ctx)
282 }
283}