datafusion_sql/
resolve.rs1use crate::TableReference;
19use std::collections::BTreeSet;
20use std::ops::ControlFlow;
21
22use crate::parser::{CopyToSource, CopyToStatement, Statement as DFStatement};
23use crate::planner::object_name_to_table_reference;
24use sqlparser::ast::*;
25
26const INFORMATION_SCHEMA: &str = "information_schema";
29const TABLES: &str = "tables";
30const VIEWS: &str = "views";
31const COLUMNS: &str = "columns";
32const DF_SETTINGS: &str = "df_settings";
33const SCHEMATA: &str = "schemata";
34const ROUTINES: &str = "routines";
35const PARAMETERS: &str = "parameters";
36
37const INFORMATION_SCHEMA_TABLES: &[&str] = &[
39 TABLES,
40 VIEWS,
41 COLUMNS,
42 DF_SETTINGS,
43 SCHEMATA,
44 ROUTINES,
45 PARAMETERS,
46];
47
48struct RelationVisitor {
49 relations: BTreeSet<ObjectName>,
50 all_ctes: BTreeSet<ObjectName>,
51 ctes_in_scope: Vec<ObjectName>,
52}
53
54impl RelationVisitor {
55 fn insert_relation(&mut self, relation: &ObjectName) {
57 if !self.relations.contains(relation) && !self.ctes_in_scope.contains(relation) {
58 self.relations.insert(relation.clone());
59 }
60 }
61}
62
63impl Visitor for RelationVisitor {
64 type Break = ();
65
66 fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<()> {
67 self.insert_relation(relation);
68 ControlFlow::Continue(())
69 }
70
71 fn pre_visit_query(&mut self, q: &Query) -> ControlFlow<Self::Break> {
72 if let Some(with) = &q.with {
73 for cte in &with.cte_tables {
74 if !with.recursive {
79 let _ = cte.visit(self);
82 }
83 self.ctes_in_scope
84 .push(ObjectName::from(vec![cte.alias.name.clone()]));
85 }
86 }
87 ControlFlow::Continue(())
88 }
89
90 fn post_visit_query(&mut self, q: &Query) -> ControlFlow<Self::Break> {
91 if let Some(with) = &q.with {
92 for _ in &with.cte_tables {
93 self.all_ctes.insert(self.ctes_in_scope.pop().unwrap());
95 }
96 }
97 ControlFlow::Continue(())
98 }
99
100 fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<()> {
101 if let Statement::ShowCreate {
102 obj_type: ShowCreateObject::Table | ShowCreateObject::View,
103 obj_name,
104 } = statement
105 {
106 self.insert_relation(obj_name)
107 }
108
109 let requires_information_schema = matches!(
111 statement,
112 Statement::ShowFunctions { .. }
113 | Statement::ShowVariable { .. }
114 | Statement::ShowStatus { .. }
115 | Statement::ShowVariables { .. }
116 | Statement::ShowCreate { .. }
117 | Statement::ShowColumns { .. }
118 | Statement::ShowTables { .. }
119 | Statement::ShowCollation { .. }
120 );
121 if requires_information_schema {
122 for s in INFORMATION_SCHEMA_TABLES {
123 self.relations.insert(ObjectName::from(vec![
124 Ident::new(INFORMATION_SCHEMA),
125 Ident::new(*s),
126 ]));
127 }
128 }
129 ControlFlow::Continue(())
130 }
131}
132
133fn visit_statement(statement: &DFStatement, visitor: &mut RelationVisitor) {
134 match statement {
135 DFStatement::Statement(s) => {
136 let _ = s.as_ref().visit(visitor);
137 }
138 DFStatement::CreateExternalTable(table) => {
139 visitor.relations.insert(table.name.clone());
140 }
141 DFStatement::CopyTo(CopyToStatement { source, .. }) => match source {
142 CopyToSource::Relation(table_name) => {
143 visitor.insert_relation(table_name);
144 }
145 CopyToSource::Query(query) => {
146 let _ = query.visit(visitor);
147 }
148 },
149 DFStatement::Explain(explain) => visit_statement(&explain.statement, visitor),
150 DFStatement::Reset(_) => {}
151 }
152}
153
154pub fn resolve_table_references(
189 statement: &crate::parser::Statement,
190 enable_ident_normalization: bool,
191) -> datafusion_common::Result<(Vec<TableReference>, Vec<TableReference>)> {
192 let mut visitor = RelationVisitor {
193 relations: BTreeSet::new(),
194 all_ctes: BTreeSet::new(),
195 ctes_in_scope: vec![],
196 };
197
198 visit_statement(statement, &mut visitor);
199
200 let table_refs = visitor
201 .relations
202 .into_iter()
203 .map(|x| object_name_to_table_reference(x, enable_ident_normalization))
204 .collect::<datafusion_common::Result<_>>()?;
205 let ctes = visitor
206 .all_ctes
207 .into_iter()
208 .map(|x| object_name_to_table_reference(x, enable_ident_normalization))
209 .collect::<datafusion_common::Result<_>>()?;
210 Ok((table_refs, ctes))
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 #[test]
218 fn resolve_table_references_shadowed_cte() {
219 use crate::parser::DFParser;
220
221 let query = "WITH t AS (SELECT * FROM t) SELECT * FROM t";
224 let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
225 let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
226 assert_eq!(table_refs.len(), 1);
227 assert_eq!(ctes.len(), 1);
228 assert_eq!(ctes[0].to_string(), "t");
229 assert_eq!(table_refs[0].to_string(), "t");
230
231 let query = "(with t as (select 1) select * from t) union (select * from t)";
233 let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
234 let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
235 assert_eq!(table_refs.len(), 1);
236 assert_eq!(ctes.len(), 1);
237 assert_eq!(ctes[0].to_string(), "t");
238 assert_eq!(table_refs[0].to_string(), "t");
239
240 let query = "(with t as (with u as (select 1) select * from u) select * from u cross join t)";
244 let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
245 let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
246 assert_eq!(table_refs.len(), 1);
247 assert_eq!(ctes.len(), 2);
248 assert_eq!(ctes[0].to_string(), "t");
249 assert_eq!(ctes[1].to_string(), "u");
250 assert_eq!(table_refs[0].to_string(), "u");
251 }
252
253 #[test]
254 fn resolve_table_references_recursive_cte() {
255 use crate::parser::DFParser;
256
257 let query = "
258 WITH RECURSIVE nodes AS (
259 SELECT 1 as id
260 UNION ALL
261 SELECT id + 1 as id
262 FROM nodes
263 WHERE id < 10
264 )
265 SELECT * FROM nodes
266 ";
267 let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
268 let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
269 assert_eq!(table_refs.len(), 0);
270 assert_eq!(ctes.len(), 1);
271 assert_eq!(ctes[0].to_string(), "nodes");
272 }
273}