Skip to main content

dirtydata_core/
graph_utils.rs

1use crate::ir::{Graph, EdgeKind};
2use crate::types::StableId;
3use std::collections::{HashMap, HashSet, VecDeque};
4
5/// Sorts the graph nodes topologically.
6/// If cycles are detected, they are returned as well.
7pub fn topological_sort(graph: &Graph) -> (Vec<StableId>, Vec<Vec<StableId>>) {
8    let mut in_degree = HashMap::new();
9    let mut adj = HashMap::new();
10    let mut all_nodes = HashSet::new();
11
12    for id in graph.nodes.keys() {
13        all_nodes.insert(*id);
14        in_degree.insert(*id, 0);
15        adj.insert(*id, Vec::new());
16    }
17
18    for edge in graph.edges.values() {
19        if edge.kind == EdgeKind::Normal {
20            adj.get_mut(&edge.source.node_id).unwrap().push(edge.target.node_id);
21            *in_degree.get_mut(&edge.target.node_id).unwrap() += 1;
22        }
23    }
24
25    let mut queue = VecDeque::new();
26    for (id, degree) in &in_degree {
27        if *degree == 0 {
28            queue.push_back(*id);
29        }
30    }
31
32    let mut sorted = Vec::new();
33    while let Some(u) = queue.pop_front() {
34        sorted.push(u);
35        if let Some(neighbors) = adj.get(&u) {
36            for &v in neighbors {
37                let degree = in_degree.get_mut(&v).unwrap();
38                *degree -= 1;
39                if *degree == 0 {
40                    queue.push_back(v);
41                }
42            }
43        }
44    }
45
46    // Detect cycles
47    let mut cycles = Vec::new();
48    if sorted.len() < all_nodes.len() {
49        // Simple cycle detection: nodes with remaining in-degree are part of cycles
50        let remaining: HashSet<_> = all_nodes
51            .into_iter()
52            .filter(|id| !sorted.contains(id))
53            .collect();
54        // For MVP, we just return them as a single group of "cyclic nodes"
55        // In a real system, we'd use Tarjan's or similar to find SCCs.
56        cycles.push(remaining.into_iter().collect());
57    }
58
59    (sorted, cycles)
60}
61
62/// Returns all nodes that are upstream (ancestors) of the target node, including the target itself.
63/// Only considers Normal edges (not Feedback) to avoid capturing entire feedback loops if not necessary.
64pub fn get_upstream_nodes(graph: &Graph, target_node: StableId) -> HashSet<StableId> {
65    let mut upstream = HashSet::new();
66    let mut stack = vec![target_node];
67
68    while let Some(current) = stack.pop() {
69        if upstream.insert(current) {
70            for edge in graph.edges.values() {
71                if edge.target.node_id == current && edge.kind == EdgeKind::Normal {
72                    stack.push(edge.source.node_id);
73                }
74            }
75        }
76    }
77    upstream
78}
79
80/// Creates a new minimal Graph containing only the specified nodes and the edges connecting them.
81pub fn clone_subgraph(graph: &Graph, node_ids: &HashSet<StableId>) -> Graph {
82    let mut new_graph = Graph::new();
83    
84    for &id in node_ids {
85        if let Some(node) = graph.nodes.get(&id) {
86            new_graph.nodes.insert(id, node.clone());
87        }
88    }
89    
90    for edge in graph.edges.values() {
91        if node_ids.contains(&edge.source.node_id) && node_ids.contains(&edge.target.node_id) {
92            new_graph.edges.insert(edge.id, edge.clone());
93        }
94    }
95    
96    // Copy modulations if both source and target are in the subgraph
97    for m in graph.modulations.values() {
98        if node_ids.contains(&m.source.node_id) && node_ids.contains(&m.target_node) {
99            new_graph.modulations.insert(m.id, m.clone());
100        }
101    }
102    
103    new_graph
104}