1use anyhow::Result;
2use std::collections::HashMap;
3use std::pin::Pin;
4use std::sync::Arc;
5use tokio::sync::{Mutex, mpsc};
6use tokio::task;
7
8type DynAsyncFn = dyn Fn() -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync;
9
10#[derive(Default)]
11pub struct Dag {
12 fns: HashMap<String, Arc<DynAsyncFn>>,
13 graph: HashMap<String, Vec<String>>,
14}
15
16impl Dag {
17 pub fn add_vertex<F, Fut>(&mut self, name: &str, f: F)
18 where
19 F: Fn() -> Fut + Send + Sync + 'static,
20 Fut: Future<Output = Result<()>> + Send + 'static,
21 {
22 self.fns.insert(
23 name.to_string(),
24 Arc::new(move || Box::pin(f()) as Pin<Box<dyn Future<Output = Result<()>> + Send>>),
25 );
26 }
27
28 pub fn add_edge(&mut self, from: &str, to: &str) {
29 self.graph
30 .entry(from.to_string())
31 .or_default()
32 .push(to.to_string());
33 }
34
35 fn detect_cycles(&self) -> bool {
36 fn dfs(
37 v: &str,
38 graph: &HashMap<String, Vec<String>>,
39 visited: &mut HashMap<String, bool>,
40 stack: &mut HashMap<String, bool>,
41 ) -> bool {
42 visited.insert(v.to_string(), true);
43 stack.insert(v.to_string(), true);
44
45 if let Some(neighbors) = graph.get(v) {
46 for n in neighbors {
47 if !*visited.get(n).unwrap_or(&false) {
48 if dfs(n, graph, visited, stack) {
49 return true;
50 }
51 } else if *stack.get(n).unwrap_or(&false) {
52 return true;
53 }
54 }
55 }
56
57 stack.insert(v.to_string(), false);
58 false
59 }
60
61 let mut visited = HashMap::new();
62 let mut stack = HashMap::new();
63 for v in self.graph.keys() {
64 if !*visited.get(v).unwrap_or(&false) && dfs(v, &self.graph, &mut visited, &mut stack) {
65 return true;
66 }
67 }
68 false
69 }
70
71 pub async fn run(&self) -> Result<()> {
72 if self.fns.is_empty() {
73 return Ok(());
74 }
75
76 let mut deps: HashMap<String, usize> = HashMap::new();
77 for (from, tos) in &self.graph {
78 if !self.fns.contains_key(from) {
79 anyhow::bail!("missing vertex");
80 }
81 for to in tos {
82 if !self.fns.contains_key(to) {
83 anyhow::bail!("missing vertex");
84 }
85 *deps.entry(to.clone()).or_default() += 1;
86 }
87 }
88
89 if self.detect_cycles() {
90 anyhow::bail!("dependency cycle detected");
91 }
92
93 let (tx, mut rx) = mpsc::unbounded_channel::<(String, Result<()>)>();
94 let deps = Arc::new(Mutex::new(deps));
95 let mut running = 0usize;
96 let mut err: Option<anyhow::Error> = None;
97
98 for name in self.fns.keys() {
99 if *deps.lock().await.get(name).unwrap_or(&0) == 0 {
100 running += 1;
101 Self::start(
102 name.clone(),
103 Arc::clone(self.fns.get(name).unwrap()),
104 tx.clone(),
105 )
106 .await;
107 }
108 }
109
110 while running > 0 {
111 let (name, res) = rx.recv().await.unwrap();
112 running -= 1;
113
114 if res.is_err() && err.is_none() {
115 err = res.err();
116 }
117
118 if err.is_some() {
119 continue;
120 }
121
122 if let Some(nexts) = self.graph.get(&name) {
123 for n in nexts {
124 let mut deps_lock = deps.lock().await;
125 let entry = deps_lock.entry(n.clone()).or_default();
126 *entry -= 1;
127 if *entry == 0 {
128 running += 1;
129 Self::start(n.clone(), Arc::clone(self.fns.get(n).unwrap()), tx.clone())
130 .await;
131 }
132 }
133 }
134 }
135
136 if let Some(e) = err { Err(e) } else { Ok(()) }
137 }
138
139 async fn start(
140 name: String,
141 f: Arc<DynAsyncFn>,
142 tx: mpsc::UnboundedSender<(String, Result<()>)>,
143 ) {
144 task::spawn(async move {
145 let r = f().await;
146 let _ = tx.send((name, r));
147 });
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use std::time::Duration;
154
155 use anyhow::anyhow;
156 use tokio::time::timeout;
157
158 use super::*;
159
160 #[tokio::test]
161 async fn test_zero() {
162 let dag = Dag::default();
163 let res = timeout(Duration::from_millis(100), dag.run()).await;
164 assert!(res.is_ok());
165 assert!(res.unwrap().is_ok());
166 }
167
168 #[tokio::test]
169 async fn test_one() {
170 let mut dag = Dag::default();
171 dag.add_vertex("one", || async { Err(anyhow!("error")) });
172
173 let res = timeout(Duration::from_millis(100), dag.run())
174 .await
175 .unwrap();
176 assert!(res.is_err());
177 assert_eq!(res.unwrap_err().to_string(), "error");
178 }
179
180 #[tokio::test]
181 async fn test_many_no_deps() {
182 let mut dag = Dag::default();
183 dag.add_vertex("one", || async { Err(anyhow!("error")) });
184 dag.add_vertex("two", || async { Ok(()) });
185 dag.add_vertex("three", || async { Ok(()) });
186 dag.add_vertex("four", || async { Ok(()) });
187
188 let res = timeout(Duration::from_millis(100), dag.run())
189 .await
190 .unwrap();
191 assert!(res.is_err());
192 assert_eq!(res.unwrap_err().to_string(), "error");
193 }
194
195 #[tokio::test]
196 async fn test_many_with_cycle() {
197 let mut dag = Dag::default();
198 dag.add_vertex("one", || async { Ok(()) });
199 dag.add_vertex("two", || async { Ok(()) });
200 dag.add_vertex("three", || async { Ok(()) });
201 dag.add_vertex("four", || async { Ok(()) });
202 dag.add_edge("one", "two");
203 dag.add_edge("two", "three");
204 dag.add_edge("three", "four");
205 dag.add_edge("three", "one"); let res = timeout(Duration::from_millis(100), dag.run())
208 .await
209 .unwrap();
210 assert!(res.is_err());
211 assert_eq!(res.unwrap_err().to_string(), "dependency cycle detected");
212 }
213
214 #[tokio::test]
215 async fn test_invalid_to_vertex() {
216 let mut dag = Dag::default();
217 dag.add_vertex("one", || async { Ok(()) });
218 dag.add_vertex("two", || async { Ok(()) });
219 dag.add_vertex("three", || async { Ok(()) });
220 dag.add_vertex("four", || async { Ok(()) });
221 dag.add_edge("one", "two");
222 dag.add_edge("two", "three");
223 dag.add_edge("three", "four");
224 dag.add_edge("three", "definitely-not-a-valid-vertex");
225
226 let res = timeout(Duration::from_millis(100), dag.run())
227 .await
228 .unwrap();
229 assert!(res.is_err());
230 assert_eq!(res.unwrap_err().to_string(), "missing vertex");
231 }
232
233 #[tokio::test]
234 async fn test_invalid_from_vertex() {
235 let mut dag = Dag::default();
236 dag.add_vertex("one", || async { Ok(()) });
237 dag.add_vertex("two", || async { Ok(()) });
238 dag.add_vertex("three", || async { Ok(()) });
239 dag.add_vertex("four", || async { Ok(()) });
240 dag.add_edge("one", "two");
241 dag.add_edge("two", "three");
242 dag.add_edge("three", "four");
243 dag.add_edge("definitely-not-a-valid-vertex", "three");
244
245 let res = timeout(Duration::from_millis(100), dag.run())
246 .await
247 .unwrap();
248 assert!(res.is_err());
249 assert_eq!(res.unwrap_err().to_string(), "missing vertex");
250 }
251
252 #[tokio::test]
253 async fn test_many_with_deps_success() {
254 let mut dag = Dag::default();
255 let (tx, mut rx) = mpsc::unbounded_channel::<String>();
256
257 let tasks = vec![
258 ("one", tx.clone()),
259 ("two", tx.clone()),
260 ("three", tx.clone()),
261 ("four", tx.clone()),
262 ("five", tx.clone()),
263 ("six", tx.clone()),
264 ("seven", tx.clone()),
265 ];
266
267 for (name, tx) in tasks {
268 dag.add_vertex(name, move || {
269 let tx = tx.clone();
270 let name = name.to_string();
271 async move {
272 let _ = tx.send(name);
273 Ok(())
274 }
275 });
276 }
277
278 dag.add_edge("one", "two");
279 dag.add_edge("one", "three");
280 dag.add_edge("two", "four");
281 dag.add_edge("two", "seven");
282 dag.add_edge("five", "six");
283
284 let res = timeout(Duration::from_millis(100), dag.run())
285 .await
286 .unwrap();
287 assert!(res.is_ok());
288
289 let mut results = Vec::new();
290 for _ in 0..7 {
291 let val = timeout(Duration::from_millis(100), rx.recv())
292 .await
293 .unwrap();
294 results.push(val.unwrap());
295 }
296
297 fn check_order(from: &str, to: &str, results: &[String]) {
298 let from_index = results.iter().position(|x| x == from).unwrap();
299 let to_index = results.iter().position(|x| x == to).unwrap();
300 assert!(
301 from_index <= to_index,
302 "from vertex: {} came after to vertex: {}",
303 from,
304 to
305 );
306 }
307
308 check_order("one", "two", &results);
309 check_order("one", "three", &results);
310 check_order("two", "four", &results);
311 check_order("two", "seven", &results);
312 check_order("five", "six", &results);
313 }
314}