1use std::collections::BTreeSet;
19use std::ops::ControlFlow;
20
21use datafusion_common::{DataFusionError, Result};
22
23use crate::TableReference;
24use crate::parser::{CopyToSource, CopyToStatement, Statement as DFStatement};
25use crate::planner::object_name_to_table_reference;
26use sqlparser::ast::*;
27
28const INFORMATION_SCHEMA: &str = "information_schema";
31const TABLES: &str = "tables";
32const VIEWS: &str = "views";
33const COLUMNS: &str = "columns";
34const DF_SETTINGS: &str = "df_settings";
35const SCHEMATA: &str = "schemata";
36const ROUTINES: &str = "routines";
37const PARAMETERS: &str = "parameters";
38
39const INFORMATION_SCHEMA_TABLES: &[&str] = &[
41 TABLES,
42 VIEWS,
43 COLUMNS,
44 DF_SETTINGS,
45 SCHEMATA,
46 ROUTINES,
47 PARAMETERS,
48];
49
50struct RelationVisitor {
53 relations: BTreeSet<TableReference>,
54 all_ctes: BTreeSet<TableReference>,
55 ctes_in_scope: Vec<TableReference>,
56 enable_ident_normalization: bool,
57}
58
59impl RelationVisitor {
60 fn insert_relation(&mut self, relation: &ObjectName) -> ControlFlow<DataFusionError> {
62 match object_name_to_table_reference(
63 relation.clone(),
64 self.enable_ident_normalization,
65 ) {
66 Ok(relation) => {
67 if !self.relations.contains(&relation)
68 && !self.ctes_in_scope.contains(&relation)
69 {
70 self.relations.insert(relation);
71 }
72 ControlFlow::Continue(())
73 }
74 Err(e) => ControlFlow::Break(e),
75 }
76 }
77}
78
79impl Visitor for RelationVisitor {
80 type Break = DataFusionError;
81
82 fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
83 self.insert_relation(relation)
84 }
85
86 fn pre_visit_query(&mut self, q: &Query) -> ControlFlow<Self::Break> {
87 if let Some(with) = &q.with {
88 for cte in &with.cte_tables {
89 if !with.recursive {
94 cte.visit(self)?;
97 }
98 let cte_name = ObjectName::from(vec![cte.alias.name.clone()]);
99 match object_name_to_table_reference(
100 cte_name,
101 self.enable_ident_normalization,
102 ) {
103 Ok(cte_ref) => self.ctes_in_scope.push(cte_ref),
104 Err(e) => return ControlFlow::Break(e),
105 }
106 }
107 }
108 ControlFlow::Continue(())
109 }
110
111 fn post_visit_query(&mut self, q: &Query) -> ControlFlow<Self::Break> {
112 if let Some(with) = &q.with {
113 for _ in &with.cte_tables {
114 self.all_ctes.insert(self.ctes_in_scope.pop().unwrap());
116 }
117 }
118 ControlFlow::Continue(())
119 }
120
121 fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
122 if let Statement::ShowCreate {
123 obj_type: ShowCreateObject::Table | ShowCreateObject::View,
124 obj_name,
125 } = statement
126 {
127 self.insert_relation(obj_name)?;
128 }
129
130 let requires_information_schema = matches!(
132 statement,
133 Statement::ShowFunctions { .. }
134 | Statement::ShowVariable { .. }
135 | Statement::ShowStatus { .. }
136 | Statement::ShowVariables { .. }
137 | Statement::ShowCreate { .. }
138 | Statement::ShowColumns { .. }
139 | Statement::ShowTables { .. }
140 | Statement::ShowCollation { .. }
141 );
142 if requires_information_schema {
143 for s in INFORMATION_SCHEMA_TABLES {
144 let obj = ObjectName::from(vec![
146 Ident::new(INFORMATION_SCHEMA),
147 Ident::new(*s),
148 ]);
149 match object_name_to_table_reference(obj, self.enable_ident_normalization)
150 {
151 Ok(tbl_ref) => {
152 self.relations.insert(tbl_ref);
153 }
154 Err(e) => return ControlFlow::Break(e),
155 }
156 }
157 }
158 ControlFlow::Continue(())
159 }
160}
161
162fn control_flow_to_result(flow: ControlFlow<DataFusionError>) -> Result<()> {
163 match flow {
164 ControlFlow::Continue(()) => Ok(()),
165 ControlFlow::Break(err) => Err(err),
166 }
167}
168
169fn visit_statement(statement: &DFStatement, visitor: &mut RelationVisitor) -> Result<()> {
170 match statement {
171 DFStatement::Statement(s) => {
172 control_flow_to_result(s.as_ref().visit(visitor))?;
173 }
174 DFStatement::CreateExternalTable(table) => {
175 control_flow_to_result(visitor.insert_relation(&table.name))?;
176 }
177 DFStatement::CopyTo(CopyToStatement { source, .. }) => match source {
178 CopyToSource::Relation(table_name) => {
179 control_flow_to_result(visitor.insert_relation(table_name))?;
180 }
181 CopyToSource::Query(query) => {
182 control_flow_to_result(query.visit(visitor))?;
183 }
184 },
185 DFStatement::Explain(explain) => {
186 visit_statement(&explain.statement, visitor)?;
187 }
188 DFStatement::Reset(_) => {}
189 }
190 Ok(())
191}
192
193pub fn resolve_table_references(
228 statement: &crate::parser::Statement,
229 enable_ident_normalization: bool,
230) -> Result<(Vec<TableReference>, Vec<TableReference>)> {
231 let mut visitor = RelationVisitor {
232 relations: BTreeSet::new(),
233 all_ctes: BTreeSet::new(),
234 ctes_in_scope: vec![],
235 enable_ident_normalization,
236 };
237
238 visit_statement(statement, &mut visitor)?;
239
240 Ok((
241 visitor.relations.into_iter().collect(),
242 visitor.all_ctes.into_iter().collect(),
243 ))
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249
250 #[test]
251 fn resolve_table_references_shadowed_cte() {
252 use crate::parser::DFParser;
253
254 let query = "WITH t AS (SELECT * FROM t) SELECT * FROM t";
257 let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
258 let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
259 assert_eq!(table_refs.len(), 1);
260 assert_eq!(ctes.len(), 1);
261 assert_eq!(ctes[0].to_string(), "t");
262 assert_eq!(table_refs[0].to_string(), "t");
263
264 let query = "(with t as (select 1) select * from t) union (select * from t)";
266 let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
267 let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
268 assert_eq!(table_refs.len(), 1);
269 assert_eq!(ctes.len(), 1);
270 assert_eq!(ctes[0].to_string(), "t");
271 assert_eq!(table_refs[0].to_string(), "t");
272
273 let query = "(with t as (with u as (select 1) select * from u) select * from u cross join t)";
277 let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
278 let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
279 assert_eq!(table_refs.len(), 1);
280 assert_eq!(ctes.len(), 2);
281 assert_eq!(ctes[0].to_string(), "t");
282 assert_eq!(ctes[1].to_string(), "u");
283 assert_eq!(table_refs[0].to_string(), "u");
284 }
285
286 #[test]
287 fn resolve_table_references_recursive_cte() {
288 use crate::parser::DFParser;
289
290 let query = "
291 WITH RECURSIVE nodes AS (
292 SELECT 1 as id
293 UNION ALL
294 SELECT id + 1 as id
295 FROM nodes
296 WHERE id < 10
297 )
298 SELECT * FROM nodes
299 ";
300 let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
301 let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
302 assert_eq!(table_refs.len(), 0);
303 assert_eq!(ctes.len(), 1);
304 assert_eq!(ctes[0].to_string(), "nodes");
305 }
306
307 #[test]
308 fn resolve_table_references_cte_with_quoted_reference() {
309 use crate::parser::DFParser;
310
311 let query = r#"with barbaz as (select 1) select * from "barbaz""#;
312 let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
313 let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
314 assert_eq!(ctes.len(), 1);
315 assert_eq!(ctes[0].to_string(), "barbaz");
316 assert_eq!(table_refs.len(), 0);
318 }
319
320 #[test]
321 fn resolve_table_references_cte_with_quoted_reference_normalization_off() {
322 use crate::parser::DFParser;
323
324 let query = r#"with barbaz as (select 1) select * from "barbaz""#;
325 let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
326 let (table_refs, ctes) = resolve_table_references(&statement, false).unwrap();
327 assert_eq!(ctes.len(), 1);
328 assert_eq!(ctes[0].to_string(), "barbaz");
329 assert_eq!(table_refs.len(), 0);
331 }
332
333 #[test]
334 fn resolve_table_references_cte_with_quoted_reference_uppercase_normalization_on() {
335 use crate::parser::DFParser;
336
337 let query = r#"with FOObar as (select 1) select * from "FOObar""#;
338 let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
339 let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
340 assert_eq!(ctes.len(), 1);
342 assert_eq!(ctes[0].to_string(), "foobar");
343 assert_eq!(table_refs.len(), 1);
344 assert_eq!(table_refs[0].to_string(), "FOObar");
345 }
346
347 #[test]
348 fn resolve_table_references_cte_with_quoted_reference_uppercase_normalization_off() {
349 use crate::parser::DFParser;
350
351 let query = r#"with FOObar as (select 1) select * from "FOObar""#;
352 let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
353 let (table_refs, ctes) = resolve_table_references(&statement, false).unwrap();
354 assert_eq!(ctes.len(), 1);
356 assert_eq!(ctes[0].to_string(), "FOObar");
357 assert_eq!(table_refs.len(), 0);
358 }
359}