1use std::{marker::PhantomData, task::Context};
2
3use apalis_core::{
4 backend::{BackendExt, TaskSinkError, codec::Codec},
5 error::BoxDynError,
6 task::{Task, builder::TaskBuilder, metadata::MetadataExt, task_id::TaskId},
7 task_fn::{TaskFn, task_fn},
8};
9use futures::{FutureExt, Sink, SinkExt, future::BoxFuture};
10use serde::{Deserialize, Serialize};
11use tower::Service;
12
13use crate::{
14 SteppedService,
15 context::{StepContext, WorkflowContext},
16 id_generator::GenerateId,
17 router::{GoTo, StepResult, WorkflowRouter},
18 step::{Layer, Stack, Step},
19 workflow::Workflow,
20};
21
22#[derive(Clone, Debug)]
24pub struct Fold<F, Init> {
25 fold: F,
26 _marker: std::marker::PhantomData<Init>,
27}
28
29impl<F, Init, S> Layer<S> for Fold<F, Init>
30where
31 F: Clone,
32 Init: Clone,
33{
34 type Step = FoldStep<S, F, Init>;
35
36 fn layer(&self, step: S) -> Self::Step {
37 FoldStep {
38 inner: step,
39 fold: self.fold.clone(),
40 _marker: std::marker::PhantomData,
41 }
42 }
43}
44impl<Start, C, L, I: IntoIterator<Item = C>, B: BackendExt> Workflow<Start, I, B, L> {
45 pub fn fold<F, Output, FnArgs, Init>(
47 self,
48 fold: F,
49 ) -> Workflow<Start, Output, B, Stack<Fold<TaskFn<F, (Init, C), B::Context, FnArgs>, Init>, L>>
50 where
51 TaskFn<F, (Init, C), B::Context, FnArgs>:
52 Service<Task<(Init, C), B::Context, B::IdType>, Response = Output>,
53 {
54 self.add_step(Fold {
55 fold: task_fn(fold),
56 _marker: PhantomData,
57 })
58 }
59}
60
61#[derive(Clone, Debug)]
63pub struct FoldStep<S, F, Init> {
64 inner: S,
65 fold: F,
66 _marker: std::marker::PhantomData<Init>,
67}
68
69impl<S, F, Input, I: IntoIterator<Item = Input>, Init, B, MetaErr, Err, CodecError> Step<I, B>
70 for FoldStep<S, F, Init>
71where
72 F: Service<Task<(Init, Input), B::Context, B::IdType>, Response = Init>
73 + Send
74 + 'static
75 + Clone,
76 S: Step<Init, B>,
77 B: BackendExt<Error = Err>
78 + Send
79 + Sync
80 + Clone
81 + Sink<Task<B::Compact, B::Context, B::IdType>, Error = Err>
82 + Unpin
83 + 'static,
84 I: IntoIterator<Item = Input> + Send + 'static,
85 B::Context: MetadataExt<FoldState, Error = MetaErr>
86 + MetadataExt<WorkflowContext, Error = MetaErr>
87 + Send
88 + 'static,
89 B::Codec: Codec<(Init, Vec<Input>), Error = CodecError, Compact = B::Compact>
90 + Codec<Init, Error = CodecError, Compact = B::Compact>
91 + Codec<I, Error = CodecError, Compact = B::Compact>
92 + Codec<(Init, Input), Error = CodecError, Compact = B::Compact>
93 + 'static,
94 B::IdType: GenerateId + Send + 'static + Clone,
95 Init: Default + Send + 'static,
96 Err: std::error::Error + Send + Sync + 'static,
97 CodecError: std::error::Error + Send + Sync + 'static,
98 F::Error: Into<BoxDynError> + Send + 'static,
99 MetaErr: std::error::Error + Send + Sync + 'static,
100 F::Future: Send + 'static,
101 B::Compact: Send + 'static,
102 Input: Send + 'static,
103{
104 type Response = Init;
105 type Error = F::Error;
106 fn register(&mut self, ctx: &mut WorkflowRouter<B>) -> Result<(), BoxDynError> {
107 let svc = SteppedService::new(FoldService {
108 fold: self.fold.clone(),
109 _marker: PhantomData::<(Init, I, B)>,
110 });
111 let count = ctx.steps.len();
112 ctx.steps.insert(count, svc);
113 self.inner.register(ctx)
114 }
115}
116
117#[derive(Clone, Debug)]
119pub struct FoldService<F, Init, I, B> {
120 fold: F,
121 _marker: std::marker::PhantomData<(Init, I, B)>,
122}
123
124impl<F, Init, I, B> FoldService<F, Init, I, B> {
125 pub fn new(fold: F) -> Self {
127 Self {
128 fold,
129 _marker: std::marker::PhantomData,
130 }
131 }
132}
133
134impl<F, Init, I, B, Input, CodecError, MetaErr, Err>
135 Service<Task<B::Compact, B::Context, B::IdType>> for FoldService<F, Init, I, B>
136where
137 F: Service<Task<(Init, Input), B::Context, B::IdType>, Response = Init>
138 + Send
139 + 'static
140 + Clone,
141 B: BackendExt<Error = Err>
142 + Send
143 + Sync
144 + Clone
145 + Sink<Task<B::Compact, B::Context, B::IdType>, Error = Err>
146 + Unpin
147 + 'static,
148 I: IntoIterator<Item = Input> + Send + 'static,
149 B::Context: MetadataExt<FoldState, Error = MetaErr>
150 + MetadataExt<WorkflowContext, Error = MetaErr>
151 + Send
152 + 'static,
153 B::Codec: Codec<(Init, Vec<Input>), Error = CodecError, Compact = B::Compact>
154 + Codec<Init, Error = CodecError, Compact = B::Compact>
155 + Codec<I, Error = CodecError, Compact = B::Compact>
156 + Codec<(Init, Input), Error = CodecError, Compact = B::Compact>
157 + 'static,
158 B::IdType: GenerateId + Send + 'static,
159 Init: Default + Send + 'static,
160 Err: std::error::Error + Send + Sync + 'static,
161 CodecError: std::error::Error + Send + Sync + 'static,
162 F::Error: Into<BoxDynError> + Send + 'static,
163 MetaErr: std::error::Error + Send + Sync + 'static,
164 F::Future: Send + 'static,
165 B::Compact: Send + 'static,
166 Input: Send + 'static,
167{
168 type Response = GoTo<StepResult<B::Compact, B::IdType>>;
169 type Error = BoxDynError;
170 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
171
172 fn poll_ready(&mut self, cx: &mut Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
173 self.fold.poll_ready(cx).map_err(|e| e.into())
174 }
175
176 fn call(&mut self, task: Task<B::Compact, B::Context, B::IdType>) -> Self::Future {
177 let state = task.parts.ctx.extract().unwrap_or(FoldState::Unknown);
178 let mut ctx = task.parts.data.get::<StepContext<B>>().cloned().unwrap();
179 let mut fold = self.fold.clone();
180
181 match state {
182 FoldState::Unknown => async move {
183 let task_id = TaskId::new(B::IdType::generate());
184 let steps: Task<I, _, _> = task.try_map(|arg| B::Codec::decode(&arg))?;
185 let steps = steps.args.into_iter().collect::<Vec<_>>();
186 let task = TaskBuilder::new(B::Codec::encode(&(Init::default(), steps))?)
187 .meta(WorkflowContext {
188 step_index: ctx.current_step,
189 })
190 .with_task_id(task_id.clone())
191 .meta(FoldState::Collection)
192 .build();
193 ctx.backend
194 .send(task)
195 .await
196 .map_err(|e| TaskSinkError::PushError(e))?;
197 Ok(GoTo::Next(StepResult {
198 result: B::Codec::encode(&Init::default())?,
199 next_task_id: Some(task_id),
200 }))
201 }
202 .boxed(),
203 FoldState::Collection => async move {
204 let args: (Init, Vec<Input>) = B::Codec::decode(&task.args)?;
205 let (acc, items) = args;
206
207 let mut items = items.into_iter();
208 let next = items.next().unwrap();
209 let rest = items.collect::<Vec<_>>();
210 let fold_task = task.map(|_| (acc, next));
211 let response = fold.call(fold_task).await.map_err(|e| e.into())?;
212
213 match rest.len() {
214 0 if ctx.has_next => {
215 let task_id = TaskId::new(B::IdType::generate());
216 let result = B::Codec::encode(&response)?;
217 let next_step = TaskBuilder::new(result)
218 .with_task_id(task_id.clone())
219 .meta(WorkflowContext {
220 step_index: ctx.current_step + 1,
221 })
222 .build();
223 ctx.backend
224 .send(next_step)
225 .await
226 .map_err(|e| TaskSinkError::PushError(e))?;
227 Ok(GoTo::Break(StepResult {
228 result: B::Codec::encode(&response)?,
229 next_task_id: Some(task_id),
230 }))
231 }
232 0 => Ok(GoTo::Break(StepResult {
233 result: B::Codec::encode(&response)?,
234 next_task_id: None,
235 })),
236 1.. => {
237 let task_id = TaskId::new(B::IdType::generate());
239 let result = B::Codec::encode(&response)?;
240 let steps = TaskBuilder::new(B::Codec::encode(&(response, rest))?)
241 .with_task_id(task_id.clone())
242 .meta(WorkflowContext {
243 step_index: ctx.current_step,
244 })
245 .meta(FoldState::Collection)
246 .build();
247 ctx.backend
248 .send(steps)
249 .await
250 .map_err(|e| TaskSinkError::PushError(e))?;
251 Ok(GoTo::Next(StepResult {
252 result,
253 next_task_id: Some(task_id),
254 }))
255 }
256 }
257 }
258 .boxed(),
259 }
260 }
261}
262
263#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
265pub enum FoldState {
266 Unknown,
268 Collection,
270}