Skip to main content

sql_splitter/graph/
view.rs

1//! Graph view with filtering and focus capabilities for ERD generation.
2
3use crate::schema::{ColumnType, SchemaGraph};
4use ahash::{AHashMap, AHashSet};
5use glob::Pattern;
6use std::collections::VecDeque;
7
8/// Information about a column in a table
9#[derive(Debug, Clone)]
10pub struct ColumnInfo {
11    /// Column name
12    pub name: String,
13    /// Column type (as string for display)
14    pub col_type: String,
15    /// Whether this column is a primary key
16    pub is_primary_key: bool,
17    /// Whether this column is a foreign key
18    pub is_foreign_key: bool,
19    /// Whether this column is nullable
20    pub is_nullable: bool,
21    /// If FK, which table it references
22    pub references_table: Option<String>,
23    /// If FK, which column it references
24    pub references_column: Option<String>,
25}
26
27/// Information about a table for ERD rendering
28#[derive(Debug, Clone)]
29pub struct TableInfo {
30    /// Table name
31    pub name: String,
32    /// All columns in order
33    pub columns: Vec<ColumnInfo>,
34}
35
36/// Information about an edge (FK relationship) in the graph
37#[derive(Debug, Clone)]
38pub struct EdgeInfo {
39    /// Source table (child with FK)
40    pub from_table: String,
41    /// Source column (FK column)
42    pub from_column: String,
43    /// Target table (parent being referenced)
44    pub to_table: String,
45    /// Target column (referenced column, usually PK)
46    pub to_column: String,
47    /// Relationship cardinality (for ERD display)
48    pub cardinality: Cardinality,
49}
50
51/// Relationship cardinality for ERD
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
53pub enum Cardinality {
54    #[default]
55    ManyToOne, // Most common: child has FK to parent
56    OneToOne,
57    OneToMany,
58    ManyToMany,
59}
60
61impl Cardinality {
62    /// Mermaid ERD notation
63    pub fn as_mermaid(self) -> &'static str {
64        match self {
65            Cardinality::ManyToOne => "}o--||",
66            Cardinality::OneToOne => "||--||",
67            Cardinality::OneToMany => "||--o{",
68            Cardinality::ManyToMany => "}o--o{",
69        }
70    }
71}
72
73/// A filtered view of a schema graph for ERD visualization
74#[derive(Debug)]
75pub struct GraphView {
76    /// Tables included in this view with full column info
77    pub tables: AHashMap<String, TableInfo>,
78    /// Edges between tables (FK relationships)
79    pub edges: Vec<EdgeInfo>,
80}
81
82impl GraphView {
83    /// Create a full view from a schema graph (all tables and edges)
84    pub fn from_schema_graph(graph: &SchemaGraph) -> Self {
85        let mut tables = AHashMap::new();
86        let mut edges = Vec::new();
87
88        // Build FK lookup: which columns are FKs and what they reference
89        let mut fk_lookup: AHashMap<(String, String), (String, String)> = AHashMap::new();
90
91        for table_schema in graph.schema.iter() {
92            for fk in &table_schema.foreign_keys {
93                for (i, col_name) in fk.column_names.iter().enumerate() {
94                    let ref_col = fk.referenced_columns.get(i).cloned().unwrap_or_default();
95                    fk_lookup.insert(
96                        (table_schema.name.clone(), col_name.clone()),
97                        (fk.referenced_table.clone(), ref_col),
98                    );
99                }
100            }
101        }
102
103        // Build table info with full column details
104        for table_schema in graph.schema.iter() {
105            let mut columns = Vec::new();
106
107            for col in &table_schema.columns {
108                let is_fk = fk_lookup.contains_key(&(table_schema.name.clone(), col.name.clone()));
109                let (ref_table, ref_col) = fk_lookup
110                    .get(&(table_schema.name.clone(), col.name.clone()))
111                    .cloned()
112                    .map(|(t, c)| (Some(t), Some(c)))
113                    .unwrap_or((None, None));
114
115                columns.push(ColumnInfo {
116                    name: col.name.clone(),
117                    col_type: format_column_type(&col.col_type),
118                    is_primary_key: col.is_primary_key,
119                    is_foreign_key: is_fk,
120                    is_nullable: col.is_nullable,
121                    references_table: ref_table,
122                    references_column: ref_col,
123                });
124            }
125
126            tables.insert(
127                table_schema.name.clone(),
128                TableInfo {
129                    name: table_schema.name.clone(),
130                    columns,
131                },
132            );
133        }
134
135        // Build edges from FK relationships
136        for table_schema in graph.schema.iter() {
137            for fk in &table_schema.foreign_keys {
138                // Create one edge per FK column pair
139                for (i, col_name) in fk.column_names.iter().enumerate() {
140                    let ref_col = fk
141                        .referenced_columns
142                        .get(i)
143                        .cloned()
144                        .unwrap_or_else(|| "id".to_string());
145
146                    edges.push(EdgeInfo {
147                        from_table: table_schema.name.clone(),
148                        from_column: col_name.clone(),
149                        to_table: fk.referenced_table.clone(),
150                        to_column: ref_col,
151                        cardinality: Cardinality::ManyToOne,
152                    });
153                }
154            }
155        }
156
157        Self { tables, edges }
158    }
159
160    /// Filter to include only tables matching the given patterns
161    pub fn filter_tables(&mut self, patterns: &[Pattern]) {
162        if patterns.is_empty() {
163            return;
164        }
165
166        let matching: AHashSet<String> = self
167            .tables
168            .keys()
169            .filter(|name| patterns.iter().any(|p| p.matches(name)))
170            .cloned()
171            .collect();
172
173        self.apply_node_filter(&matching);
174    }
175
176    /// Exclude tables matching the given patterns
177    pub fn exclude_tables(&mut self, patterns: &[Pattern]) {
178        if patterns.is_empty() {
179            return;
180        }
181
182        let remaining: AHashSet<String> = self
183            .tables
184            .keys()
185            .filter(|name| !patterns.iter().any(|p| p.matches(name)))
186            .cloned()
187            .collect();
188
189        self.apply_node_filter(&remaining);
190    }
191
192    /// Focus on a specific table and its relationships
193    pub fn focus_table(
194        &mut self,
195        table: &str,
196        transitive: bool,
197        reverse: bool,
198        max_depth: Option<usize>,
199    ) {
200        if !self.tables.contains_key(table) {
201            self.tables.clear();
202            self.edges.clear();
203            return;
204        }
205
206        let mut result_nodes = AHashSet::new();
207        result_nodes.insert(table.to_string());
208
209        // Build adjacency maps for traversal
210        let (outgoing, incoming) = self.build_adjacency_maps();
211
212        if transitive {
213            // Show tables this table depends on (parents, transitively)
214            self.traverse(&outgoing, table, max_depth, &mut result_nodes);
215        }
216
217        if reverse {
218            // Show tables that depend on this table (children, transitively)
219            self.traverse(&incoming, table, max_depth, &mut result_nodes);
220        }
221
222        // If neither transitive nor reverse, show direct connections only
223        if !transitive && !reverse {
224            if let Some(parents) = outgoing.get(table) {
225                for parent in parents {
226                    result_nodes.insert(parent.clone());
227                }
228            }
229            if let Some(children) = incoming.get(table) {
230                for child in children {
231                    result_nodes.insert(child.clone());
232                }
233            }
234        }
235
236        self.apply_node_filter(&result_nodes);
237    }
238
239    /// Keep only tables that are part of cycles
240    pub fn filter_to_cyclic_tables(&mut self, cyclic_tables: &AHashSet<String>) {
241        self.apply_node_filter(cyclic_tables);
242    }
243
244    /// Get the number of tables in the view
245    pub fn table_count(&self) -> usize {
246        self.tables.len()
247    }
248
249    /// Get the number of edges in the view
250    pub fn edge_count(&self) -> usize {
251        self.edges.len()
252    }
253
254    /// Check if the view is empty
255    pub fn is_empty(&self) -> bool {
256        self.tables.is_empty()
257    }
258
259    /// Get tables sorted alphabetically
260    pub fn sorted_tables(&self) -> Vec<&TableInfo> {
261        let mut tables: Vec<_> = self.tables.values().collect();
262        tables.sort_by(|a, b| a.name.cmp(&b.name));
263        tables
264    }
265
266    /// Get table info by name
267    pub fn get_table(&self, name: &str) -> Option<&TableInfo> {
268        self.tables.get(name)
269    }
270
271    // Private helper methods
272
273    fn apply_node_filter(&mut self, keep: &AHashSet<String>) {
274        self.tables.retain(|n, _| keep.contains(n));
275        self.edges
276            .retain(|e| keep.contains(&e.from_table) && keep.contains(&e.to_table));
277    }
278
279    fn build_adjacency_maps(
280        &self,
281    ) -> (AHashMap<String, Vec<String>>, AHashMap<String, Vec<String>>) {
282        let mut outgoing: AHashMap<String, Vec<String>> = AHashMap::new();
283        let mut incoming: AHashMap<String, Vec<String>> = AHashMap::new();
284
285        for edge in &self.edges {
286            outgoing
287                .entry(edge.from_table.clone())
288                .or_default()
289                .push(edge.to_table.clone());
290            incoming
291                .entry(edge.to_table.clone())
292                .or_default()
293                .push(edge.from_table.clone());
294        }
295
296        (outgoing, incoming)
297    }
298
299    fn traverse(
300        &self,
301        adjacency: &AHashMap<String, Vec<String>>,
302        start: &str,
303        max_depth: Option<usize>,
304        result: &mut AHashSet<String>,
305    ) {
306        let mut queue: VecDeque<(String, usize)> = VecDeque::new();
307        queue.push_back((start.to_string(), 0));
308
309        while let Some((current, depth)) = queue.pop_front() {
310            if let Some(max) = max_depth {
311                if depth >= max {
312                    continue;
313                }
314            }
315
316            if let Some(neighbors) = adjacency.get(&current) {
317                for neighbor in neighbors {
318                    if result.insert(neighbor.clone()) {
319                        queue.push_back((neighbor.clone(), depth + 1));
320                    }
321                }
322            }
323        }
324    }
325}
326
327/// Format a ColumnType for display
328fn format_column_type(col_type: &ColumnType) -> String {
329    match col_type {
330        ColumnType::Int => "INT".to_string(),
331        ColumnType::BigInt => "BIGINT".to_string(),
332        ColumnType::Text => "VARCHAR".to_string(),
333        ColumnType::Uuid => "UUID".to_string(),
334        ColumnType::Decimal => "DECIMAL".to_string(),
335        ColumnType::DateTime => "DATETIME".to_string(),
336        ColumnType::Bool => "BOOL".to_string(),
337        ColumnType::Other(s) => s.to_uppercase(),
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344
345    fn create_test_view() -> GraphView {
346        let mut tables = AHashMap::new();
347
348        tables.insert(
349            "users".to_string(),
350            TableInfo {
351                name: "users".to_string(),
352                columns: vec![
353                    ColumnInfo {
354                        name: "id".to_string(),
355                        col_type: "INT".to_string(),
356                        is_primary_key: true,
357                        is_foreign_key: false,
358                        is_nullable: false,
359                        references_table: None,
360                        references_column: None,
361                    },
362                    ColumnInfo {
363                        name: "email".to_string(),
364                        col_type: "VARCHAR".to_string(),
365                        is_primary_key: false,
366                        is_foreign_key: false,
367                        is_nullable: false,
368                        references_table: None,
369                        references_column: None,
370                    },
371                ],
372            },
373        );
374
375        tables.insert(
376            "orders".to_string(),
377            TableInfo {
378                name: "orders".to_string(),
379                columns: vec![
380                    ColumnInfo {
381                        name: "id".to_string(),
382                        col_type: "INT".to_string(),
383                        is_primary_key: true,
384                        is_foreign_key: false,
385                        is_nullable: false,
386                        references_table: None,
387                        references_column: None,
388                    },
389                    ColumnInfo {
390                        name: "user_id".to_string(),
391                        col_type: "INT".to_string(),
392                        is_primary_key: false,
393                        is_foreign_key: true,
394                        is_nullable: false,
395                        references_table: Some("users".to_string()),
396                        references_column: Some("id".to_string()),
397                    },
398                ],
399            },
400        );
401
402        let edges = vec![EdgeInfo {
403            from_table: "orders".to_string(),
404            from_column: "user_id".to_string(),
405            to_table: "users".to_string(),
406            to_column: "id".to_string(),
407            cardinality: Cardinality::ManyToOne,
408        }];
409
410        GraphView { tables, edges }
411    }
412
413    #[test]
414    fn test_table_info() {
415        let view = create_test_view();
416        assert_eq!(view.table_count(), 2);
417
418        let users = view.get_table("users").unwrap();
419        assert_eq!(users.columns.len(), 2);
420        assert!(users.columns[0].is_primary_key);
421    }
422
423    #[test]
424    fn test_edge_info() {
425        let view = create_test_view();
426        assert_eq!(view.edge_count(), 1);
427
428        let edge = &view.edges[0];
429        assert_eq!(edge.from_table, "orders");
430        assert_eq!(edge.from_column, "user_id");
431        assert_eq!(edge.to_table, "users");
432        assert_eq!(edge.to_column, "id");
433    }
434
435    #[test]
436    fn test_exclude_tables() {
437        let mut view = create_test_view();
438        let patterns = vec![Pattern::new("orders").unwrap()];
439        view.exclude_tables(&patterns);
440
441        assert!(!view.tables.contains_key("orders"));
442        assert!(view.tables.contains_key("users"));
443        assert_eq!(view.edge_count(), 0); // Edge removed since orders is gone
444    }
445}