1#![allow(dead_code)]
2use std::collections::{HashMap, HashSet, VecDeque};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12pub struct DepNodeId(pub u64);
13
14impl std::fmt::Display for DepNodeId {
15 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16 write!(f, "node_{}", self.0)
17 }
18}
19
20#[derive(Debug, Clone)]
22pub struct DepNode {
23 pub id: DepNodeId,
25 pub label: String,
27 pub cost: f64,
29}
30
31impl DepNode {
32 pub fn new(id: u64, label: &str, cost: f64) -> Self {
34 Self {
35 id: DepNodeId(id),
36 label: label.to_string(),
37 cost,
38 }
39 }
40}
41
42#[derive(Debug, Default)]
44pub struct DependencyGraph {
45 nodes: HashMap<DepNodeId, DepNode>,
47 forward_edges: HashMap<DepNodeId, HashSet<DepNodeId>>,
49 reverse_edges: HashMap<DepNodeId, HashSet<DepNodeId>>,
51}
52
53impl DependencyGraph {
54 pub fn new() -> Self {
56 Self::default()
57 }
58
59 pub fn add_node(&mut self, node: DepNode) {
61 let id = node.id;
62 self.nodes.insert(id, node);
63 self.forward_edges.entry(id).or_default();
64 self.reverse_edges.entry(id).or_default();
65 }
66
67 pub fn add_edge(&mut self, from: DepNodeId, to: DepNodeId) -> bool {
71 self.forward_edges.entry(from).or_default().insert(to);
72 self.reverse_edges.entry(to).or_default().insert(from)
73 }
74
75 pub fn node_count(&self) -> usize {
77 self.nodes.len()
78 }
79
80 pub fn edge_count(&self) -> usize {
82 self.forward_edges.values().map(|s| s.len()).sum()
83 }
84
85 pub fn dependencies_of(&self, id: DepNodeId) -> Vec<DepNodeId> {
87 self.reverse_edges
88 .get(&id)
89 .map(|s| s.iter().copied().collect())
90 .unwrap_or_default()
91 }
92
93 pub fn dependents_of(&self, id: DepNodeId) -> Vec<DepNodeId> {
95 self.forward_edges
96 .get(&id)
97 .map(|s| s.iter().copied().collect())
98 .unwrap_or_default()
99 }
100
101 pub fn roots(&self) -> Vec<DepNodeId> {
103 self.nodes
104 .keys()
105 .filter(|id| self.reverse_edges.get(id).map_or(true, HashSet::is_empty))
106 .copied()
107 .collect()
108 }
109
110 pub fn leaves(&self) -> Vec<DepNodeId> {
112 self.nodes
113 .keys()
114 .filter(|id| self.forward_edges.get(id).map_or(true, HashSet::is_empty))
115 .copied()
116 .collect()
117 }
118
119 pub fn topological_order(&self) -> Option<Vec<DepNodeId>> {
123 let mut in_degree: HashMap<DepNodeId, usize> = HashMap::new();
124 for id in self.nodes.keys() {
125 in_degree.insert(*id, self.reverse_edges.get(id).map_or(0, HashSet::len));
126 }
127
128 let mut queue: VecDeque<DepNodeId> = in_degree
129 .iter()
130 .filter(|(_, °)| deg == 0)
131 .map(|(id, _)| *id)
132 .collect();
133
134 let mut order = Vec::with_capacity(self.nodes.len());
135
136 while let Some(node) = queue.pop_front() {
137 order.push(node);
138 if let Some(successors) = self.forward_edges.get(&node) {
139 for &succ in successors {
140 if let Some(deg) = in_degree.get_mut(&succ) {
141 *deg -= 1;
142 if *deg == 0 {
143 queue.push_back(succ);
144 }
145 }
146 }
147 }
148 }
149
150 if order.len() == self.nodes.len() {
151 Some(order)
152 } else {
153 None
154 }
155 }
156
157 pub fn transitive_dependencies(&self, id: DepNodeId) -> HashSet<DepNodeId> {
159 let mut visited = HashSet::new();
160 let mut stack = vec![id];
161 while let Some(current) = stack.pop() {
162 if let Some(deps) = self.reverse_edges.get(¤t) {
163 for &dep in deps {
164 if visited.insert(dep) {
165 stack.push(dep);
166 }
167 }
168 }
169 }
170 visited
171 }
172
173 pub fn compute_depths(&self) -> HashMap<DepNodeId, u32> {
175 let mut depths: HashMap<DepNodeId, u32> = HashMap::new();
176 if let Some(order) = self.topological_order() {
177 for &node in &order {
178 let max_pred_depth = self
179 .reverse_edges
180 .get(&node)
181 .map(|preds| {
182 preds
183 .iter()
184 .filter_map(|p| depths.get(p))
185 .max()
186 .copied()
187 .unwrap_or(0)
188 })
189 .unwrap_or(0);
190 let depth = if self
191 .reverse_edges
192 .get(&node)
193 .map_or(true, HashSet::is_empty)
194 {
195 0
196 } else {
197 max_pred_depth + 1
198 };
199 depths.insert(node, depth);
200 }
201 }
202 depths
203 }
204
205 pub fn parallel_levels(&self) -> Vec<Vec<DepNodeId>> {
209 let depths = self.compute_depths();
210 if depths.is_empty() {
211 return Vec::new();
212 }
213 let max_depth = depths.values().copied().max().unwrap_or(0);
214 let mut levels = vec![Vec::new(); (max_depth + 1) as usize];
215 for (id, depth) in &depths {
216 levels[*depth as usize].push(*id);
217 }
218 levels
219 }
220
221 #[allow(clippy::cast_precision_loss)]
223 pub fn critical_path(&self) -> (Vec<DepNodeId>, f64) {
224 let order = match self.topological_order() {
225 Some(o) => o,
226 None => return (Vec::new(), 0.0),
227 };
228
229 let mut dist: HashMap<DepNodeId, f64> = HashMap::new();
230 let mut prev: HashMap<DepNodeId, DepNodeId> = HashMap::new();
231
232 for &node in &order {
233 let node_cost = self.nodes.get(&node).map_or(0.0, |n| n.cost);
234 let max_pred = self.reverse_edges.get(&node).and_then(|preds| {
235 preds
236 .iter()
237 .filter_map(|p| dist.get(p).map(|d| (*p, *d)))
238 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
239 });
240
241 let total = if let Some((pred_id, pred_dist)) = max_pred {
242 prev.insert(node, pred_id);
243 pred_dist + node_cost
244 } else {
245 node_cost
246 };
247 dist.insert(node, total);
248 }
249
250 let end_node = dist
252 .iter()
253 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
254 .map(|(id, _)| *id);
255
256 let end_node = match end_node {
257 Some(n) => n,
258 None => return (Vec::new(), 0.0),
259 };
260
261 let total_cost = dist[&end_node];
262
263 let mut path = vec![end_node];
265 let mut current = end_node;
266 while let Some(&pred) = prev.get(¤t) {
267 path.push(pred);
268 current = pred;
269 }
270 path.reverse();
271
272 (path, total_cost)
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 fn make_linear_graph() -> DependencyGraph {
281 let mut g = DependencyGraph::new();
282 g.add_node(DepNode::new(0, "A", 1.0));
283 g.add_node(DepNode::new(1, "B", 2.0));
284 g.add_node(DepNode::new(2, "C", 3.0));
285 g.add_edge(DepNodeId(0), DepNodeId(1));
286 g.add_edge(DepNodeId(1), DepNodeId(2));
287 g
288 }
289
290 fn make_diamond_graph() -> DependencyGraph {
291 let mut g = DependencyGraph::new();
293 g.add_node(DepNode::new(0, "A", 1.0));
294 g.add_node(DepNode::new(1, "B", 2.0));
295 g.add_node(DepNode::new(2, "C", 4.0));
296 g.add_node(DepNode::new(3, "D", 1.0));
297 g.add_edge(DepNodeId(0), DepNodeId(1));
298 g.add_edge(DepNodeId(0), DepNodeId(2));
299 g.add_edge(DepNodeId(1), DepNodeId(3));
300 g.add_edge(DepNodeId(2), DepNodeId(3));
301 g
302 }
303
304 #[test]
305 fn test_add_node_and_edge() {
306 let g = make_linear_graph();
307 assert_eq!(g.node_count(), 3);
308 assert_eq!(g.edge_count(), 2);
309 }
310
311 #[test]
312 fn test_roots_and_leaves() {
313 let g = make_linear_graph();
314 let roots = g.roots();
315 assert_eq!(roots.len(), 1);
316 assert_eq!(roots[0], DepNodeId(0));
317 let leaves = g.leaves();
318 assert_eq!(leaves.len(), 1);
319 assert_eq!(leaves[0], DepNodeId(2));
320 }
321
322 #[test]
323 fn test_dependencies_of() {
324 let g = make_linear_graph();
325 let deps = g.dependencies_of(DepNodeId(2));
326 assert_eq!(deps.len(), 1);
327 assert_eq!(deps[0], DepNodeId(1));
328 }
329
330 #[test]
331 fn test_dependents_of() {
332 let g = make_linear_graph();
333 let deps = g.dependents_of(DepNodeId(0));
334 assert_eq!(deps.len(), 1);
335 assert_eq!(deps[0], DepNodeId(1));
336 }
337
338 #[test]
339 fn test_topological_order() {
340 let g = make_linear_graph();
341 let order = g
342 .topological_order()
343 .expect("topological_order should succeed");
344 assert_eq!(order.len(), 3);
345 let pos_a = order
347 .iter()
348 .position(|&x| x == DepNodeId(0))
349 .expect("iter should succeed");
350 let pos_b = order
351 .iter()
352 .position(|&x| x == DepNodeId(1))
353 .expect("iter should succeed");
354 let pos_c = order
355 .iter()
356 .position(|&x| x == DepNodeId(2))
357 .expect("iter should succeed");
358 assert!(pos_a < pos_b);
359 assert!(pos_b < pos_c);
360 }
361
362 #[test]
363 fn test_topological_order_diamond() {
364 let g = make_diamond_graph();
365 let order = g
366 .topological_order()
367 .expect("topological_order should succeed");
368 assert_eq!(order.len(), 4);
369 let pos_a = order
370 .iter()
371 .position(|&x| x == DepNodeId(0))
372 .expect("iter should succeed");
373 let pos_d = order
374 .iter()
375 .position(|&x| x == DepNodeId(3))
376 .expect("iter should succeed");
377 assert!(pos_a < pos_d);
378 }
379
380 #[test]
381 fn test_transitive_dependencies() {
382 let g = make_linear_graph();
383 let trans = g.transitive_dependencies(DepNodeId(2));
384 assert!(trans.contains(&DepNodeId(0)));
385 assert!(trans.contains(&DepNodeId(1)));
386 assert_eq!(trans.len(), 2);
387 }
388
389 #[test]
390 fn test_compute_depths() {
391 let g = make_linear_graph();
392 let depths = g.compute_depths();
393 assert_eq!(depths[&DepNodeId(0)], 0);
394 assert_eq!(depths[&DepNodeId(1)], 1);
395 assert_eq!(depths[&DepNodeId(2)], 2);
396 }
397
398 #[test]
399 fn test_parallel_levels_diamond() {
400 let g = make_diamond_graph();
401 let levels = g.parallel_levels();
402 assert_eq!(levels.len(), 3);
403 assert_eq!(levels[0].len(), 1);
405 assert_eq!(levels[1].len(), 2);
406 assert_eq!(levels[2].len(), 1);
407 }
408
409 #[test]
410 fn test_critical_path_linear() {
411 let g = make_linear_graph();
412 let (path, cost) = g.critical_path();
413 assert_eq!(path.len(), 3);
414 assert!((cost - 6.0).abs() < f64::EPSILON);
415 }
416
417 #[test]
418 fn test_critical_path_diamond() {
419 let g = make_diamond_graph();
420 let (path, cost) = g.critical_path();
421 assert!((cost - 6.0).abs() < f64::EPSILON);
423 assert!(path.contains(&DepNodeId(0)));
424 assert!(path.contains(&DepNodeId(3)));
425 }
426
427 #[test]
428 fn test_empty_graph() {
429 let g = DependencyGraph::new();
430 assert_eq!(g.node_count(), 0);
431 assert_eq!(g.edge_count(), 0);
432 assert!(g.roots().is_empty());
433 assert!(g.leaves().is_empty());
434 let (path, cost) = g.critical_path();
435 assert!(path.is_empty());
436 assert!((cost - 0.0).abs() < f64::EPSILON);
437 }
438
439 #[test]
440 fn test_dep_node_id_display() {
441 let id = DepNodeId(42);
442 assert_eq!(format!("{id}"), "node_42");
443 }
444}