Skip to main content

pg2sqlite_core/transform/
topo.rs

1/// Topological sort for FK dependencies.
2use std::collections::{HashMap, HashSet, VecDeque};
3
4use crate::diagnostics::{Severity, Warning, warning};
5use crate::ir::{Table, TableConstraint};
6
7/// Sort tables in dependency order (tables referenced by FKs come first).
8/// Falls back to alphabetical order if cycles are detected.
9pub fn topological_sort(tables: &mut Vec<Table>, warnings: &mut Vec<Warning>) {
10    let name_to_idx: HashMap<String, usize> = tables
11        .iter()
12        .enumerate()
13        .map(|(i, t)| (t.name.name.normalized.clone(), i))
14        .collect();
15
16    // Build adjacency list: edges from table → tables it depends on
17    let mut in_degree: Vec<usize> = vec![0; tables.len()];
18    let mut dependents: Vec<Vec<usize>> = vec![Vec::new(); tables.len()];
19
20    for (i, table) in tables.iter().enumerate() {
21        let deps = get_fk_dependencies(table);
22        for dep_name in deps {
23            if let Some(&dep_idx) = name_to_idx.get(&dep_name)
24                && dep_idx != i
25            {
26                // dep must come before i
27                dependents[dep_idx].push(i);
28                in_degree[i] += 1;
29            }
30        }
31    }
32
33    // Kahn's algorithm
34    let mut queue: VecDeque<usize> = VecDeque::new();
35    for (i, &deg) in in_degree.iter().enumerate() {
36        if deg == 0 {
37            queue.push_back(i);
38        }
39    }
40
41    // Sort queue entries alphabetically for determinism
42    let mut sorted_queue: Vec<usize> = queue.into_iter().collect();
43    sorted_queue.sort_by(|a, b| {
44        tables[*a]
45            .name
46            .name
47            .normalized
48            .cmp(&tables[*b].name.name.normalized)
49    });
50    let mut queue: VecDeque<usize> = sorted_queue.into_iter().collect();
51
52    let mut order: Vec<usize> = Vec::new();
53
54    while let Some(idx) = queue.pop_front() {
55        order.push(idx);
56        let mut next_ready = Vec::new();
57        for &dep in &dependents[idx] {
58            in_degree[dep] -= 1;
59            if in_degree[dep] == 0 {
60                next_ready.push(dep);
61            }
62        }
63        // Sort newly ready nodes alphabetically
64        next_ready.sort_by(|a, b| {
65            tables[*a]
66                .name
67                .name
68                .normalized
69                .cmp(&tables[*b].name.name.normalized)
70        });
71        queue.extend(next_ready);
72    }
73
74    if order.len() == tables.len() {
75        // Successful topological sort
76        let mut sorted: Vec<Table> = Vec::with_capacity(tables.len());
77        for idx in order {
78            sorted.push(std::mem::replace(
79                &mut tables[idx],
80                Table {
81                    name: crate::ir::QualifiedName::new(crate::ir::Ident::new("")),
82                    columns: vec![],
83                    constraints: vec![],
84                },
85            ));
86        }
87        *tables = sorted;
88    } else {
89        // Cycle detected — fall back to alphabetical
90        let cycle_tables: Vec<String> = tables
91            .iter()
92            .enumerate()
93            .filter(|(i, _)| in_degree[*i] > 0)
94            .map(|(_, t)| t.name.name.normalized.clone())
95            .collect();
96        warnings.push(Warning::new(
97            warning::FK_CYCLE_DETECTED,
98            Severity::Lossy,
99            format!(
100                "Foreign key dependency cycle detected among tables: {}; falling back to alphabetical order",
101                cycle_tables.join(", ")
102            ),
103        ));
104        tables.sort_by(|a, b| a.name.name.normalized.cmp(&b.name.name.normalized));
105    }
106}
107
108/// Extract FK dependency table names from a table.
109fn get_fk_dependencies(table: &Table) -> HashSet<String> {
110    let mut deps = HashSet::new();
111
112    for constraint in &table.constraints {
113        if let TableConstraint::ForeignKey { ref_table, .. } = constraint {
114            deps.insert(ref_table.name.normalized.clone());
115        }
116    }
117
118    for col in &table.columns {
119        if let Some(fk) = &col.references {
120            deps.insert(fk.table.name.normalized.clone());
121        }
122    }
123
124    deps
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130    use crate::ir::*;
131
132    fn make_table(name: &str, fk_refs: Vec<&str>) -> Table {
133        let constraints: Vec<TableConstraint> = fk_refs
134            .into_iter()
135            .map(|ref_name| TableConstraint::ForeignKey {
136                name: None,
137                columns: vec![Ident::new("ref_id")],
138                ref_table: QualifiedName::new(Ident::new(ref_name)),
139                ref_columns: vec![Ident::new("id")],
140                on_delete: None,
141                on_update: None,
142                deferrable: false,
143            })
144            .collect();
145
146        Table {
147            name: QualifiedName::new(Ident::new(name)),
148            columns: vec![],
149            constraints,
150        }
151    }
152
153    #[test]
154    fn test_no_deps_alphabetical() {
155        let mut warnings = Vec::new();
156        let mut tables = vec![
157            make_table("c", vec![]),
158            make_table("a", vec![]),
159            make_table("b", vec![]),
160        ];
161        topological_sort(&mut tables, &mut warnings);
162        let names: Vec<&str> = tables
163            .iter()
164            .map(|t| t.name.name.normalized.as_str())
165            .collect();
166        assert_eq!(names, vec!["a", "b", "c"]);
167        assert!(warnings.is_empty());
168    }
169
170    #[test]
171    fn test_simple_dependency() {
172        let mut warnings = Vec::new();
173        let mut tables = vec![
174            make_table("orders", vec!["users"]),
175            make_table("users", vec![]),
176        ];
177        topological_sort(&mut tables, &mut warnings);
178        let names: Vec<&str> = tables
179            .iter()
180            .map(|t| t.name.name.normalized.as_str())
181            .collect();
182        assert_eq!(names[0], "users");
183        assert_eq!(names[1], "orders");
184        assert!(warnings.is_empty());
185    }
186
187    #[test]
188    fn test_cycle_falls_back_to_alphabetical() {
189        let mut warnings = Vec::new();
190        let mut tables = vec![make_table("b", vec!["a"]), make_table("a", vec!["b"])];
191        topological_sort(&mut tables, &mut warnings);
192        let names: Vec<&str> = tables
193            .iter()
194            .map(|t| t.name.name.normalized.as_str())
195            .collect();
196        assert_eq!(names, vec!["a", "b"]);
197        assert_eq!(warnings.len(), 1);
198        assert_eq!(warnings[0].code, "FK_CYCLE_DETECTED");
199    }
200
201    #[test]
202    fn test_chain_dependency() {
203        let mut warnings = Vec::new();
204        let mut tables = vec![
205            make_table("c", vec!["b"]),
206            make_table("b", vec!["a"]),
207            make_table("a", vec![]),
208        ];
209        topological_sort(&mut tables, &mut warnings);
210        let names: Vec<&str> = tables
211            .iter()
212            .map(|t| t.name.name.normalized.as_str())
213            .collect();
214        assert_eq!(names, vec!["a", "b", "c"]);
215    }
216}