1use std::{marker::PhantomData, task::Context};
2
3use apalis_core::{
4 backend::{BackendExt, TaskSinkError, codec::Codec},
5 error::BoxDynError,
6 task::{Task, builder::TaskBuilder, metadata::MetadataExt, task_id::TaskId},
7 task_fn::{TaskFn, task_fn},
8};
9use futures::{FutureExt, Sink, SinkExt, future::BoxFuture};
10use serde::{Deserialize, Serialize};
11use tower::Service;
12
13use crate::{
14 SteppedService,
15 id_generator::GenerateId,
16 sequential::context::{StepContext, WorkflowContext},
17 sequential::router::{GoTo, StepResult, WorkflowRouter},
18 sequential::step::{Layer, Stack, Step},
19 sequential::workflow::Workflow,
20};
21
22#[derive(Clone, Debug)]
24pub struct Fold<F, Init> {
25 fold: F,
26 _marker: std::marker::PhantomData<Init>,
27}
28
29impl<F, Init, S> Layer<S> for Fold<F, Init>
30where
31 F: Clone,
32 Init: Clone,
33{
34 type Step = FoldStep<S, F, Init>;
35
36 fn layer(&self, step: S) -> Self::Step {
37 FoldStep {
38 inner: step,
39 fold: self.fold.clone(),
40 _marker: std::marker::PhantomData,
41 }
42 }
43}
44impl<Start, C, L, I: IntoIterator<Item = C>, B: BackendExt> Workflow<Start, I, B, L> {
45 pub fn fold<F, Output, FnArgs, Init>(
47 self,
48 fold: F,
49 ) -> Workflow<Start, Output, B, Stack<Fold<TaskFn<F, (Init, C), B::Context, FnArgs>, Init>, L>>
50 where
51 TaskFn<F, (Init, C), B::Context, FnArgs>:
52 Service<Task<(Init, C), B::Context, B::IdType>, Response = Output>,
53 {
54 self.add_step(Fold {
55 fold: task_fn(fold),
56 _marker: PhantomData,
57 })
58 }
59}
60
61#[derive(Clone, Debug)]
63pub struct FoldStep<S, F, Init> {
64 inner: S,
65 fold: F,
66 _marker: std::marker::PhantomData<Init>,
67}
68
69impl<S, F, Input, I: IntoIterator<Item = Input>, Init, B, MetaErr, Err, CodecError> Step<I, B>
70 for FoldStep<S, F, Init>
71where
72 F: Service<Task<(Init, Input), B::Context, B::IdType>, Response = Init>
73 + Send
74 + Sync
75 + 'static
76 + Clone,
77 S: Step<Init, B>,
78 B: BackendExt<Error = Err>
79 + Send
80 + Sync
81 + Clone
82 + Sink<Task<B::Compact, B::Context, B::IdType>, Error = Err>
83 + Unpin
84 + 'static,
85 I: IntoIterator<Item = Input> + Send + Sync + 'static,
86 B::Context: MetadataExt<FoldState, Error = MetaErr>
87 + MetadataExt<WorkflowContext, Error = MetaErr>
88 + Send
89 + 'static,
90 B::Codec: Codec<(Init, Vec<Input>), Error = CodecError, Compact = B::Compact>
91 + Codec<Init, Error = CodecError, Compact = B::Compact>
92 + Codec<I, Error = CodecError, Compact = B::Compact>
93 + Codec<(Init, Input), Error = CodecError, Compact = B::Compact>
94 + 'static,
95 B::IdType: GenerateId + Send + 'static + Clone,
96 Init: Default + Send + Sync + 'static,
97 Err: std::error::Error + Send + Sync + 'static,
98 CodecError: std::error::Error + Send + Sync + 'static,
99 F::Error: Into<BoxDynError> + Send + 'static,
100 MetaErr: std::error::Error + Send + Sync + 'static,
101 F::Future: Send + 'static,
102 B::Compact: Send + 'static,
103 Input: Send + 'static,
104{
105 type Response = Init;
106 type Error = F::Error;
107 fn register(&mut self, ctx: &mut WorkflowRouter<B>) -> Result<(), BoxDynError> {
108 let svc = SteppedService::new(FoldService {
109 fold: self.fold.clone(),
110 _marker: PhantomData::<(Init, I, B)>,
111 });
112 let count = ctx.steps.len();
113 ctx.steps.insert(count, svc);
114 self.inner.register(ctx)
115 }
116}
117
118#[derive(Debug)]
120pub struct FoldService<F, Init, I, B> {
121 fold: F,
122 _marker: std::marker::PhantomData<(Init, I, B)>,
123}
124
125impl<F: Clone, Init, I, B> Clone for FoldService<F, Init, I, B> {
126 fn clone(&self) -> Self {
127 Self {
128 fold: self.fold.clone(),
129 _marker: std::marker::PhantomData,
130 }
131 }
132}
133
134impl<F, Init, I, B> FoldService<F, Init, I, B> {
135 pub fn new(fold: F) -> Self {
137 Self {
138 fold,
139 _marker: std::marker::PhantomData,
140 }
141 }
142}
143
144impl<F, Init, I, B, Input, CodecError, MetaErr, Err>
145 Service<Task<B::Compact, B::Context, B::IdType>> for FoldService<F, Init, I, B>
146where
147 F: Service<Task<(Init, Input), B::Context, B::IdType>, Response = Init>
148 + Send
149 + 'static
150 + Clone,
151 B: BackendExt<Error = Err>
152 + Send
153 + Sync
154 + Clone
155 + Sink<Task<B::Compact, B::Context, B::IdType>, Error = Err>
156 + Unpin
157 + 'static,
158 I: IntoIterator<Item = Input> + Send + 'static,
159 B::Context: MetadataExt<FoldState, Error = MetaErr>
160 + MetadataExt<WorkflowContext, Error = MetaErr>
161 + Send
162 + 'static,
163 B::Codec: Codec<(Init, Vec<Input>), Error = CodecError, Compact = B::Compact>
164 + Codec<Init, Error = CodecError, Compact = B::Compact>
165 + Codec<I, Error = CodecError, Compact = B::Compact>
166 + Codec<(Init, Input), Error = CodecError, Compact = B::Compact>
167 + 'static,
168 B::IdType: GenerateId + Send + 'static,
169 Init: Default + Send + 'static,
170 Err: std::error::Error + Send + Sync + 'static,
171 CodecError: std::error::Error + Send + Sync + 'static,
172 F::Error: Into<BoxDynError> + Send + 'static,
173 MetaErr: std::error::Error + Send + Sync + 'static,
174 F::Future: Send + 'static,
175 B::Compact: Send + 'static,
176 Input: Send + 'static,
177{
178 type Response = GoTo<StepResult<B::Compact, B::IdType>>;
179 type Error = BoxDynError;
180 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
181
182 fn poll_ready(&mut self, cx: &mut Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
183 self.fold.poll_ready(cx).map_err(|e| e.into())
184 }
185
186 fn call(&mut self, task: Task<B::Compact, B::Context, B::IdType>) -> Self::Future {
187 let state = task.parts.ctx.extract().unwrap_or(FoldState::Unknown);
188 let mut ctx = task.parts.data.get::<StepContext<B>>().cloned().unwrap();
189 let mut fold = self.fold.clone();
190
191 match state {
192 FoldState::Unknown => async move {
193 let task_id = TaskId::new(B::IdType::generate());
194 let steps: Task<I, _, _> = task.try_map(|arg| B::Codec::decode(&arg))?;
195 let steps = steps.args.into_iter().collect::<Vec<_>>();
196 let task = TaskBuilder::new(B::Codec::encode(&(Init::default(), steps))?)
197 .meta(WorkflowContext {
198 step_index: ctx.current_step,
199 })
200 .with_task_id(task_id.clone())
201 .meta(FoldState::Collection)
202 .build();
203 ctx.backend
204 .send(task)
205 .await
206 .map_err(|e| TaskSinkError::PushError(e))?;
207 Ok(GoTo::Next(StepResult {
208 result: B::Codec::encode(&Init::default())?,
209 next_task_id: Some(task_id),
210 }))
211 }
212 .boxed(),
213 FoldState::Collection => async move {
214 let args: (Init, Vec<Input>) = B::Codec::decode(&task.args)?;
215 let (acc, items) = args;
216
217 let mut items = items.into_iter();
218 let next = items.next().unwrap();
219 let rest = items.collect::<Vec<_>>();
220 let fold_task = task.map(|_| (acc, next));
221 let response = fold.call(fold_task).await.map_err(|e| e.into())?;
222
223 match rest.len() {
224 0 if ctx.has_next => {
225 let task_id = TaskId::new(B::IdType::generate());
226 let result = B::Codec::encode(&response)?;
227 let next_step = TaskBuilder::new(result)
228 .with_task_id(task_id.clone())
229 .meta(WorkflowContext {
230 step_index: ctx.current_step + 1,
231 })
232 .build();
233 ctx.backend
234 .send(next_step)
235 .await
236 .map_err(|e| TaskSinkError::PushError(e))?;
237 Ok(GoTo::Break(StepResult {
238 result: B::Codec::encode(&response)?,
239 next_task_id: Some(task_id),
240 }))
241 }
242 0 => Ok(GoTo::Break(StepResult {
243 result: B::Codec::encode(&response)?,
244 next_task_id: None,
245 })),
246 1.. => {
247 let task_id = TaskId::new(B::IdType::generate());
249 let result = B::Codec::encode(&response)?;
250 let steps = TaskBuilder::new(B::Codec::encode(&(response, rest))?)
251 .with_task_id(task_id.clone())
252 .meta(WorkflowContext {
253 step_index: ctx.current_step,
254 })
255 .meta(FoldState::Collection)
256 .build();
257 ctx.backend
258 .send(steps)
259 .await
260 .map_err(|e| TaskSinkError::PushError(e))?;
261 Ok(GoTo::Next(StepResult {
262 result,
263 next_task_id: Some(task_id),
264 }))
265 }
266 }
267 }
268 .boxed(),
269 }
270 }
271}
272
273#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
275pub enum FoldState {
276 Unknown,
278 Collection,
280}