Skip to main content

sql_splitter/graph/
analysis.rs

1//! Graph analysis algorithms: cycle detection and topological sort.
2
3use crate::graph::view::GraphView;
4use ahash::{AHashMap, AHashSet};
5
6/// A cycle in the graph (list of table names forming the cycle)
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub struct Cycle {
9    pub tables: Vec<String>,
10}
11
12impl Cycle {
13    /// Check if this is a self-referencing cycle (single table)
14    pub fn is_self_reference(&self) -> bool {
15        self.tables.len() == 1
16    }
17
18    /// Format the cycle for display
19    pub fn display(&self) -> String {
20        if self.is_self_reference() {
21            format!("{} -> {} (self-reference)", self.tables[0], self.tables[0])
22        } else {
23            let mut parts = self.tables.clone();
24            parts.push(self.tables[0].clone()); // Complete the cycle
25            parts.join(" -> ")
26        }
27    }
28}
29
30/// Find all cycles in the graph using Tarjan's SCC algorithm
31pub fn find_cycles(view: &GraphView) -> Vec<Cycle> {
32    let mut finder = TarjanSCC::new(view);
33    finder.find_sccs();
34
35    let mut cycles = Vec::new();
36
37    for scc in &finder.sccs {
38        if scc.len() == 1 {
39            // Check if it's a self-referencing table
40            let table = &scc[0];
41            if view
42                .edges
43                .iter()
44                .any(|e| &e.from_table == table && &e.to_table == table)
45            {
46                cycles.push(Cycle {
47                    tables: scc.clone(),
48                });
49            }
50        } else if scc.len() > 1 {
51            // Multi-table cycle
52            cycles.push(Cycle {
53                tables: scc.clone(),
54            });
55        }
56    }
57
58    cycles
59}
60
61/// Get all tables that are part of any cycle
62pub fn cyclic_tables(view: &GraphView) -> AHashSet<String> {
63    let cycles = find_cycles(view);
64    let mut tables = AHashSet::new();
65    for cycle in cycles {
66        for table in cycle.tables {
67            tables.insert(table);
68        }
69    }
70    tables
71}
72
73/// Tarjan's Strongly Connected Components algorithm
74struct TarjanSCC<'a> {
75    view: &'a GraphView,
76    index_counter: usize,
77    stack: Vec<String>,
78    on_stack: AHashSet<String>,
79    indices: AHashMap<String, usize>,
80    lowlinks: AHashMap<String, usize>,
81    sccs: Vec<Vec<String>>,
82    adjacency: AHashMap<String, Vec<String>>,
83}
84
85impl<'a> TarjanSCC<'a> {
86    fn new(view: &'a GraphView) -> Self {
87        // Build adjacency list from tables
88        let mut adjacency: AHashMap<String, Vec<String>> = AHashMap::new();
89        for table_name in view.tables.keys() {
90            adjacency.insert(table_name.clone(), Vec::new());
91        }
92        for edge in &view.edges {
93            if view.tables.contains_key(&edge.from_table)
94                && view.tables.contains_key(&edge.to_table)
95            {
96                adjacency
97                    .entry(edge.from_table.clone())
98                    .or_default()
99                    .push(edge.to_table.clone());
100            }
101        }
102
103        Self {
104            view,
105            index_counter: 0,
106            stack: Vec::new(),
107            on_stack: AHashSet::new(),
108            indices: AHashMap::new(),
109            lowlinks: AHashMap::new(),
110            sccs: Vec::new(),
111            adjacency,
112        }
113    }
114
115    fn find_sccs(&mut self) {
116        let nodes: Vec<_> = self.view.tables.keys().cloned().collect();
117        for node in nodes {
118            if !self.indices.contains_key(&node) {
119                self.strongconnect(&node);
120            }
121        }
122    }
123
124    fn strongconnect(&mut self, v: &str) {
125        // Set the depth index for v
126        self.indices.insert(v.to_string(), self.index_counter);
127        self.lowlinks.insert(v.to_string(), self.index_counter);
128        self.index_counter += 1;
129        self.stack.push(v.to_string());
130        self.on_stack.insert(v.to_string());
131
132        // Consider successors of v
133        if let Some(neighbors) = self.adjacency.get(v).cloned() {
134            for w in neighbors {
135                if !self.indices.contains_key(&w) {
136                    // Successor w has not yet been visited; recurse on it
137                    self.strongconnect(&w);
138                    let v_lowlink = self.lowlinks[v];
139                    let w_lowlink = self.lowlinks[&w];
140                    self.lowlinks
141                        .insert(v.to_string(), v_lowlink.min(w_lowlink));
142                } else if self.on_stack.contains(&w) {
143                    // Successor w is in stack S and hence in the current SCC
144                    let v_lowlink = self.lowlinks[v];
145                    let w_index = self.indices[&w];
146                    self.lowlinks.insert(v.to_string(), v_lowlink.min(w_index));
147                }
148            }
149        }
150
151        // If v is a root node, pop the stack and generate an SCC
152        if self.lowlinks[v] == self.indices[v] {
153            let mut scc = Vec::new();
154            loop {
155                let w = self.stack.pop().unwrap();
156                self.on_stack.remove(&w);
157                scc.push(w.clone());
158                if w == v {
159                    break;
160                }
161            }
162            self.sccs.push(scc);
163        }
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use crate::graph::view::{Cardinality, TableInfo};
171
172    fn create_simple_table(name: &str) -> TableInfo {
173        TableInfo {
174            name: name.to_string(),
175            columns: vec![],
176        }
177    }
178
179    fn create_edge(from: &str, to: &str) -> crate::graph::EdgeInfo {
180        crate::graph::EdgeInfo {
181            from_table: from.to_string(),
182            from_column: "fk".to_string(),
183            to_table: to.to_string(),
184            to_column: "id".to_string(),
185            cardinality: Cardinality::ManyToOne,
186        }
187    }
188
189    fn create_acyclic_view() -> GraphView {
190        let mut tables = AHashMap::new();
191        tables.insert("users".to_string(), create_simple_table("users"));
192        tables.insert("orders".to_string(), create_simple_table("orders"));
193        tables.insert("products".to_string(), create_simple_table("products"));
194
195        let edges = vec![create_edge("orders", "users")];
196
197        GraphView { tables, edges }
198    }
199
200    fn create_self_ref_view() -> GraphView {
201        let mut tables = AHashMap::new();
202        tables.insert("categories".to_string(), create_simple_table("categories"));
203
204        let edges = vec![create_edge("categories", "categories")];
205
206        GraphView { tables, edges }
207    }
208
209    fn create_multi_cycle_view() -> GraphView {
210        let mut tables = AHashMap::new();
211        tables.insert("a".to_string(), create_simple_table("a"));
212        tables.insert("b".to_string(), create_simple_table("b"));
213        tables.insert("c".to_string(), create_simple_table("c"));
214
215        let edges = vec![
216            create_edge("a", "b"),
217            create_edge("b", "c"),
218            create_edge("c", "a"),
219        ];
220
221        GraphView { tables, edges }
222    }
223
224    #[test]
225    fn test_no_cycles() {
226        let view = create_acyclic_view();
227        let cycles = find_cycles(&view);
228        assert!(cycles.is_empty());
229    }
230
231    #[test]
232    fn test_self_reference_cycle() {
233        let view = create_self_ref_view();
234        let cycles = find_cycles(&view);
235        assert_eq!(cycles.len(), 1);
236        assert!(cycles[0].is_self_reference());
237        assert_eq!(cycles[0].tables, vec!["categories"]);
238    }
239
240    #[test]
241    fn test_multi_table_cycle() {
242        let view = create_multi_cycle_view();
243        let cycles = find_cycles(&view);
244        assert_eq!(cycles.len(), 1);
245        assert!(!cycles[0].is_self_reference());
246        assert_eq!(cycles[0].tables.len(), 3);
247    }
248
249    #[test]
250    fn test_cyclic_tables() {
251        let view = create_multi_cycle_view();
252        let tables = cyclic_tables(&view);
253        assert!(tables.contains("a"));
254        assert!(tables.contains("b"));
255        assert!(tables.contains("c"));
256    }
257}