datafusion_sql/
resolve.rs1use crate::TableReference;
19use std::collections::BTreeSet;
20use std::ops::ControlFlow;
21
22use crate::parser::{CopyToSource, CopyToStatement, Statement as DFStatement};
23use crate::planner::object_name_to_table_reference;
24use sqlparser::ast::*;
25
26const INFORMATION_SCHEMA: &str = "information_schema";
29const TABLES: &str = "tables";
30const VIEWS: &str = "views";
31const COLUMNS: &str = "columns";
32const DF_SETTINGS: &str = "df_settings";
33const SCHEMATA: &str = "schemata";
34const ROUTINES: &str = "routines";
35const PARAMETERS: &str = "parameters";
36
37const INFORMATION_SCHEMA_TABLES: &[&str] = &[
39 TABLES,
40 VIEWS,
41 COLUMNS,
42 DF_SETTINGS,
43 SCHEMATA,
44 ROUTINES,
45 PARAMETERS,
46];
47
48struct RelationVisitor {
49 relations: BTreeSet<ObjectName>,
50 all_ctes: BTreeSet<ObjectName>,
51 ctes_in_scope: Vec<ObjectName>,
52}
53
54impl RelationVisitor {
55 fn insert_relation(&mut self, relation: &ObjectName) {
57 if !self.relations.contains(relation) && !self.ctes_in_scope.contains(relation) {
58 self.relations.insert(relation.clone());
59 }
60 }
61}
62
63impl Visitor for RelationVisitor {
64 type Break = ();
65
66 fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<()> {
67 self.insert_relation(relation);
68 ControlFlow::Continue(())
69 }
70
71 fn pre_visit_query(&mut self, q: &Query) -> ControlFlow<Self::Break> {
72 if let Some(with) = &q.with {
73 for cte in &with.cte_tables {
74 if !with.recursive {
79 cte.visit(self);
82 }
83 self.ctes_in_scope
84 .push(ObjectName(vec![cte.alias.name.clone()]));
85 }
86 }
87 ControlFlow::Continue(())
88 }
89
90 fn post_visit_query(&mut self, q: &Query) -> ControlFlow<Self::Break> {
91 if let Some(with) = &q.with {
92 for _ in &with.cte_tables {
93 self.all_ctes.insert(self.ctes_in_scope.pop().unwrap());
95 }
96 }
97 ControlFlow::Continue(())
98 }
99
100 fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<()> {
101 if let Statement::ShowCreate {
102 obj_type: ShowCreateObject::Table | ShowCreateObject::View,
103 obj_name,
104 } = statement
105 {
106 self.insert_relation(obj_name)
107 }
108
109 let requires_information_schema = matches!(
111 statement,
112 Statement::ShowFunctions { .. }
113 | Statement::ShowVariable { .. }
114 | Statement::ShowStatus { .. }
115 | Statement::ShowVariables { .. }
116 | Statement::ShowCreate { .. }
117 | Statement::ShowColumns { .. }
118 | Statement::ShowTables { .. }
119 | Statement::ShowCollation { .. }
120 );
121 if requires_information_schema {
122 for s in INFORMATION_SCHEMA_TABLES {
123 self.relations.insert(ObjectName(vec![
124 Ident::new(INFORMATION_SCHEMA),
125 Ident::new(*s),
126 ]));
127 }
128 }
129 ControlFlow::Continue(())
130 }
131}
132
133fn visit_statement(statement: &DFStatement, visitor: &mut RelationVisitor) {
134 match statement {
135 DFStatement::Statement(s) => {
136 let _ = s.as_ref().visit(visitor);
137 }
138 DFStatement::CreateExternalTable(table) => {
139 visitor.relations.insert(table.name.clone());
140 }
141 DFStatement::CopyTo(CopyToStatement { source, .. }) => match source {
142 CopyToSource::Relation(table_name) => {
143 visitor.insert_relation(table_name);
144 }
145 CopyToSource::Query(query) => {
146 query.visit(visitor);
147 }
148 },
149 DFStatement::Explain(explain) => visit_statement(&explain.statement, visitor),
150 }
151}
152
153pub fn resolve_table_references(
188 statement: &crate::parser::Statement,
189 enable_ident_normalization: bool,
190) -> datafusion_common::Result<(Vec<TableReference>, Vec<TableReference>)> {
191 let mut visitor = RelationVisitor {
192 relations: BTreeSet::new(),
193 all_ctes: BTreeSet::new(),
194 ctes_in_scope: vec![],
195 };
196
197 visit_statement(statement, &mut visitor);
198
199 let table_refs = visitor
200 .relations
201 .into_iter()
202 .map(|x| object_name_to_table_reference(x, enable_ident_normalization))
203 .collect::<datafusion_common::Result<_>>()?;
204 let ctes = visitor
205 .all_ctes
206 .into_iter()
207 .map(|x| object_name_to_table_reference(x, enable_ident_normalization))
208 .collect::<datafusion_common::Result<_>>()?;
209 Ok((table_refs, ctes))
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215
216 #[test]
217 fn resolve_table_references_shadowed_cte() {
218 use crate::parser::DFParser;
219
220 let query = "WITH t AS (SELECT * FROM t) SELECT * FROM t";
223 let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
224 let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
225 assert_eq!(table_refs.len(), 1);
226 assert_eq!(ctes.len(), 1);
227 assert_eq!(ctes[0].to_string(), "t");
228 assert_eq!(table_refs[0].to_string(), "t");
229
230 let query = "(with t as (select 1) select * from t) union (select * from t)";
232 let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
233 let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
234 assert_eq!(table_refs.len(), 1);
235 assert_eq!(ctes.len(), 1);
236 assert_eq!(ctes[0].to_string(), "t");
237 assert_eq!(table_refs[0].to_string(), "t");
238
239 let query = "(with t as (with u as (select 1) select * from u) select * from u cross join t)";
243 let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
244 let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
245 assert_eq!(table_refs.len(), 1);
246 assert_eq!(ctes.len(), 2);
247 assert_eq!(ctes[0].to_string(), "t");
248 assert_eq!(ctes[1].to_string(), "u");
249 assert_eq!(table_refs[0].to_string(), "u");
250 }
251
252 #[test]
253 fn resolve_table_references_recursive_cte() {
254 use crate::parser::DFParser;
255
256 let query = "
257 WITH RECURSIVE nodes AS (
258 SELECT 1 as id
259 UNION ALL
260 SELECT id + 1 as id
261 FROM nodes
262 WHERE id < 10
263 )
264 SELECT * FROM nodes
265 ";
266 let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
267 let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
268 assert_eq!(table_refs.len(), 0);
269 assert_eq!(ctes.len(), 1);
270 assert_eq!(ctes[0].to_string(), "nodes");
271 }
272}