1use std::{collections::HashMap, marker::PhantomData, sync::Mutex};
2
3use apalis_core::{
4 error::BoxDynError,
5 task::Task,
6 task_fn::{TaskFn, task_fn},
7};
8use petgraph::{
9 Direction,
10 algo::toposort,
11 dot::Config,
12 graph::{DiGraph, EdgeIndex, NodeIndex},
13};
14use tower::{Service, ServiceBuilder};
15
16use crate::{BoxedService, SteppedService};
17
18#[derive(Debug)]
20pub struct DagFlow<Input = (), Output = ()> {
21 graph: Mutex<DiGraph<SteppedService<(), (), ()>, ()>>,
22 node_mapping: Mutex<HashMap<String, NodeIndex>>,
23 _marker: PhantomData<(Input, Output)>,
24}
25
26impl Default for DagFlow {
27 fn default() -> Self {
28 Self::new()
29 }
30}
31
32impl DagFlow {
33 #[must_use]
35 pub fn new() -> Self {
36 Self {
37 graph: Mutex::new(DiGraph::new()),
38 node_mapping: Mutex::new(HashMap::new()),
39 _marker: PhantomData,
40 }
41 }
42
43 #[must_use]
45 #[allow(clippy::todo)]
46 pub fn add_node<S, Input>(&self, name: &str, service: S) -> NodeBuilder<'_, Input, S::Response>
47 where
48 S: Service<Task<Input, (), ()>> + Send + 'static,
49 S::Future: Send + 'static,
50 {
51 let svc = ServiceBuilder::new()
52 .map_request(|r: Task<(), (), ()>| todo!())
53 .map_response(|r: S::Response| todo!())
54 .map_err(|_e: S::Error| {
55 let boxed: BoxDynError = todo!();
56 boxed
57 })
58 .service(service);
59 let node = self.graph.lock().unwrap().add_node(BoxedService::new(svc));
60 self.node_mapping
61 .lock()
62 .unwrap()
63 .insert(name.to_owned(), node);
64 NodeBuilder {
65 id: node,
66 dag: self,
67 io: PhantomData,
68 }
69 }
70
71 pub fn node<F, Input, O, FnArgs>(&self, node: F) -> NodeBuilder<'_, Input, O>
73 where
74 TaskFn<F, Input, (), FnArgs>: Service<Task<Input, (), ()>, Response = O>,
75 F: Send + 'static,
76 Input: Send + 'static,
77 FnArgs: Send + 'static,
78 <TaskFn<F, Input, (), FnArgs> as Service<Task<Input, (), ()>>>::Future: Send + 'static,
79 {
80 self.add_node(std::any::type_name::<F>(), task_fn(node))
81 }
82
83 pub fn route<F, Input, O, FnArgs>(&self, router: F) -> NodeBuilder<'_, Input, O>
85 where
86 TaskFn<F, Input, (), FnArgs>: Service<Task<Input, (), ()>, Response = O>,
87 F: Send + 'static,
88 Input: Send + 'static,
89 FnArgs: Send + 'static,
90 <TaskFn<F, Input, (), FnArgs> as Service<Task<Input, (), ()>>>::Future: Send + 'static,
91 O: Into<NodeIndex>,
92 {
93 self.add_node::<TaskFn<F, Input, (), FnArgs>, Input>(
94 std::any::type_name::<F>(),
95 task_fn(router),
96 )
97 }
98
99 pub fn build(self) -> Result<DagExecutor, String> {
101 let sorted =
103 toposort(&*self.graph.lock().unwrap(), None).map_err(|_| "DAG contains cycles")?;
104
105 fn find_edge_nodes<N, E>(graph: &DiGraph<N, E>, direction: Direction) -> Vec<NodeIndex> {
106 graph
107 .node_indices()
108 .filter(|&n| graph.neighbors_directed(n, direction).count() == 0)
109 .collect()
110 }
111
112 let graph = self.graph.into_inner().unwrap();
113
114 Ok(DagExecutor {
115 start_nodes: find_edge_nodes(&graph, Direction::Incoming),
116 end_nodes: find_edge_nodes(&graph, Direction::Outgoing),
117 graph,
118 node_mapping: self.node_mapping.into_inner().unwrap(),
119 topological_order: sorted,
120 })
121 }
122}
123
124#[derive(Debug)]
126pub struct DagExecutor<Ctx = (), IdType = ()> {
127 graph: DiGraph<SteppedService<(), Ctx, IdType>, ()>,
128 node_mapping: HashMap<String, NodeIndex>,
129 topological_order: Vec<NodeIndex>,
130 start_nodes: Vec<NodeIndex>,
131 end_nodes: Vec<NodeIndex>,
132}
133
134impl DagExecutor {
135 pub fn get_node_by_name_mut(&mut self, name: &str) -> Option<&mut SteppedService<(), (), ()>> {
137 self.node_mapping
138 .get(name)
139 .and_then(|&idx| self.graph.node_weight_mut(idx))
140 }
141
142 #[must_use]
144 pub fn to_dot(&self) -> String {
145 let names = self
146 .node_mapping
147 .iter()
148 .map(|(name, &idx)| (idx, name.clone()))
149 .collect::<HashMap<_, _>>();
150 let get_node_attributes = |_, (index, _)| {
151 format!(
152 "label=\"{}\"",
153 names.get(&index).cloned().unwrap_or_default()
154 )
155 };
156 let dot = petgraph::dot::Dot::with_attr_getters(
157 &self.graph,
158 &[Config::NodeNoLabel, Config::EdgeNoLabel],
159 &|_, _| String::new(),
160 &get_node_attributes,
161 );
162 format!("{dot:?}")
163 }
164}
165
166#[derive(Clone, Debug)]
168pub struct NodeBuilder<'a, Input, Output = ()> {
169 pub(crate) id: NodeIndex,
170 pub(crate) dag: &'a DagFlow,
171 pub(crate) io: PhantomData<(Input, Output)>,
172}
173
174impl<Input, Output> NodeBuilder<'_, Input, Output> {
175 #[allow(clippy::needless_pass_by_value)]
177 pub fn depends_on<D>(self, deps: D) -> NodeHandle<Input, Output>
178 where
179 D: DepsCheck<Input>,
180 {
181 let mut edges = Vec::new();
182 for dep in deps.to_node_ids() {
183 edges.push(self.dag.graph.lock().unwrap().add_edge(dep, self.id, ()));
184 }
185 NodeHandle {
186 id: self.id,
187 edges,
188 _phantom: PhantomData,
189 }
190 }
191}
192
193#[derive(Clone, Debug)]
195pub struct NodeHandle<Input, Output = ()> {
196 pub(crate) id: NodeIndex,
197 pub(crate) edges: Vec<EdgeIndex>,
198 pub(crate) _phantom: PhantomData<(Input, Output)>,
199}
200
201pub trait DepsCheck<Input> {
203 fn to_node_ids(&self) -> Vec<NodeIndex>;
205}
206
207impl DepsCheck<()> for () {
208 fn to_node_ids(&self) -> Vec<NodeIndex> {
209 Vec::new()
210 }
211}
212
213impl<'a, Input, Output> DepsCheck<Output> for &NodeBuilder<'a, Input, Output> {
214 fn to_node_ids(&self) -> Vec<NodeIndex> {
215 vec![self.id]
216 }
217}
218
219impl<Input, Output> DepsCheck<Output> for &NodeHandle<Input, Output> {
220 fn to_node_ids(&self) -> Vec<NodeIndex> {
221 vec![self.id]
222 }
223}
224
225impl<Input, Output> DepsCheck<Output> for (&NodeHandle<Input, Output>,) {
226 fn to_node_ids(&self) -> Vec<NodeIndex> {
227 vec![self.0.id]
228 }
229}
230
231impl<'a, Input, Output> DepsCheck<Output> for (&NodeBuilder<'a, Input, Output>,) {
232 fn to_node_ids(&self) -> Vec<NodeIndex> {
233 vec![self.0.id]
234 }
235}
236
237impl<Output, T: DepsCheck<Output>> DepsCheck<Vec<Output>> for Vec<T> {
238 fn to_node_ids(&self) -> Vec<NodeIndex> {
239 self.iter().flat_map(|item| item.to_node_ids()).collect()
240 }
241}
242
243macro_rules! impl_deps_check {
244 ($( $len:literal => ( $( $in:ident $out:ident $idx:tt ),+ ) ),+ $(,)?) => {
245 $(
246 impl<'a, $( $in, )+ $( $out, )+> DepsCheck<( $( $out, )+ )>
247 for ( $( &NodeBuilder<'a, $in, $out>, )+ )
248 {
249 fn to_node_ids(&self) -> Vec<NodeIndex> {
250 vec![ $( self.$idx.id ),+ ]
251 }
252 }
253
254 impl<$( $in, )+ $( $out, )+> DepsCheck<( $( $out, )+ )>
255 for ( $( &NodeHandle<$in, $out>, )+ )
256 {
257 fn to_node_ids(&self) -> Vec<NodeIndex> {
258 vec![ $( self.$idx.id ),+ ]
259 }
260 }
261 )+
262 };
263}
264
265impl_deps_check! {
266 1 => (Input1 Output1 0),
267 2 => (Input1 Output1 0, Input2 Output2 1),
268 3 => (Input1 Output1 0, Input2 Output2 1, Input3 Output3 2),
269 4 => (Input1 Output1 0, Input2 Output2 1, Input3 Output3 2, Input4 Output4 3),
270 5 => (Input1 Output1 0, Input2 Output2 1, Input3 Output3 2, Input4 Output4 3, Input5 Output5 4),
271 6 => (Input1 Output1 0, Input2 Output2 1, Input3 Output3 2, Input4 Output4 3, Input5 Output5 4, Input6 Output6 5),
272 7 => (Input1 Output1 0, Input2 Output2 1, Input3 Output3 2, Input4 Output4 3, Input5 Output5 4, Input6 Output6 5, Input7 Output7 6),
273 8 => (Input1 Output1 0, Input2 Output2 1, Input3 Output3 2, Input4 Output4 3, Input5 Output5 4, Input6 Output6 5, Input7 Output7 6, Input8 Output8 7),
274}
275
276#[cfg(test)]
277mod tests {
278 use std::{
279 collections::HashMap, marker::PhantomData, num::ParseIntError, ops::Range, time::Duration,
280 };
281
282 use apalis_core::{
283 backend::json::JsonStorage, error::BoxDynError, task::Task, task_fn::task_fn,
284 worker::context::WorkerContext,
285 };
286 use petgraph::graph::NodeIndex;
287 use serde_json::Value;
288
289 use crate::{step::Identity, workflow::Workflow};
290
291 use super::*;
292
293 #[test]
294 fn test_basic_workflow() {
295 let dag = DagFlow::new();
296
297 let entry1 = dag.add_node("entry1", task_fn(|task: u32| async move { task as usize }));
298 let entry2 = dag.add_node("entry2", task_fn(|task: u32| async move { task as usize }));
299 let entry3 = dag.add_node("entry3", task_fn(|task: u32| async move { task as usize }));
300
301 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
302 enum EntryRoute {
303 Entry1(NodeIndex),
304 Entry2(NodeIndex),
305 Entry3(NodeIndex),
306 }
307
308 impl Into<NodeIndex> for EntryRoute {
309 fn into(self) -> NodeIndex {
310 match self {
311 EntryRoute::Entry1(idx) => idx,
312 EntryRoute::Entry2(idx) => idx,
313 EntryRoute::Entry3(idx) => idx,
314 }
315 }
316 }
317
318 impl DepsCheck<usize> for EntryRoute {
319 fn to_node_ids(&self) -> Vec<NodeIndex> {
320 vec![(*self).into()]
321 }
322 }
323
324 async fn collect(task: (usize, usize, usize)) -> usize {
325 task.0 + task.1 + task.2
326 }
327 let collector = dag.node(collect).depends_on((&entry1, &entry2, &entry3));
328
329 async fn vec_collect(task: Vec<usize>, worker: WorkerContext) -> usize {
330 task.iter().sum::<usize>()
331 }
332
333 let vec_collector = dag
334 .node(vec_collect)
335 .depends_on(vec![&entry1, &entry2, &entry3]);
336
337 async fn exit(task: (usize, usize)) -> Result<u32, ParseIntError> {
338 (task.0.to_string() + &task.1.to_string()).parse()
339 }
340
341 let on_collect = dag.node(exit).depends_on((&collector, &vec_collector));
342
343 async fn check_approval(task: u32) -> Result<EntryRoute, BoxDynError> {
344 match task % 3 {
345 0 => Ok(EntryRoute::Entry1(NodeIndex::new(0))),
346 1 => Ok(EntryRoute::Entry2(NodeIndex::new(1))),
347 2 => Ok(EntryRoute::Entry3(NodeIndex::new(2))),
348 _ => Err(BoxDynError::from("Invalid task")),
349 }
350 }
351
352 dag.route(check_approval).depends_on(&on_collect);
353
354 let dag_executor = dag.build().unwrap();
355 assert_eq!(dag_executor.topological_order.len(), 7);
356
357 println!("Start nodes: {:?}", dag_executor.start_nodes);
358 println!("End nodes: {:?}", dag_executor.end_nodes);
359
360 println!(
361 "DAG Topological Order: {:?}",
362 dag_executor.topological_order
363 );
364
365 println!("DAG in DOT format:\n{}", dag_executor.to_dot());
366
367 }
402}