Skip to main content

polyglot_sql/optimizer/
qualify_tables.rs

1//! Table Qualification Module
2//!
3//! This module provides functionality for qualifying table references in SQL queries
4//! with their database and catalog names.
5//!
6//! Ported from sqlglot's optimizer/qualify_tables.py
7
8use crate::dialects::DialectType;
9use crate::expressions::{Expression, Identifier, Select, TableRef};
10use crate::helper::name_sequence;
11use crate::optimizer::normalize_identifiers::{
12    get_normalization_strategy, normalize_identifier, NormalizationStrategy,
13};
14use std::collections::{HashMap, HashSet};
15
16/// Options for table qualification
17#[derive(Debug, Clone, Default)]
18pub struct QualifyTablesOptions {
19    /// Default database name to add to unqualified tables
20    pub db: Option<String>,
21    /// Default catalog name to add to tables that have a db but no catalog
22    pub catalog: Option<String>,
23    /// The dialect to use for normalization
24    pub dialect: Option<DialectType>,
25    /// Whether to use canonical aliases (_0, _1, ...) instead of table names
26    pub canonicalize_table_aliases: bool,
27}
28
29impl QualifyTablesOptions {
30    pub fn new() -> Self {
31        Self::default()
32    }
33
34    pub fn with_db(mut self, db: impl Into<String>) -> Self {
35        self.db = Some(db.into());
36        self
37    }
38
39    pub fn with_catalog(mut self, catalog: impl Into<String>) -> Self {
40        self.catalog = Some(catalog.into());
41        self
42    }
43
44    pub fn with_dialect(mut self, dialect: DialectType) -> Self {
45        self.dialect = Some(dialect);
46        self
47    }
48
49    pub fn with_canonical_aliases(mut self) -> Self {
50        self.canonicalize_table_aliases = true;
51        self
52    }
53}
54
55/// Rewrite SQL AST to have fully qualified tables.
56///
57/// This function:
58/// - Adds database/catalog prefixes to table references
59/// - Ensures all tables have aliases
60/// - Optionally canonicalizes aliases to _0, _1, etc.
61///
62/// # Examples
63///
64/// ```ignore
65/// // SELECT 1 FROM tbl -> SELECT 1 FROM db.tbl AS tbl
66/// let options = QualifyTablesOptions::new().with_db("db");
67/// let qualified = qualify_tables(expression, &options);
68/// ```
69///
70/// # Arguments
71/// * `expression` - The expression to qualify
72/// * `options` - Qualification options
73///
74/// # Returns
75/// The qualified expression
76pub fn qualify_tables(expression: Expression, options: &QualifyTablesOptions) -> Expression {
77    let strategy = get_normalization_strategy(options.dialect);
78    let mut next_alias = name_sequence("_");
79
80    match expression {
81        Expression::Select(select) => {
82            let qualified = qualify_select(*select, options, strategy, &mut next_alias);
83            Expression::Select(Box::new(qualified))
84        }
85        Expression::Union(mut union) => {
86            union.left = qualify_tables(union.left, options);
87            union.right = qualify_tables(union.right, options);
88            Expression::Union(union)
89        }
90        Expression::Intersect(mut intersect) => {
91            intersect.left = qualify_tables(intersect.left, options);
92            intersect.right = qualify_tables(intersect.right, options);
93            Expression::Intersect(intersect)
94        }
95        Expression::Except(mut except) => {
96            except.left = qualify_tables(except.left, options);
97            except.right = qualify_tables(except.right, options);
98            Expression::Except(except)
99        }
100        _ => expression,
101    }
102}
103
104/// Qualify a SELECT expression
105fn qualify_select(
106    mut select: Select,
107    options: &QualifyTablesOptions,
108    strategy: NormalizationStrategy,
109    next_alias: &mut impl FnMut() -> String,
110) -> Select {
111    // Collect CTE names to avoid qualifying them
112    let cte_names: HashSet<String> = select
113        .with
114        .as_ref()
115        .map(|w| w.ctes.iter().map(|c| c.alias.name.clone()).collect())
116        .unwrap_or_default();
117
118    // Track canonical aliases if needed
119    let mut canonical_aliases: HashMap<String, String> = HashMap::new();
120
121    // Qualify CTEs first
122    if let Some(ref mut with) = select.with {
123        for cte in &mut with.ctes {
124            cte.this = qualify_tables(cte.this.clone(), options);
125        }
126    }
127
128    // Qualify tables in FROM clause
129    if let Some(ref mut from) = select.from {
130        for expr in &mut from.expressions {
131            *expr = qualify_table_expression(
132                expr.clone(),
133                options,
134                strategy,
135                &cte_names,
136                &mut canonical_aliases,
137                next_alias,
138            );
139        }
140    }
141
142    // Qualify tables in JOINs
143    for join in &mut select.joins {
144        join.this = qualify_table_expression(
145            join.this.clone(),
146            options,
147            strategy,
148            &cte_names,
149            &mut canonical_aliases,
150            next_alias,
151        );
152    }
153
154    // Update column references if using canonical aliases
155    if options.canonicalize_table_aliases && !canonical_aliases.is_empty() {
156        select = update_column_references(select, &canonical_aliases);
157    }
158
159    select
160}
161
162/// Qualify a table expression (Table, Subquery, etc.)
163fn qualify_table_expression(
164    expression: Expression,
165    options: &QualifyTablesOptions,
166    strategy: NormalizationStrategy,
167    cte_names: &HashSet<String>,
168    canonical_aliases: &mut HashMap<String, String>,
169    next_alias: &mut impl FnMut() -> String,
170) -> Expression {
171    match expression {
172        Expression::Table(mut table) => {
173            let table_name = table.name.name.clone();
174
175            // Don't qualify CTEs
176            if cte_names.contains(&table_name) {
177                // Still ensure it has an alias
178                ensure_table_alias(&mut table, strategy, canonical_aliases, next_alias, options);
179                return Expression::Table(table);
180            }
181
182            // Add db if specified and not already present
183            if let Some(ref db) = options.db {
184                if table.schema.is_none() {
185                    table.schema =
186                        Some(normalize_identifier(Identifier::new(db.clone()), strategy));
187                }
188            }
189
190            // Add catalog if specified, db is present, and catalog not already present
191            if let Some(ref catalog) = options.catalog {
192                if table.schema.is_some() && table.catalog.is_none() {
193                    table.catalog = Some(normalize_identifier(
194                        Identifier::new(catalog.clone()),
195                        strategy,
196                    ));
197                }
198            }
199
200            // Ensure the table has an alias
201            ensure_table_alias(&mut table, strategy, canonical_aliases, next_alias, options);
202
203            Expression::Table(table)
204        }
205        Expression::Subquery(mut subquery) => {
206            // Qualify the inner query
207            subquery.this = qualify_tables(subquery.this, options);
208
209            // Ensure the subquery has an alias
210            if subquery.alias.is_none() || options.canonicalize_table_aliases {
211                let alias_name = if options.canonicalize_table_aliases {
212                    let new_name = next_alias();
213                    if let Some(ref old_alias) = subquery.alias {
214                        canonical_aliases.insert(old_alias.name.clone(), new_name.clone());
215                    }
216                    new_name
217                } else {
218                    subquery
219                        .alias
220                        .as_ref()
221                        .map(|a| a.name.clone())
222                        .unwrap_or_else(|| next_alias())
223                };
224
225                subquery.alias = Some(normalize_identifier(Identifier::new(alias_name), strategy));
226            }
227
228            Expression::Subquery(subquery)
229        }
230        Expression::Paren(mut paren) => {
231            paren.this = qualify_table_expression(
232                paren.this,
233                options,
234                strategy,
235                cte_names,
236                canonical_aliases,
237                next_alias,
238            );
239            Expression::Paren(paren)
240        }
241        _ => expression,
242    }
243}
244
245/// Ensure a table has an alias
246fn ensure_table_alias(
247    table: &mut TableRef,
248    strategy: NormalizationStrategy,
249    canonical_aliases: &mut HashMap<String, String>,
250    next_alias: &mut impl FnMut() -> String,
251    options: &QualifyTablesOptions,
252) {
253    let table_name = table.name.name.clone();
254
255    if options.canonicalize_table_aliases {
256        // Use canonical alias (_0, _1, etc.)
257        let new_alias = next_alias();
258        let old_alias = table
259            .alias
260            .as_ref()
261            .map(|a| a.name.clone())
262            .unwrap_or(table_name.clone());
263        canonical_aliases.insert(old_alias, new_alias.clone());
264        table.alias = Some(normalize_identifier(Identifier::new(new_alias), strategy));
265    } else if table.alias.is_none() {
266        // Use table name as alias
267        table.alias = Some(normalize_identifier(Identifier::new(table_name), strategy));
268    }
269}
270
271/// Update column references to use canonical aliases
272fn update_column_references(
273    mut select: Select,
274    canonical_aliases: &HashMap<String, String>,
275) -> Select {
276    // Update SELECT expressions
277    select.expressions = select
278        .expressions
279        .into_iter()
280        .map(|e| update_column_in_expression(e, canonical_aliases))
281        .collect();
282
283    // Update WHERE
284    if let Some(mut where_clause) = select.where_clause {
285        where_clause.this = update_column_in_expression(where_clause.this, canonical_aliases);
286        select.where_clause = Some(where_clause);
287    }
288
289    // Update GROUP BY
290    if let Some(mut group_by) = select.group_by {
291        group_by.expressions = group_by
292            .expressions
293            .into_iter()
294            .map(|e| update_column_in_expression(e, canonical_aliases))
295            .collect();
296        select.group_by = Some(group_by);
297    }
298
299    // Update HAVING
300    if let Some(mut having) = select.having {
301        having.this = update_column_in_expression(having.this, canonical_aliases);
302        select.having = Some(having);
303    }
304
305    // Update ORDER BY
306    if let Some(mut order_by) = select.order_by {
307        order_by.expressions = order_by
308            .expressions
309            .into_iter()
310            .map(|mut o| {
311                o.this = update_column_in_expression(o.this, canonical_aliases);
312                o
313            })
314            .collect();
315        select.order_by = Some(order_by);
316    }
317
318    // Update JOIN ON conditions
319    for join in &mut select.joins {
320        if let Some(on) = &mut join.on {
321            *on = update_column_in_expression(on.clone(), canonical_aliases);
322        }
323    }
324
325    select
326}
327
328/// Update column references in an expression
329fn update_column_in_expression(
330    expression: Expression,
331    canonical_aliases: &HashMap<String, String>,
332) -> Expression {
333    match expression {
334        Expression::Column(mut col) => {
335            if let Some(ref table) = col.table {
336                if let Some(canonical) = canonical_aliases.get(&table.name) {
337                    col.table = Some(Identifier {
338                        name: canonical.clone(),
339                        quoted: table.quoted,
340                        trailing_comments: table.trailing_comments.clone(),
341                        span: None,
342                    });
343                }
344            }
345            Expression::Column(col)
346        }
347        Expression::And(mut bin) => {
348            bin.left = update_column_in_expression(bin.left, canonical_aliases);
349            bin.right = update_column_in_expression(bin.right, canonical_aliases);
350            Expression::And(bin)
351        }
352        Expression::Or(mut bin) => {
353            bin.left = update_column_in_expression(bin.left, canonical_aliases);
354            bin.right = update_column_in_expression(bin.right, canonical_aliases);
355            Expression::Or(bin)
356        }
357        Expression::Eq(mut bin) => {
358            bin.left = update_column_in_expression(bin.left, canonical_aliases);
359            bin.right = update_column_in_expression(bin.right, canonical_aliases);
360            Expression::Eq(bin)
361        }
362        Expression::Neq(mut bin) => {
363            bin.left = update_column_in_expression(bin.left, canonical_aliases);
364            bin.right = update_column_in_expression(bin.right, canonical_aliases);
365            Expression::Neq(bin)
366        }
367        Expression::Lt(mut bin) => {
368            bin.left = update_column_in_expression(bin.left, canonical_aliases);
369            bin.right = update_column_in_expression(bin.right, canonical_aliases);
370            Expression::Lt(bin)
371        }
372        Expression::Lte(mut bin) => {
373            bin.left = update_column_in_expression(bin.left, canonical_aliases);
374            bin.right = update_column_in_expression(bin.right, canonical_aliases);
375            Expression::Lte(bin)
376        }
377        Expression::Gt(mut bin) => {
378            bin.left = update_column_in_expression(bin.left, canonical_aliases);
379            bin.right = update_column_in_expression(bin.right, canonical_aliases);
380            Expression::Gt(bin)
381        }
382        Expression::Gte(mut bin) => {
383            bin.left = update_column_in_expression(bin.left, canonical_aliases);
384            bin.right = update_column_in_expression(bin.right, canonical_aliases);
385            Expression::Gte(bin)
386        }
387        Expression::Not(mut un) => {
388            un.this = update_column_in_expression(un.this, canonical_aliases);
389            Expression::Not(un)
390        }
391        Expression::Paren(mut paren) => {
392            paren.this = update_column_in_expression(paren.this, canonical_aliases);
393            Expression::Paren(paren)
394        }
395        Expression::Alias(mut alias) => {
396            alias.this = update_column_in_expression(alias.this, canonical_aliases);
397            Expression::Alias(alias)
398        }
399        Expression::Function(mut func) => {
400            func.args = func
401                .args
402                .into_iter()
403                .map(|a| update_column_in_expression(a, canonical_aliases))
404                .collect();
405            Expression::Function(func)
406        }
407        Expression::AggregateFunction(mut agg) => {
408            agg.args = agg
409                .args
410                .into_iter()
411                .map(|a| update_column_in_expression(a, canonical_aliases))
412                .collect();
413            Expression::AggregateFunction(agg)
414        }
415        Expression::Case(mut case) => {
416            case.operand = case
417                .operand
418                .map(|o| update_column_in_expression(o, canonical_aliases));
419            case.whens = case
420                .whens
421                .into_iter()
422                .map(|(w, t)| {
423                    (
424                        update_column_in_expression(w, canonical_aliases),
425                        update_column_in_expression(t, canonical_aliases),
426                    )
427                })
428                .collect();
429            case.else_ = case
430                .else_
431                .map(|e| update_column_in_expression(e, canonical_aliases));
432            Expression::Case(case)
433        }
434        _ => expression,
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441    use crate::generator::Generator;
442    use crate::parser::Parser;
443
444    fn gen(expr: &Expression) -> String {
445        Generator::new().generate(expr).unwrap()
446    }
447
448    fn parse(sql: &str) -> Expression {
449        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
450    }
451
452    #[test]
453    fn test_qualify_with_db() {
454        let options = QualifyTablesOptions::new().with_db("mydb");
455        let expr = parse("SELECT * FROM users");
456        let qualified = qualify_tables(expr, &options);
457        let sql = gen(&qualified);
458        // Should contain mydb.users
459        assert!(sql.contains("mydb") && sql.contains("users"));
460    }
461
462    #[test]
463    fn test_qualify_with_db_and_catalog() {
464        let options = QualifyTablesOptions::new()
465            .with_db("mydb")
466            .with_catalog("mycatalog");
467        let expr = parse("SELECT * FROM users");
468        let qualified = qualify_tables(expr, &options);
469        let sql = gen(&qualified);
470        // Should contain mycatalog.mydb.users
471        assert!(sql.contains("mycatalog") && sql.contains("mydb") && sql.contains("users"));
472    }
473
474    #[test]
475    fn test_preserve_existing_schema() {
476        let options = QualifyTablesOptions::new().with_db("default_db");
477        let expr = parse("SELECT * FROM other_db.users");
478        let qualified = qualify_tables(expr, &options);
479        let sql = gen(&qualified);
480        // Should preserve other_db, not add default_db
481        assert!(sql.contains("other_db"));
482        assert!(!sql.contains("default_db"));
483    }
484
485    #[test]
486    fn test_ensure_table_alias() {
487        let options = QualifyTablesOptions::new();
488        let expr = parse("SELECT * FROM users");
489        let qualified = qualify_tables(expr, &options);
490        let sql = gen(&qualified);
491        // Should have alias (AS users)
492        assert!(sql.contains("AS") || sql.to_lowercase().contains(" users"));
493    }
494
495    #[test]
496    fn test_canonical_aliases() {
497        let options = QualifyTablesOptions::new().with_canonical_aliases();
498        let expr = parse("SELECT u.id FROM users u");
499        let qualified = qualify_tables(expr, &options);
500        let sql = gen(&qualified);
501        // Should use canonical alias like _0
502        assert!(sql.contains("_0"));
503    }
504
505    #[test]
506    fn test_qualify_join() {
507        let options = QualifyTablesOptions::new().with_db("mydb");
508        let expr = parse("SELECT * FROM users JOIN orders ON users.id = orders.user_id");
509        let qualified = qualify_tables(expr, &options);
510        let sql = gen(&qualified);
511        // Both tables should be qualified
512        assert!(sql.contains("mydb"));
513    }
514
515    #[test]
516    fn test_dont_qualify_cte() {
517        let options = QualifyTablesOptions::new().with_db("mydb");
518        let expr = parse("WITH cte AS (SELECT 1) SELECT * FROM cte");
519        let qualified = qualify_tables(expr, &options);
520        let sql = gen(&qualified);
521        // CTE reference should not be qualified with mydb
522        // The CTE definition might have mydb, but the SELECT FROM cte should not
523        assert!(sql.contains("cte"));
524    }
525
526    #[test]
527    fn test_qualify_subquery() {
528        let options = QualifyTablesOptions::new().with_db("mydb");
529        let expr = parse("SELECT * FROM (SELECT * FROM users) AS sub");
530        let qualified = qualify_tables(expr, &options);
531        let sql = gen(&qualified);
532        // Inner table should be qualified
533        assert!(sql.contains("mydb"));
534    }
535}