flowscope_core/extractors/
mod.rs

1use sqlparser::ast::{Statement, TableFactor};
2
3pub fn extract_tables(statements: &[Statement]) -> Vec<String> {
4    let mut tables = Vec::new();
5
6    for statement in statements {
7        match statement {
8            Statement::Query(query) => {
9                extract_tables_from_query_body(&query.body, &mut tables);
10            }
11            Statement::Insert(insert) => {
12                tables.push(insert.table.to_string());
13                if let Some(source) = &insert.source {
14                    extract_tables_from_query_body(&source.body, &mut tables);
15                }
16            }
17            Statement::Update { table, from, .. } => {
18                extract_tables_from_table_factor(&table.relation, &mut tables);
19                for join in &table.joins {
20                    extract_tables_from_table_factor(&join.relation, &mut tables);
21                }
22
23                if let Some(from_kind) = from {
24                    match from_kind {
25                        sqlparser::ast::UpdateTableFromKind::BeforeSet(ts)
26                        | sqlparser::ast::UpdateTableFromKind::AfterSet(ts) => {
27                            for t in ts {
28                                extract_tables_from_table_factor(&t.relation, &mut tables);
29                                for join in &t.joins {
30                                    extract_tables_from_table_factor(&join.relation, &mut tables);
31                                }
32                            }
33                        }
34                    }
35                }
36            }
37            Statement::Delete(delete) => {
38                for obj in &delete.tables {
39                    tables.push(obj.to_string());
40                }
41
42                let from_tables = match &delete.from {
43                    sqlparser::ast::FromTable::WithFromKeyword(ts)
44                    | sqlparser::ast::FromTable::WithoutKeyword(ts) => ts,
45                };
46
47                for t in from_tables {
48                    extract_tables_from_table_factor(&t.relation, &mut tables);
49                    for join in &t.joins {
50                        extract_tables_from_table_factor(&join.relation, &mut tables);
51                    }
52                }
53
54                if let Some(using) = &delete.using {
55                    for t in using {
56                        extract_tables_from_table_factor(&t.relation, &mut tables);
57                        for join in &t.joins {
58                            extract_tables_from_table_factor(&join.relation, &mut tables);
59                        }
60                    }
61                }
62            }
63            Statement::Merge { table, source, .. } => {
64                extract_tables_from_table_factor(table, &mut tables);
65                extract_tables_from_table_factor(source, &mut tables);
66            }
67            _ => {}
68        }
69    }
70
71    tables
72}
73
74fn extract_tables_from_query_body(body: &sqlparser::ast::SetExpr, tables: &mut Vec<String>) {
75    use sqlparser::ast::SetExpr;
76
77    match body {
78        SetExpr::Select(select) => {
79            for table_with_joins in &select.from {
80                extract_tables_from_table_factor(&table_with_joins.relation, tables);
81
82                for join in &table_with_joins.joins {
83                    extract_tables_from_table_factor(&join.relation, tables);
84                }
85            }
86        }
87        SetExpr::Query(query) => {
88            extract_tables_from_query_body(&query.body, tables);
89        }
90        SetExpr::SetOperation { left, right, .. } => {
91            extract_tables_from_query_body(left, tables);
92            extract_tables_from_query_body(right, tables);
93        }
94        SetExpr::Values(_) => {}
95        SetExpr::Insert(stmt) => {
96            if let sqlparser::ast::Statement::Insert(insert) = stmt {
97                tables.push(insert.table.to_string());
98                if let Some(source) = &insert.source {
99                    extract_tables_from_query_body(&source.body, tables);
100                }
101            }
102        }
103        SetExpr::Update(stmt) => {
104            if let sqlparser::ast::Statement::Update { table, from, .. } = stmt {
105                extract_tables_from_table_factor(&table.relation, tables);
106                for join in &table.joins {
107                    extract_tables_from_table_factor(&join.relation, tables);
108                }
109
110                if let Some(from_kind) = from {
111                    match from_kind {
112                        sqlparser::ast::UpdateTableFromKind::BeforeSet(ts)
113                        | sqlparser::ast::UpdateTableFromKind::AfterSet(ts) => {
114                            for t in ts {
115                                extract_tables_from_table_factor(&t.relation, tables);
116                                for join in &t.joins {
117                                    extract_tables_from_table_factor(&join.relation, tables);
118                                }
119                            }
120                        }
121                    }
122                }
123            }
124        }
125        SetExpr::Table(table) => {
126            if let Some(name) = &table.table_name {
127                tables.push(name.clone());
128            }
129        }
130        SetExpr::Delete(stmt) => {
131            if let sqlparser::ast::Statement::Delete(delete) = stmt {
132                for obj in &delete.tables {
133                    tables.push(obj.to_string());
134                }
135
136                let from_tables = match &delete.from {
137                    sqlparser::ast::FromTable::WithFromKeyword(ts)
138                    | sqlparser::ast::FromTable::WithoutKeyword(ts) => ts,
139                };
140
141                for t in from_tables {
142                    extract_tables_from_table_factor(&t.relation, tables);
143                    for join in &t.joins {
144                        extract_tables_from_table_factor(&join.relation, tables);
145                    }
146                }
147
148                if let Some(using) = &delete.using {
149                    for t in using {
150                        extract_tables_from_table_factor(&t.relation, tables);
151                        for join in &t.joins {
152                            extract_tables_from_table_factor(&join.relation, tables);
153                        }
154                    }
155                }
156            }
157        }
158        SetExpr::Merge(stmt) => {
159            if let sqlparser::ast::Statement::Merge { table, source, .. } = stmt {
160                extract_tables_from_table_factor(table, tables);
161                extract_tables_from_table_factor(source, tables);
162            }
163        }
164    }
165}
166
167fn extract_tables_from_table_factor(table_factor: &TableFactor, tables: &mut Vec<String>) {
168    match table_factor {
169        TableFactor::Table { name, .. } => {
170            tables.push(name.to_string());
171        }
172        TableFactor::Derived { subquery, .. } => {
173            extract_tables_from_query_body(&subquery.body, tables);
174        }
175        TableFactor::TableFunction { .. } => {}
176        TableFactor::Function { .. } => {}
177        TableFactor::UNNEST { .. } => {}
178        TableFactor::NestedJoin {
179            table_with_joins, ..
180        } => {
181            extract_tables_from_table_factor(&table_with_joins.relation, tables);
182            for join in &table_with_joins.joins {
183                extract_tables_from_table_factor(&join.relation, tables);
184            }
185        }
186        TableFactor::Pivot { .. } => {}
187        TableFactor::Unpivot { .. } => {}
188        TableFactor::MatchRecognize { .. } => {}
189        TableFactor::JsonTable { .. } => {}
190        // TODO: Implement table extraction for OPENJSON (SQL Server)
191        TableFactor::OpenJsonTable { .. } => {}
192        // TODO: Implement table extraction for XMLTABLE
193        TableFactor::XmlTable { .. } => {}
194        // TODO: Implement table extraction for semantic views
195        TableFactor::SemanticView { .. } => {}
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202    use crate::parser::parse_sql;
203
204    #[test]
205    fn test_extract_single_table() {
206        let sql = "SELECT * FROM users";
207        let statements = parse_sql(sql).unwrap();
208        let tables = extract_tables(&statements);
209        assert_eq!(tables.len(), 1);
210        assert_eq!(tables[0], "users");
211    }
212
213    #[test]
214    fn test_extract_multiple_tables_join() {
215        let sql = "SELECT * FROM users JOIN orders ON users.id = orders.user_id";
216        let statements = parse_sql(sql).unwrap();
217        let tables = extract_tables(&statements);
218        assert_eq!(tables.len(), 2);
219        assert!(tables.contains(&"users".to_string()));
220        assert!(tables.contains(&"orders".to_string()));
221    }
222
223    #[test]
224    fn test_extract_with_schema() {
225        let sql = "SELECT * FROM public.users";
226        let statements = parse_sql(sql).unwrap();
227        let tables = extract_tables(&statements);
228        assert_eq!(tables.len(), 1);
229        assert_eq!(tables[0], "public.users");
230    }
231
232    #[test]
233    fn test_extract_delete() {
234        // DELETE FROM users WHERE id = 1
235        let sql = "DELETE FROM users WHERE id = 1";
236        let statements = parse_sql(sql).unwrap();
237        let tables = extract_tables(&statements);
238        // Currently expected to fail until implemented
239        assert_eq!(tables.len(), 1);
240        assert_eq!(tables[0], "users");
241    }
242
243    #[test]
244    fn test_extract_merge() {
245        // MERGE INTO target t USING source s ON t.id = s.id ...
246        let sql = "MERGE INTO target t USING source s ON t.id = s.id WHEN MATCHED THEN UPDATE SET t.val = s.val";
247        let statements = parse_sql(sql).unwrap();
248        let tables = extract_tables(&statements);
249        // Currently expected to fail until implemented
250        assert_eq!(tables.len(), 2);
251        assert!(tables.contains(&"target".to_string()));
252        assert!(tables.contains(&"source".to_string()));
253    }
254}