apalis_workflow/dag/
mod.rs

1use std::{collections::HashMap, marker::PhantomData, sync::Mutex};
2
3use apalis_core::{
4    error::BoxDynError,
5    task::Task,
6    task_fn::{TaskFn, task_fn},
7};
8use petgraph::{
9    Direction,
10    algo::toposort,
11    dot::Config,
12    graph::{DiGraph, EdgeIndex, NodeIndex},
13};
14use tower::{Service, ServiceBuilder};
15
16use crate::{BoxedService, SteppedService};
17
18/// Directed Acyclic Graph (DAG) workflow builder
19#[derive(Debug)]
20pub struct DagFlow<Input = (), Output = ()> {
21    graph: Mutex<DiGraph<SteppedService<(), (), ()>, ()>>,
22    node_mapping: Mutex<HashMap<String, NodeIndex>>,
23    _marker: PhantomData<(Input, Output)>,
24}
25
26impl Default for DagFlow {
27    fn default() -> Self {
28        Self::new()
29    }
30}
31
32impl DagFlow {
33    /// Create a new DAG workflow builder
34    #[must_use]
35    pub fn new() -> Self {
36        Self {
37            graph: Mutex::new(DiGraph::new()),
38            node_mapping: Mutex::new(HashMap::new()),
39            _marker: PhantomData,
40        }
41    }
42
43    /// Add a node to the DAG
44    #[must_use]
45    #[allow(clippy::todo)]
46    pub fn add_node<S, Input>(&self, name: &str, service: S) -> NodeBuilder<'_, Input, S::Response>
47    where
48        S: Service<Task<Input, (), ()>> + Send + 'static,
49        S::Future: Send + 'static,
50    {
51        let svc = ServiceBuilder::new()
52            .map_request(|r: Task<(), (), ()>| todo!())
53            .map_response(|r: S::Response| todo!())
54            .map_err(|_e: S::Error| {
55                let boxed: BoxDynError = todo!();
56                boxed
57            })
58            .service(service);
59        let node = self.graph.lock().unwrap().add_node(BoxedService::new(svc));
60        self.node_mapping
61            .lock()
62            .unwrap()
63            .insert(name.to_owned(), node);
64        NodeBuilder {
65            id: node,
66            dag: self,
67            io: PhantomData,
68        }
69    }
70
71    /// Add a task function node to the DAG
72    pub fn node<F, Input, O, FnArgs>(&self, node: F) -> NodeBuilder<'_, Input, O>
73    where
74        TaskFn<F, Input, (), FnArgs>: Service<Task<Input, (), ()>, Response = O>,
75        F: Send + 'static,
76        Input: Send + 'static,
77        FnArgs: Send + 'static,
78        <TaskFn<F, Input, (), FnArgs> as Service<Task<Input, (), ()>>>::Future: Send + 'static,
79    {
80        self.add_node(std::any::type_name::<F>(), task_fn(node))
81    }
82
83    /// Add a routing node to the DAG
84    pub fn route<F, Input, O, FnArgs>(&self, router: F) -> NodeBuilder<'_, Input, O>
85    where
86        TaskFn<F, Input, (), FnArgs>: Service<Task<Input, (), ()>, Response = O>,
87        F: Send + 'static,
88        Input: Send + 'static,
89        FnArgs: Send + 'static,
90        <TaskFn<F, Input, (), FnArgs> as Service<Task<Input, (), ()>>>::Future: Send + 'static,
91        O: Into<NodeIndex>,
92    {
93        self.add_node::<TaskFn<F, Input, (), FnArgs>, Input>(
94            std::any::type_name::<F>(),
95            task_fn(router),
96        )
97    }
98
99    /// Build the DAG executor
100    pub fn build(self) -> Result<DagExecutor, String> {
101        // Validate DAG (check for cycles)
102        let sorted =
103            toposort(&*self.graph.lock().unwrap(), None).map_err(|_| "DAG contains cycles")?;
104
105        fn find_edge_nodes<N, E>(graph: &DiGraph<N, E>, direction: Direction) -> Vec<NodeIndex> {
106            graph
107                .node_indices()
108                .filter(|&n| graph.neighbors_directed(n, direction).count() == 0)
109                .collect()
110        }
111
112        let graph = self.graph.into_inner().unwrap();
113
114        Ok(DagExecutor {
115            start_nodes: find_edge_nodes(&graph, Direction::Incoming),
116            end_nodes: find_edge_nodes(&graph, Direction::Outgoing),
117            graph,
118            node_mapping: self.node_mapping.into_inner().unwrap(),
119            topological_order: sorted,
120        })
121    }
122}
123
124/// Executor for DAG workflows
125#[derive(Debug)]
126pub struct DagExecutor<Ctx = (), IdType = ()> {
127    graph: DiGraph<SteppedService<(), Ctx, IdType>, ()>,
128    node_mapping: HashMap<String, NodeIndex>,
129    topological_order: Vec<NodeIndex>,
130    start_nodes: Vec<NodeIndex>,
131    end_nodes: Vec<NodeIndex>,
132}
133
134impl DagExecutor {
135    /// Get a node by name
136    pub fn get_node_by_name_mut(&mut self, name: &str) -> Option<&mut SteppedService<(), (), ()>> {
137        self.node_mapping
138            .get(name)
139            .and_then(|&idx| self.graph.node_weight_mut(idx))
140    }
141
142    /// Export the DAG to DOT format
143    #[must_use]
144    pub fn to_dot(&self) -> String {
145        let names = self
146            .node_mapping
147            .iter()
148            .map(|(name, &idx)| (idx, name.clone()))
149            .collect::<HashMap<_, _>>();
150        let get_node_attributes = |_, (index, _)| {
151            format!(
152                "label=\"{}\"",
153                names.get(&index).cloned().unwrap_or_default()
154            )
155        };
156        let dot = petgraph::dot::Dot::with_attr_getters(
157            &self.graph,
158            &[Config::NodeNoLabel, Config::EdgeNoLabel],
159            &|_, _| String::new(),
160            &get_node_attributes,
161        );
162        format!("{dot:?}")
163    }
164}
165
166/// Builder for a node in the DAG
167#[derive(Clone, Debug)]
168pub struct NodeBuilder<'a, Input, Output = ()> {
169    pub(crate) id: NodeIndex,
170    pub(crate) dag: &'a DagFlow,
171    pub(crate) io: PhantomData<(Input, Output)>,
172}
173
174impl<Input, Output> NodeBuilder<'_, Input, Output> {
175    /// Specify dependencies for this node
176    #[allow(clippy::needless_pass_by_value)]
177    pub fn depends_on<D>(self, deps: D) -> NodeHandle<Input, Output>
178    where
179        D: DepsCheck<Input>,
180    {
181        let mut edges = Vec::new();
182        for dep in deps.to_node_ids() {
183            edges.push(self.dag.graph.lock().unwrap().add_edge(dep, self.id, ()));
184        }
185        NodeHandle {
186            id: self.id,
187            edges,
188            _phantom: PhantomData,
189        }
190    }
191}
192
193/// Handle for a node in the DAG
194#[derive(Clone, Debug)]
195pub struct NodeHandle<Input, Output = ()> {
196    pub(crate) id: NodeIndex,
197    pub(crate) edges: Vec<EdgeIndex>,
198    pub(crate) _phantom: PhantomData<(Input, Output)>,
199}
200
201/// Trait for converting dependencies into node IDs
202pub trait DepsCheck<Input> {
203    /// Convert dependencies to node IDs
204    fn to_node_ids(&self) -> Vec<NodeIndex>;
205}
206
207impl DepsCheck<()> for () {
208    fn to_node_ids(&self) -> Vec<NodeIndex> {
209        Vec::new()
210    }
211}
212
213impl<'a, Input, Output> DepsCheck<Output> for &NodeBuilder<'a, Input, Output> {
214    fn to_node_ids(&self) -> Vec<NodeIndex> {
215        vec![self.id]
216    }
217}
218
219impl<Input, Output> DepsCheck<Output> for &NodeHandle<Input, Output> {
220    fn to_node_ids(&self) -> Vec<NodeIndex> {
221        vec![self.id]
222    }
223}
224
225impl<Input, Output> DepsCheck<Output> for (&NodeHandle<Input, Output>,) {
226    fn to_node_ids(&self) -> Vec<NodeIndex> {
227        vec![self.0.id]
228    }
229}
230
231impl<'a, Input, Output> DepsCheck<Output> for (&NodeBuilder<'a, Input, Output>,) {
232    fn to_node_ids(&self) -> Vec<NodeIndex> {
233        vec![self.0.id]
234    }
235}
236
237impl<Output, T: DepsCheck<Output>> DepsCheck<Vec<Output>> for Vec<T> {
238    fn to_node_ids(&self) -> Vec<NodeIndex> {
239        self.iter().flat_map(|item| item.to_node_ids()).collect()
240    }
241}
242
243macro_rules! impl_deps_check {
244    ($( $len:literal => ( $( $in:ident $out:ident $idx:tt ),+ ) ),+ $(,)?) => {
245        $(
246            impl<'a, $( $in, )+ $( $out, )+> DepsCheck<( $( $out, )+ )>
247                for ( $( &NodeBuilder<'a, $in, $out>, )+ )
248            {
249                fn to_node_ids(&self) -> Vec<NodeIndex> {
250                    vec![ $( self.$idx.id ),+ ]
251                }
252            }
253
254            impl<$( $in, )+ $( $out, )+> DepsCheck<( $( $out, )+ )>
255                for ( $( &NodeHandle<$in, $out>, )+ )
256            {
257                fn to_node_ids(&self) -> Vec<NodeIndex> {
258                    vec![ $( self.$idx.id ),+ ]
259                }
260            }
261        )+
262    };
263}
264
265impl_deps_check! {
266    1 => (Input1 Output1 0),
267    2 => (Input1 Output1 0, Input2 Output2 1),
268    3 => (Input1 Output1 0, Input2 Output2 1, Input3 Output3 2),
269    4 => (Input1 Output1 0, Input2 Output2 1, Input3 Output3 2, Input4 Output4 3),
270    5 => (Input1 Output1 0, Input2 Output2 1, Input3 Output3 2, Input4 Output4 3, Input5 Output5 4),
271    6 => (Input1 Output1 0, Input2 Output2 1, Input3 Output3 2, Input4 Output4 3, Input5 Output5 4, Input6 Output6 5),
272    7 => (Input1 Output1 0, Input2 Output2 1, Input3 Output3 2, Input4 Output4 3, Input5 Output5 4, Input6 Output6 5, Input7 Output7 6),
273    8 => (Input1 Output1 0, Input2 Output2 1, Input3 Output3 2, Input4 Output4 3, Input5 Output5 4, Input6 Output6 5, Input7 Output7 6, Input8 Output8 7),
274}
275
276#[cfg(test)]
277mod tests {
278    use std::{
279        collections::HashMap, marker::PhantomData, num::ParseIntError, ops::Range, time::Duration,
280    };
281
282    use apalis_core::{
283        backend::json::JsonStorage, error::BoxDynError, task::Task, task_fn::task_fn,
284        worker::context::WorkerContext,
285    };
286    use petgraph::graph::NodeIndex;
287    use serde_json::Value;
288
289    use crate::{step::Identity, workflow::Workflow};
290
291    use super::*;
292
293    #[test]
294    fn test_basic_workflow() {
295        let dag = DagFlow::new();
296
297        let entry1 = dag.add_node("entry1", task_fn(|task: u32| async move { task as usize }));
298        let entry2 = dag.add_node("entry2", task_fn(|task: u32| async move { task as usize }));
299        let entry3 = dag.add_node("entry3", task_fn(|task: u32| async move { task as usize }));
300
301        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
302        enum EntryRoute {
303            Entry1(NodeIndex),
304            Entry2(NodeIndex),
305            Entry3(NodeIndex),
306        }
307
308        impl Into<NodeIndex> for EntryRoute {
309            fn into(self) -> NodeIndex {
310                match self {
311                    EntryRoute::Entry1(idx) => idx,
312                    EntryRoute::Entry2(idx) => idx,
313                    EntryRoute::Entry3(idx) => idx,
314                }
315            }
316        }
317
318        impl DepsCheck<usize> for EntryRoute {
319            fn to_node_ids(&self) -> Vec<NodeIndex> {
320                vec![(*self).into()]
321            }
322        }
323
324        async fn collect(task: (usize, usize, usize)) -> usize {
325            task.0 + task.1 + task.2
326        }
327        let collector = dag.node(collect).depends_on((&entry1, &entry2, &entry3));
328
329        async fn vec_collect(task: Vec<usize>, worker: WorkerContext) -> usize {
330            task.iter().sum::<usize>()
331        }
332
333        let vec_collector = dag
334            .node(vec_collect)
335            .depends_on(vec![&entry1, &entry2, &entry3]);
336
337        async fn exit(task: (usize, usize)) -> Result<u32, ParseIntError> {
338            (task.0.to_string() + &task.1.to_string()).parse()
339        }
340
341        let on_collect = dag.node(exit).depends_on((&collector, &vec_collector));
342
343        async fn check_approval(task: u32) -> Result<EntryRoute, BoxDynError> {
344            match task % 3 {
345                0 => Ok(EntryRoute::Entry1(NodeIndex::new(0))),
346                1 => Ok(EntryRoute::Entry2(NodeIndex::new(1))),
347                2 => Ok(EntryRoute::Entry3(NodeIndex::new(2))),
348                _ => Err(BoxDynError::from("Invalid task")),
349            }
350        }
351
352        dag.route(check_approval).depends_on(&on_collect);
353
354        let dag_executor = dag.build().unwrap();
355        assert_eq!(dag_executor.topological_order.len(), 7);
356
357        println!("Start nodes: {:?}", dag_executor.start_nodes);
358        println!("End nodes: {:?}", dag_executor.end_nodes);
359
360        println!(
361            "DAG Topological Order: {:?}",
362            dag_executor.topological_order
363        );
364
365        println!("DAG in DOT format:\n{}", dag_executor.to_dot());
366
367        // let inner_basic: Workflow<u32, usize, JsonStorage<Value>, _> = Workflow::new("basic")
368        //     .and_then(async |input: u32| (input + 1) as usize)
369        //     .and_then(async |input: usize| input.to_string())
370        //     .and_then(async |input: String| input.parse::<usize>());
371
372        // let workflow = Workflow::new("example_workflow")
373        //     .and_then(async |input: u32| Ok::<Range<u32>, BoxDynError>(input..100))
374        //     .filter_map(
375        //         async |input: u32| {
376        //             if input > 50 { Some(input) } else { None }
377        //         },
378        //     )
379        //     .and_then(async |items: Vec<u32>| Ok::<_, BoxDynError>(items))
380        //     .fold(0, async |(acc, item): (u32, u32)| {
381        //         Ok::<_, BoxDynError>(item + acc)
382        //     })
383        //     // .delay_for(Duration::from_secs(2))
384        //     // .delay_with(|_| Duration::from_secs(1))
385        //     // .and_then(async |items: Range<u32>| Ok::<_, BoxDynError>(items.sum::<u32>()))
386        //     // .repeat_until(async |i: u32| {
387        //     //     if i < 20 {
388        //     //         Ok::<_, BoxDynError>(Some(i))
389        //     //     } else {
390        //     //         Ok(None)
391        //     //     }
392        //     // })
393        //     // .chain(inner_basic)
394        //     // .chain(
395        //     //     Workflow::new("sub_workflow")
396        //     //         .and_then(async |input: usize| Ok::<_, BoxDynError>(input as u32 * 2))
397        //     //         .and_then(async |input: u32| Ok::<_, BoxDynError>(input + 10)),
398        //     // )
399        //     // .chain(dag_executor)
400        //     .build();
401    }
402}