dag_runner/
dag.rs

1use anyhow::Result;
2use std::collections::HashMap;
3use std::pin::Pin;
4use std::sync::Arc;
5use tokio::sync::{Mutex, mpsc};
6use tokio::task;
7
8type DynAsyncFn = dyn Fn() -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync;
9
10#[derive(Default)]
11pub struct Dag {
12    fns: HashMap<String, Arc<DynAsyncFn>>,
13    graph: HashMap<String, Vec<String>>,
14}
15
16impl Dag {
17    pub fn add_vertex<F, Fut>(&mut self, name: &str, f: F)
18    where
19        F: Fn() -> Fut + Send + Sync + 'static,
20        Fut: Future<Output = Result<()>> + Send + 'static,
21    {
22        self.fns.insert(
23            name.to_string(),
24            Arc::new(move || Box::pin(f()) as Pin<Box<dyn Future<Output = Result<()>> + Send>>),
25        );
26    }
27
28    pub fn add_edge(&mut self, from: &str, to: &str) {
29        self.graph
30            .entry(from.to_string())
31            .or_default()
32            .push(to.to_string());
33    }
34
35    fn detect_cycles(&self) -> bool {
36        fn dfs(
37            v: &str,
38            graph: &HashMap<String, Vec<String>>,
39            visited: &mut HashMap<String, bool>,
40            stack: &mut HashMap<String, bool>,
41        ) -> bool {
42            visited.insert(v.to_string(), true);
43            stack.insert(v.to_string(), true);
44
45            if let Some(neighbors) = graph.get(v) {
46                for n in neighbors {
47                    if !*visited.get(n).unwrap_or(&false) {
48                        if dfs(n, graph, visited, stack) {
49                            return true;
50                        }
51                    } else if *stack.get(n).unwrap_or(&false) {
52                        return true;
53                    }
54                }
55            }
56
57            stack.insert(v.to_string(), false);
58            false
59        }
60
61        let mut visited = HashMap::new();
62        let mut stack = HashMap::new();
63        for v in self.graph.keys() {
64            if !*visited.get(v).unwrap_or(&false) && dfs(v, &self.graph, &mut visited, &mut stack) {
65                return true;
66            }
67        }
68        false
69    }
70
71    pub async fn run(&self) -> Result<()> {
72        if self.fns.is_empty() {
73            return Ok(());
74        }
75
76        let mut deps: HashMap<String, usize> = HashMap::new();
77        for (from, tos) in &self.graph {
78            if !self.fns.contains_key(from) {
79                anyhow::bail!("missing vertex");
80            }
81            for to in tos {
82                if !self.fns.contains_key(to) {
83                    anyhow::bail!("missing vertex");
84                }
85                *deps.entry(to.clone()).or_default() += 1;
86            }
87        }
88
89        if self.detect_cycles() {
90            anyhow::bail!("dependency cycle detected");
91        }
92
93        let (tx, mut rx) = mpsc::unbounded_channel::<(String, Result<()>)>();
94        let deps = Arc::new(Mutex::new(deps));
95        let mut running = 0usize;
96        let mut err: Option<anyhow::Error> = None;
97
98        for name in self.fns.keys() {
99            if *deps.lock().await.get(name).unwrap_or(&0) == 0 {
100                running += 1;
101                Self::start(
102                    name.clone(),
103                    Arc::clone(self.fns.get(name).unwrap()),
104                    tx.clone(),
105                )
106                .await;
107            }
108        }
109
110        while running > 0 {
111            let (name, res) = rx.recv().await.unwrap();
112            running -= 1;
113
114            if res.is_err() && err.is_none() {
115                err = res.err();
116            }
117
118            if err.is_some() {
119                continue;
120            }
121
122            if let Some(nexts) = self.graph.get(&name) {
123                for n in nexts {
124                    let mut deps_lock = deps.lock().await;
125                    let entry = deps_lock.entry(n.clone()).or_default();
126                    *entry -= 1;
127                    if *entry == 0 {
128                        running += 1;
129                        Self::start(n.clone(), Arc::clone(self.fns.get(n).unwrap()), tx.clone())
130                            .await;
131                    }
132                }
133            }
134        }
135
136        if let Some(e) = err { Err(e) } else { Ok(()) }
137    }
138
139    async fn start(
140        name: String,
141        f: Arc<DynAsyncFn>,
142        tx: mpsc::UnboundedSender<(String, Result<()>)>,
143    ) {
144        task::spawn(async move {
145            let r = f().await;
146            let _ = tx.send((name, r));
147        });
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use std::time::Duration;
154
155    use anyhow::anyhow;
156    use tokio::time::timeout;
157
158    use super::*;
159
160    #[tokio::test]
161    async fn test_zero() {
162        let dag = Dag::default();
163        let res = timeout(Duration::from_millis(100), dag.run()).await;
164        assert!(res.is_ok());
165        assert!(res.unwrap().is_ok());
166    }
167
168    #[tokio::test]
169    async fn test_one() {
170        let mut dag = Dag::default();
171        dag.add_vertex("one", || async { Err(anyhow!("error")) });
172
173        let res = timeout(Duration::from_millis(100), dag.run())
174            .await
175            .unwrap();
176        assert!(res.is_err());
177        assert_eq!(res.unwrap_err().to_string(), "error");
178    }
179
180    #[tokio::test]
181    async fn test_many_no_deps() {
182        let mut dag = Dag::default();
183        dag.add_vertex("one", || async { Err(anyhow!("error")) });
184        dag.add_vertex("two", || async { Ok(()) });
185        dag.add_vertex("three", || async { Ok(()) });
186        dag.add_vertex("four", || async { Ok(()) });
187
188        let res = timeout(Duration::from_millis(100), dag.run())
189            .await
190            .unwrap();
191        assert!(res.is_err());
192        assert_eq!(res.unwrap_err().to_string(), "error");
193    }
194
195    #[tokio::test]
196    async fn test_many_with_cycle() {
197        let mut dag = Dag::default();
198        dag.add_vertex("one", || async { Ok(()) });
199        dag.add_vertex("two", || async { Ok(()) });
200        dag.add_vertex("three", || async { Ok(()) });
201        dag.add_vertex("four", || async { Ok(()) });
202        dag.add_edge("one", "two");
203        dag.add_edge("two", "three");
204        dag.add_edge("three", "four");
205        dag.add_edge("three", "one"); // cycle
206
207        let res = timeout(Duration::from_millis(100), dag.run())
208            .await
209            .unwrap();
210        assert!(res.is_err());
211        assert_eq!(res.unwrap_err().to_string(), "dependency cycle detected");
212    }
213
214    #[tokio::test]
215    async fn test_invalid_to_vertex() {
216        let mut dag = Dag::default();
217        dag.add_vertex("one", || async { Ok(()) });
218        dag.add_vertex("two", || async { Ok(()) });
219        dag.add_vertex("three", || async { Ok(()) });
220        dag.add_vertex("four", || async { Ok(()) });
221        dag.add_edge("one", "two");
222        dag.add_edge("two", "three");
223        dag.add_edge("three", "four");
224        dag.add_edge("three", "definitely-not-a-valid-vertex");
225
226        let res = timeout(Duration::from_millis(100), dag.run())
227            .await
228            .unwrap();
229        assert!(res.is_err());
230        assert_eq!(res.unwrap_err().to_string(), "missing vertex");
231    }
232
233    #[tokio::test]
234    async fn test_invalid_from_vertex() {
235        let mut dag = Dag::default();
236        dag.add_vertex("one", || async { Ok(()) });
237        dag.add_vertex("two", || async { Ok(()) });
238        dag.add_vertex("three", || async { Ok(()) });
239        dag.add_vertex("four", || async { Ok(()) });
240        dag.add_edge("one", "two");
241        dag.add_edge("two", "three");
242        dag.add_edge("three", "four");
243        dag.add_edge("definitely-not-a-valid-vertex", "three");
244
245        let res = timeout(Duration::from_millis(100), dag.run())
246            .await
247            .unwrap();
248        assert!(res.is_err());
249        assert_eq!(res.unwrap_err().to_string(), "missing vertex");
250    }
251
252    #[tokio::test]
253    async fn test_many_with_deps_success() {
254        let mut dag = Dag::default();
255        let (tx, mut rx) = mpsc::unbounded_channel::<String>();
256
257        let tasks = vec![
258            ("one", tx.clone()),
259            ("two", tx.clone()),
260            ("three", tx.clone()),
261            ("four", tx.clone()),
262            ("five", tx.clone()),
263            ("six", tx.clone()),
264            ("seven", tx.clone()),
265        ];
266
267        for (name, tx) in tasks {
268            dag.add_vertex(name, move || {
269                let tx = tx.clone();
270                let name = name.to_string();
271                async move {
272                    let _ = tx.send(name);
273                    Ok(())
274                }
275            });
276        }
277
278        dag.add_edge("one", "two");
279        dag.add_edge("one", "three");
280        dag.add_edge("two", "four");
281        dag.add_edge("two", "seven");
282        dag.add_edge("five", "six");
283
284        let res = timeout(Duration::from_millis(100), dag.run())
285            .await
286            .unwrap();
287        assert!(res.is_ok());
288
289        let mut results = Vec::new();
290        for _ in 0..7 {
291            let val = timeout(Duration::from_millis(100), rx.recv())
292                .await
293                .unwrap();
294            results.push(val.unwrap());
295        }
296
297        fn check_order(from: &str, to: &str, results: &[String]) {
298            let from_index = results.iter().position(|x| x == from).unwrap();
299            let to_index = results.iter().position(|x| x == to).unwrap();
300            assert!(
301                from_index <= to_index,
302                "from vertex: {} came after to vertex: {}",
303                from,
304                to
305            );
306        }
307
308        check_order("one", "two", &results);
309        check_order("one", "three", &results);
310        check_order("two", "four", &results);
311        check_order("two", "seven", &results);
312        check_order("five", "six", &results);
313    }
314}