dirtydata_core/
graph_utils.rs1use crate::ir::{Graph, EdgeKind};
2use crate::types::StableId;
3use std::collections::{HashMap, HashSet, VecDeque};
4
5pub fn topological_sort(graph: &Graph) -> (Vec<StableId>, Vec<Vec<StableId>>) {
8 let mut in_degree = HashMap::new();
9 let mut adj = HashMap::new();
10 let mut all_nodes = HashSet::new();
11
12 for id in graph.nodes.keys() {
13 all_nodes.insert(*id);
14 in_degree.insert(*id, 0);
15 adj.insert(*id, Vec::new());
16 }
17
18 for edge in graph.edges.values() {
19 if edge.kind == EdgeKind::Normal {
20 adj.get_mut(&edge.source.node_id).unwrap().push(edge.target.node_id);
21 *in_degree.get_mut(&edge.target.node_id).unwrap() += 1;
22 }
23 }
24
25 let mut queue = VecDeque::new();
26 for (id, degree) in &in_degree {
27 if *degree == 0 {
28 queue.push_back(*id);
29 }
30 }
31
32 let mut sorted = Vec::new();
33 while let Some(u) = queue.pop_front() {
34 sorted.push(u);
35 if let Some(neighbors) = adj.get(&u) {
36 for &v in neighbors {
37 let degree = in_degree.get_mut(&v).unwrap();
38 *degree -= 1;
39 if *degree == 0 {
40 queue.push_back(v);
41 }
42 }
43 }
44 }
45
46 let mut cycles = Vec::new();
48 if sorted.len() < all_nodes.len() {
49 let remaining: HashSet<_> = all_nodes
51 .into_iter()
52 .filter(|id| !sorted.contains(id))
53 .collect();
54 cycles.push(remaining.into_iter().collect());
57 }
58
59 (sorted, cycles)
60}
61
62pub fn get_upstream_nodes(graph: &Graph, target_node: StableId) -> HashSet<StableId> {
65 let mut upstream = HashSet::new();
66 let mut stack = vec![target_node];
67
68 while let Some(current) = stack.pop() {
69 if upstream.insert(current) {
70 for edge in graph.edges.values() {
71 if edge.target.node_id == current && edge.kind == EdgeKind::Normal {
72 stack.push(edge.source.node_id);
73 }
74 }
75 }
76 }
77 upstream
78}
79
80pub fn clone_subgraph(graph: &Graph, node_ids: &HashSet<StableId>) -> Graph {
82 let mut new_graph = Graph::new();
83
84 for &id in node_ids {
85 if let Some(node) = graph.nodes.get(&id) {
86 new_graph.nodes.insert(id, node.clone());
87 }
88 }
89
90 for edge in graph.edges.values() {
91 if node_ids.contains(&edge.source.node_id) && node_ids.contains(&edge.target.node_id) {
92 new_graph.edges.insert(edge.id, edge.clone());
93 }
94 }
95
96 for m in graph.modulations.values() {
98 if node_ids.contains(&m.source.node_id) && node_ids.contains(&m.target_node) {
99 new_graph.modulations.insert(m.id, m.clone());
100 }
101 }
102
103 new_graph
104}