datafusion_sql/
resolve.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::TableReference;
19use std::collections::BTreeSet;
20use std::ops::ControlFlow;
21
22use crate::parser::{CopyToSource, CopyToStatement, Statement as DFStatement};
23use crate::planner::object_name_to_table_reference;
24use sqlparser::ast::*;
25
26// following constants are used in `resolve_table_references`
27// and should be same as `datafusion/catalog/src/information_schema.rs`
28const INFORMATION_SCHEMA: &str = "information_schema";
29const TABLES: &str = "tables";
30const VIEWS: &str = "views";
31const COLUMNS: &str = "columns";
32const DF_SETTINGS: &str = "df_settings";
33const SCHEMATA: &str = "schemata";
34const ROUTINES: &str = "routines";
35const PARAMETERS: &str = "parameters";
36
37/// All information schema tables
38const INFORMATION_SCHEMA_TABLES: &[&str] = &[
39    TABLES,
40    VIEWS,
41    COLUMNS,
42    DF_SETTINGS,
43    SCHEMATA,
44    ROUTINES,
45    PARAMETERS,
46];
47
48struct RelationVisitor {
49    relations: BTreeSet<ObjectName>,
50    all_ctes: BTreeSet<ObjectName>,
51    ctes_in_scope: Vec<ObjectName>,
52}
53
54impl RelationVisitor {
55    /// Record the reference to `relation`, if it's not a CTE reference.
56    fn insert_relation(&mut self, relation: &ObjectName) {
57        if !self.relations.contains(relation) && !self.ctes_in_scope.contains(relation) {
58            self.relations.insert(relation.clone());
59        }
60    }
61}
62
63impl Visitor for RelationVisitor {
64    type Break = ();
65
66    fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<()> {
67        self.insert_relation(relation);
68        ControlFlow::Continue(())
69    }
70
71    fn pre_visit_query(&mut self, q: &Query) -> ControlFlow<Self::Break> {
72        if let Some(with) = &q.with {
73            for cte in &with.cte_tables {
74                // The non-recursive CTE name is not in scope when evaluating the CTE itself, so this is valid:
75                // `WITH t AS (SELECT * FROM t) SELECT * FROM t`
76                // Where the first `t` refers to a predefined table. So we are careful here
77                // to visit the CTE first, before putting it in scope.
78                if !with.recursive {
79                    // This is a bit hackish as the CTE will be visited again as part of visiting `q`,
80                    // but thankfully `insert_relation` is idempotent.
81                    let _ = cte.visit(self);
82                }
83                self.ctes_in_scope
84                    .push(ObjectName::from(vec![cte.alias.name.clone()]));
85            }
86        }
87        ControlFlow::Continue(())
88    }
89
90    fn post_visit_query(&mut self, q: &Query) -> ControlFlow<Self::Break> {
91        if let Some(with) = &q.with {
92            for _ in &with.cte_tables {
93                // Unwrap: We just pushed these in `pre_visit_query`
94                self.all_ctes.insert(self.ctes_in_scope.pop().unwrap());
95            }
96        }
97        ControlFlow::Continue(())
98    }
99
100    fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<()> {
101        if let Statement::ShowCreate {
102            obj_type: ShowCreateObject::Table | ShowCreateObject::View,
103            obj_name,
104        } = statement
105        {
106            self.insert_relation(obj_name)
107        }
108
109        // SHOW statements will later be rewritten into a SELECT from the information_schema
110        let requires_information_schema = matches!(
111            statement,
112            Statement::ShowFunctions { .. }
113                | Statement::ShowVariable { .. }
114                | Statement::ShowStatus { .. }
115                | Statement::ShowVariables { .. }
116                | Statement::ShowCreate { .. }
117                | Statement::ShowColumns { .. }
118                | Statement::ShowTables { .. }
119                | Statement::ShowCollation { .. }
120        );
121        if requires_information_schema {
122            for s in INFORMATION_SCHEMA_TABLES {
123                self.relations.insert(ObjectName::from(vec![
124                    Ident::new(INFORMATION_SCHEMA),
125                    Ident::new(*s),
126                ]));
127            }
128        }
129        ControlFlow::Continue(())
130    }
131}
132
133fn visit_statement(statement: &DFStatement, visitor: &mut RelationVisitor) {
134    match statement {
135        DFStatement::Statement(s) => {
136            let _ = s.as_ref().visit(visitor);
137        }
138        DFStatement::CreateExternalTable(table) => {
139            visitor.relations.insert(table.name.clone());
140        }
141        DFStatement::CopyTo(CopyToStatement { source, .. }) => match source {
142            CopyToSource::Relation(table_name) => {
143                visitor.insert_relation(table_name);
144            }
145            CopyToSource::Query(query) => {
146                let _ = query.visit(visitor);
147            }
148        },
149        DFStatement::Explain(explain) => visit_statement(&explain.statement, visitor),
150        DFStatement::Reset(_) => {}
151    }
152}
153
154/// Collects all tables and views referenced in the SQL statement. CTEs are collected separately.
155/// This can be used to determine which tables need to be in the catalog for a query to be planned.
156///
157/// # Returns
158///
159/// A `(table_refs, ctes)` tuple, the first element contains table and view references and the second
160/// element contains any CTE aliases that were defined and possibly referenced.
161///
162/// ## Example
163///
164/// ```
165/// # use datafusion_sql::parser::DFParser;
166/// # use datafusion_sql::resolve::resolve_table_references;
167/// let query = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
168/// let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
169/// let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
170/// assert_eq!(table_refs.len(), 2);
171/// assert_eq!(table_refs[0].to_string(), "bar");
172/// assert_eq!(table_refs[1].to_string(), "foo");
173/// assert_eq!(ctes.len(), 0);
174/// ```
175///
176/// ## Example with CTEs  
177///  
178/// ```  
179/// # use datafusion_sql::parser::DFParser;
180/// # use datafusion_sql::resolve::resolve_table_references;
181/// let query = "with my_cte as (values (1), (2)) SELECT * from my_cte;";
182/// let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
183/// let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
184/// assert_eq!(table_refs.len(), 0);
185/// assert_eq!(ctes.len(), 1);
186/// assert_eq!(ctes[0].to_string(), "my_cte");
187/// ```
188pub fn resolve_table_references(
189    statement: &crate::parser::Statement,
190    enable_ident_normalization: bool,
191) -> datafusion_common::Result<(Vec<TableReference>, Vec<TableReference>)> {
192    let mut visitor = RelationVisitor {
193        relations: BTreeSet::new(),
194        all_ctes: BTreeSet::new(),
195        ctes_in_scope: vec![],
196    };
197
198    visit_statement(statement, &mut visitor);
199
200    let table_refs = visitor
201        .relations
202        .into_iter()
203        .map(|x| object_name_to_table_reference(x, enable_ident_normalization))
204        .collect::<datafusion_common::Result<_>>()?;
205    let ctes = visitor
206        .all_ctes
207        .into_iter()
208        .map(|x| object_name_to_table_reference(x, enable_ident_normalization))
209        .collect::<datafusion_common::Result<_>>()?;
210    Ok((table_refs, ctes))
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    #[test]
218    fn resolve_table_references_shadowed_cte() {
219        use crate::parser::DFParser;
220
221        // An interesting edge case where the `t` name is used both as an ordinary table reference
222        // and as a CTE reference.
223        let query = "WITH t AS (SELECT * FROM t) SELECT * FROM t";
224        let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
225        let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
226        assert_eq!(table_refs.len(), 1);
227        assert_eq!(ctes.len(), 1);
228        assert_eq!(ctes[0].to_string(), "t");
229        assert_eq!(table_refs[0].to_string(), "t");
230
231        // UNION is a special case where the CTE is not in scope for the second branch.
232        let query = "(with t as (select 1) select * from t) union (select * from t)";
233        let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
234        let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
235        assert_eq!(table_refs.len(), 1);
236        assert_eq!(ctes.len(), 1);
237        assert_eq!(ctes[0].to_string(), "t");
238        assert_eq!(table_refs[0].to_string(), "t");
239
240        // Nested CTEs are also handled.
241        // Here the first `u` is a CTE, but the second `u` is a table reference.
242        // While `t` is always a CTE.
243        let query = "(with t as (with u as (select 1) select * from u) select * from u cross join t)";
244        let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
245        let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
246        assert_eq!(table_refs.len(), 1);
247        assert_eq!(ctes.len(), 2);
248        assert_eq!(ctes[0].to_string(), "t");
249        assert_eq!(ctes[1].to_string(), "u");
250        assert_eq!(table_refs[0].to_string(), "u");
251    }
252
253    #[test]
254    fn resolve_table_references_recursive_cte() {
255        use crate::parser::DFParser;
256
257        let query = "
258            WITH RECURSIVE nodes AS ( 
259                SELECT 1 as id
260                UNION ALL 
261                SELECT id + 1 as id 
262                FROM nodes
263                WHERE id < 10
264            )
265            SELECT * FROM nodes
266        ";
267        let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
268        let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
269        assert_eq!(table_refs.len(), 0);
270        assert_eq!(ctes.len(), 1);
271        assert_eq!(ctes[0].to_string(), "nodes");
272    }
273}