apalis_workflow/
lib.rs

1#![doc = include_str!("../README.md")]
2use std::{
3    collections::HashMap,
4    fmt::Debug,
5    future::Future,
6    marker::PhantomData,
7    task::{Context, Poll},
8    time::Duration,
9};
10
11use apalis_core::{
12    backend::{Backend, TaskSink, TaskSinkError, WeakTaskSink, codec::Codec},
13    error::BoxDynError,
14    task::{Task, builder::TaskBuilder, metadata::MetadataExt, task_id::TaskId},
15    worker::builder::IntoWorkerService,
16};
17use futures::{Sink, future::BoxFuture};
18use serde::{Deserialize, Serialize};
19use tower::Service;
20
21use crate::{context::StepContext, service::WorkFlowService};
22
23mod context;
24mod id_generator;
25mod service;
26mod steps;
27
28pub use crate::steps::{delay::DelayStep, filter_map::FilterMapStep, then::ThenStep};
29pub use id_generator::GenerateId;
30pub use service::{StepResult, handle_workflow_result};
31
32type BoxedService<Input, Output> = tower::util::BoxService<Input, Output, BoxDynError>;
33type SteppedService<Compact, Ctx, IdType> = BoxedService<Task<Compact, Ctx, IdType>, Compact>;
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub enum GoTo<T = ()> {
37    Next(T),
38    DelayFor(Duration, T),
39    Done,
40    /// Breaks the current task execution
41    Break(T),
42
43    /// Execution will continue in another task identified by the String
44    /// Returning this does not guarantee that the task will be executed.
45    /// It may be an invalid task id, or the task may never be scheduled.
46    ContinueAt(String),
47}
48
49pub trait Step<Args, FlowSink, Encode>
50where
51    FlowSink: WeakTaskSink<Self::Response>,
52{
53    type Response;
54    type Error: Send;
55    fn run(
56        &mut self,
57        ctx: &StepContext<FlowSink, Encode>,
58        step: Task<Args, FlowSink::Context, FlowSink::IdType>,
59    ) -> impl Future<Output = Result<GoTo<Self::Response>, Self::Error>> + Send;
60}
61
62pub struct Workflow<Input, Current, FlowSink, Encode, Compact, Context, IdType> {
63    name: String,
64    steps: HashMap<usize, CompositeService<FlowSink, Encode, Compact, Context, IdType>>,
65    _marker: PhantomData<(Input, Current, FlowSink)>,
66}
67
68pub struct CompositeService<FlowSink, Encode, Compact, Context, IdType> {
69    svc: SteppedService<Compact, Context, IdType>,
70    _marker: PhantomData<(FlowSink, Encode)>,
71}
72
73impl<Input, FlowSink, Encode, Compact, Context, IdType>
74    Workflow<Input, Input, FlowSink, Encode, Compact, Context, IdType>
75{
76    pub fn new(name: &str) -> Self {
77        Self {
78            name: name.to_owned(),
79            steps: HashMap::new(),
80            _marker: PhantomData,
81        }
82    }
83}
84
85impl<Input, Current, FlowSink, Encode, Compact>
86    Workflow<Input, Current, FlowSink, Encode, Compact, FlowSink::Context, FlowSink::IdType>
87where
88    Current: Send + 'static,
89    FlowSink: Send + Clone + Sync + 'static + Unpin + Backend,
90{
91    pub fn add_step<S, Res, E, CodecError, BackendError>(
92        mut self,
93        step: S,
94    ) -> Workflow<Input, Res, FlowSink, Encode, Compact, FlowSink::Context, FlowSink::IdType>
95    where
96        FlowSink: WeakTaskSink<Res, Codec = Encode, Error = BackendError>
97            + Sink<Task<Compact, FlowSink::Context, FlowSink::IdType>, Error = BackendError>,
98        Current: std::marker::Send + 'static + Sync,
99        FlowSink::Context: Send + 'static + Sync,
100        S: Step<Current, FlowSink, Encode, Response = Res, Error = E>
101            + Sync
102            + Send
103            + 'static
104            + Clone,
105        S::Response: Send,
106        S::Error: Send,
107        Res: 'static + Sync,
108        FlowSink::IdType: Send,
109        Encode: Codec<Current, Compact = Compact, Error = CodecError>
110            + Codec<GoTo<Res>, Compact = Compact, Error = CodecError>
111            + Codec<Res, Compact = Compact, Error = CodecError>,
112        Compact: Send + Sync + 'static,
113        Encode: Send + Sync + 'static,
114        E: Into<BoxDynError> + Send + Sync + 'static,
115        CodecError: std::error::Error + Send + 'static + Sync,
116        FlowSink::Context: MetadataExt<WorkflowRequest>,
117        BackendError: std::error::Error + Send + Sync + 'static,
118    {
119        self.steps.insert(self.steps.len(), {
120            let svc =
121                SteppedService::<Compact, FlowSink::Context, FlowSink::IdType>::new(StepService {
122                    codec: PhantomData::<(Encode, Current, FlowSink)>,
123                    step,
124                });
125            CompositeService {
126                svc,
127                _marker: PhantomData,
128            }
129        });
130        Workflow {
131            name: self.name,
132            steps: self.steps,
133            _marker: PhantomData,
134        }
135    }
136}
137
138pub struct StepService<Step, Encode, Args, FlowSink> {
139    step: Step,
140    codec: PhantomData<(Encode, Args, FlowSink)>,
141}
142
143impl<Args, S, Encode, Compact, FlowSink, E, CodecError, BackendErr>
144    Service<Task<Compact, FlowSink::Context, FlowSink::IdType>>
145    for StepService<S, Encode, Args, FlowSink>
146where
147    S: Step<Args, FlowSink, Encode, Error = E> + Clone + Send + 'static,
148    Encode: Codec<Args, Compact = Compact, Error = CodecError>
149        + Codec<S::Response, Compact = Compact, Error = CodecError>
150        + Codec<GoTo<S::Response>, Compact = Compact, Error = CodecError>,
151    S::Response: Send + 'static + Sync,
152    S::Error: Send + 'static,
153    FlowSink: Clone
154        + Send
155        + 'static
156        + Sync
157        + WeakTaskSink<S::Response, Codec = Encode, Error = BackendErr>
158        + Unpin
159        + Sink<Task<Compact, FlowSink::Context, FlowSink::IdType>, Error = BackendErr>,
160    Args: Send + 'static,
161    FlowSink::Context: Send + 'static + MetadataExt<WorkflowRequest>,
162    FlowSink::IdType: Send + 'static,
163    Compact: Send + Sync + 'static,
164    Encode: Send + Sync + 'static,
165    E: Into<BoxDynError> + Send + 'static + Sync,
166    CodecError: std::error::Error + Send + 'static + Sync,
167    BackendErr: std::error::Error + Send + 'static + Sync,
168{
169    type Response = Compact;
170    type Error = BoxDynError;
171    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
172    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
173        Poll::Ready(Ok(()))
174    }
175    fn call(&mut self, req: Task<Compact, FlowSink::Context, FlowSink::IdType>) -> Self::Future {
176        let ctx: Option<StepContext<FlowSink, Encode>> = req.parts.data.get().cloned();
177        let mut step = self.step.clone();
178        Box::pin(async move {
179            match ctx {
180                Some(ctx) => {
181                    let mut ctx = ctx.clone();
182                    let req = req.try_map(|arg| Encode::decode(&arg));
183                    match req {
184                        Ok(task) => {
185                            let res = step.run(&ctx, task).await.map_err(|e| e.into())?;
186
187                            let _ = handle_workflow_result::<
188                                S::Response,
189                                Compact,
190                                FlowSink,
191                                BackendErr,
192                            >(&mut ctx, &res)
193                            .await
194                            .map_err(|e| match e {
195                                TaskSinkError::PushError(err) => Box::new(err) as BoxDynError,
196                                TaskSinkError::CodecError(err) => {
197                                    WorkflowError::CodecError(err.into()).into()
198                                }
199                            })?;
200                            Encode::encode(&res)
201                                .map_err(|e| WorkflowError::CodecError(e.into()).into())
202                        }
203                        Err(e) => Err(WorkflowError::CodecError(e.into()).into()),
204                    }
205                }
206                None => Err(WorkflowError::MissingContextError.into()),
207            }
208        })
209    }
210}
211
212#[derive(Debug, thiserror::Error)]
213pub enum WorkflowError {
214    #[error("Missing StepContext")]
215    MissingContextError,
216    #[error("CodecError: {0}")]
217    CodecError(BoxDynError),
218    #[error("SingleStepError: {0}")]
219    SingleStepError(BoxDynError),
220    #[error("SinkError: {0}")]
221    SinkError(BoxDynError),
222    #[error("MetadataError: {0}")]
223    MetadataError(BoxDynError),
224}
225
226#[derive(Debug, Clone, Deserialize, Serialize, Default)]
227pub struct WorkflowRequest {
228    pub step_index: usize,
229}
230
231impl<Input, Current, FlowSink, Encode, Compact, Err>
232    IntoWorkerService<
233        FlowSink,
234        WorkFlowService<FlowSink, Encode, Compact, FlowSink::Context, FlowSink::IdType>,
235        Compact,
236        FlowSink::Context,
237    > for Workflow<Input, Current, FlowSink, Encode, Compact, FlowSink::Context, FlowSink::IdType>
238where
239    FlowSink: Clone
240        + Send
241        + Sync
242        + 'static
243        + Sink<Task<Compact, FlowSink::Context, FlowSink::IdType>, Error = Err>
244        + Unpin,
245    Err: std::error::Error + Send + Sync + 'static,
246    Compact: Send,
247    FlowSink: TaskSink<Compact, Codec = Encode>,
248    FlowSink::Context: MetadataExt<WorkflowRequest> + Send + Sync + 'static,
249    Encode: Send + Sync + 'static + Codec<Compact, Compact = Compact>,
250    Compact: Send + Sync + 'static + Clone,
251    FlowSink::IdType: Send + 'static + Default,
252    FlowSink: Sync + Backend<Args = Compact, Error = Err>,
253    Compact: Send + Sync,
254    FlowSink::Context: Send + Default + MetadataExt<WorkflowRequest>,
255    FlowSink::IdType: GenerateId,
256    <FlowSink::Context as MetadataExt<WorkflowRequest>>::Error: Into<BoxDynError>,
257    Encode::Error: Into<BoxDynError>,
258{
259    fn into_service(
260        self,
261        b: &FlowSink,
262    ) -> WorkFlowService<FlowSink, Encode, Compact, FlowSink::Context, FlowSink::IdType> {
263        let services: HashMap<usize, _> = self
264            .steps
265            .into_iter()
266            .map(|(index, svc)| (index, svc))
267            .collect();
268        WorkFlowService::new(services, b.clone())
269    }
270}
271
272pub trait TaskFlowSink<Args, Compact>: Backend
273where
274    Self::Codec: Codec<Args>,
275{
276    fn push_start(&mut self, step: Args) -> impl Future<Output = Result<(), WorkflowError>> + Send {
277        self.push_step(step, 0)
278    }
279
280    fn push_step(
281        &mut self,
282        step: Args,
283        index: usize,
284    ) -> impl Future<Output = Result<(), WorkflowError>> + Send;
285}
286
287impl<S: Send, Args: Send, Compact> TaskFlowSink<Args, Compact> for S
288where
289    S: WeakTaskSink<Args>,
290    S::IdType: GenerateId + Send,
291    S::Codec: Codec<Args, Compact = Compact>,
292    S::Context: MetadataExt<WorkflowRequest> + Send,
293    S::Error: std::error::Error + Send + Sync + 'static,
294    <S::Codec as Codec<Args>>::Error: Into<BoxDynError> + Send + Sync + 'static,
295    <S::Context as MetadataExt<WorkflowRequest>>::Error: Into<BoxDynError> + Send + Sync + 'static,
296{
297    async fn push_step(&mut self, step: Args, index: usize) -> Result<(), WorkflowError> {
298        let task_id = TaskId::new(S::IdType::generate());
299        let task = TaskBuilder::new(step)
300            .meta(WorkflowRequest { step_index: index })
301            .with_task_id(task_id.clone())
302            .build();
303        self.push_task(task)
304            .await
305            .map_err(|e| WorkflowError::SinkError(e.into()))
306    }
307}
308
309#[cfg(test)]
310mod tests {
311
312    use apalis_core::{
313        backend::json::JsonStorage,
314        worker::{builder::WorkerBuilder, event::Event, ext::event_listener::EventListenerExt},
315    };
316
317    use std::time::Duration;
318
319    use crate::{TaskFlowSink, Workflow, WorkflowError};
320
321    #[tokio::test]
322    async fn simple_workflow() {
323        let workflow = Workflow::new("odd-numbers-workflow")
324            .then(|a: usize| async move { Ok::<_, WorkflowError>(a - 2) })
325            .delay_for(Duration::from_millis(1000))
326            .then(|_| async move { Err::<(), WorkflowError>(WorkflowError::MissingContextError) });
327
328        let mut in_memory = JsonStorage::new_temp().unwrap();
329
330        in_memory.push_start(usize::MAX).await.unwrap();
331
332        let worker = WorkerBuilder::new("rango-tango")
333            .backend(in_memory)
334            .on_event(|ctx, ev| {
335                println!("On Event = {:?}", ev);
336                if matches!(ev, Event::Error(_)) {
337                    ctx.stop().unwrap();
338                }
339            })
340            .build(workflow);
341        worker.run().await.unwrap();
342    }
343
344    #[tokio::test]
345    async fn then_workflow() {
346        let workflow = Workflow::new("then-workflow")
347            .then(|a: usize| async move { Ok::<_, WorkflowError>((0..a).collect::<Vec<_>>()) })
348            .filter_map(|x| async move { if x % 5 != 0 { Some(x) } else { None } })
349            .filter_map(|x| async move { if x % 3 != 0 { Some(x) } else { None } })
350            .filter_map(|x| async move { if x % 2 != 0 { Some(x) } else { None } })
351            .then(|a| async move {
352                dbg!(a);
353                Err::<(), WorkflowError>(WorkflowError::MissingContextError)
354            });
355
356        let mut in_memory = JsonStorage::new_temp().unwrap();
357
358        in_memory.push_start(100).await.unwrap();
359
360        let worker = WorkerBuilder::new("rango-tango")
361            .backend(in_memory)
362            .on_event(|ctx, ev| {
363                println!("On Event = {:?}", ev);
364                if matches!(ev, Event::Error(_)) {
365                    ctx.stop().unwrap();
366                }
367            })
368            .build(workflow);
369        worker.run().await.unwrap();
370    }
371}