Skip to main content

polyglot_sql/optimizer/
isolate_table_selects.rs

1//! Isolate Table Selects Optimization Pass
2//!
3//! This module wraps plain table references in subqueries (`SELECT * FROM table`)
4//! when multiple tables are present in a scope. This normalization is needed for
5//! other optimizations (like merge_subqueries) to work correctly, since they
6//! expect each source in a multi-table query to be a subquery rather than a bare
7//! table reference.
8//!
9//! Ported from sqlglot's optimizer/isolate_table_selects.py
10
11use crate::dialects::DialectType;
12use crate::expressions::*;
13use crate::schema::Schema;
14
15/// Error type for the isolate_table_selects pass
16#[derive(Debug, Clone, thiserror::Error)]
17pub enum IsolateTableSelectsError {
18    #[error("Tables require an alias: {0}")]
19    MissingAlias(String),
20}
21
22/// Wrap plain table references in subqueries when multiple sources are present.
23///
24/// When a SELECT has multiple sources (FROM + JOINs, or multiple FROM tables),
25/// each bare `Table` reference is replaced with:
26///
27/// ```sql
28/// (SELECT * FROM table AS alias) AS alias
29/// ```
30///
31/// This makes every source a subquery, which simplifies downstream
32/// optimizations such as `merge_subqueries`.
33///
34/// # Arguments
35///
36/// * `expression` - The SQL expression tree to transform
37/// * `schema` - Optional schema for looking up column names (used to skip
38///   tables whose columns are unknown, matching the Python behavior)
39/// * `_dialect` - Optional dialect (reserved for future use)
40///
41/// # Returns
42///
43/// The transformed expression with isolated table selects
44pub fn isolate_table_selects(
45    expression: Expression,
46    schema: Option<&dyn Schema>,
47    _dialect: Option<DialectType>,
48) -> Expression {
49    match expression {
50        Expression::Select(select) => {
51            let transformed = isolate_select(*select, schema);
52            Expression::Select(Box::new(transformed))
53        }
54        Expression::Union(mut union) => {
55            union.left = isolate_table_selects(union.left, schema, _dialect);
56            union.right = isolate_table_selects(union.right, schema, _dialect);
57            Expression::Union(union)
58        }
59        Expression::Intersect(mut intersect) => {
60            intersect.left = isolate_table_selects(intersect.left, schema, _dialect);
61            intersect.right = isolate_table_selects(intersect.right, schema, _dialect);
62            Expression::Intersect(intersect)
63        }
64        Expression::Except(mut except) => {
65            except.left = isolate_table_selects(except.left, schema, _dialect);
66            except.right = isolate_table_selects(except.right, schema, _dialect);
67            Expression::Except(except)
68        }
69        other => other,
70    }
71}
72
73/// Process a single SELECT statement, wrapping bare table references in
74/// subqueries when multiple sources are present.
75fn isolate_select(mut select: Select, schema: Option<&dyn Schema>) -> Select {
76    // First, recursively process CTEs
77    if let Some(ref mut with) = select.with {
78        for cte in &mut with.ctes {
79            cte.this = isolate_table_selects(cte.this.clone(), schema, None);
80        }
81    }
82
83    // Recursively process subqueries in FROM and JOINs
84    if let Some(ref mut from) = select.from {
85        for expr in &mut from.expressions {
86            if let Expression::Subquery(ref mut sq) = expr {
87                sq.this = isolate_table_selects(sq.this.clone(), schema, None);
88            }
89        }
90    }
91    for join in &mut select.joins {
92        if let Expression::Subquery(ref mut sq) = join.this {
93            sq.this = isolate_table_selects(sq.this.clone(), schema, None);
94        }
95    }
96
97    // Count the total number of sources (FROM expressions + JOINs)
98    let source_count = count_sources(&select);
99
100    // Only isolate when there are multiple sources
101    if source_count <= 1 {
102        return select;
103    }
104
105    // Wrap bare table references in FROM clause
106    if let Some(ref mut from) = select.from {
107        from.expressions = from
108            .expressions
109            .drain(..)
110            .map(|expr| maybe_wrap_table(expr, schema))
111            .collect();
112    }
113
114    // Wrap bare table references in JOINs
115    for join in &mut select.joins {
116        join.this = maybe_wrap_table(join.this.clone(), schema);
117    }
118
119    select
120}
121
122/// Count the total number of source tables/subqueries in a SELECT.
123///
124/// This counts FROM expressions plus JOINs.
125fn count_sources(select: &Select) -> usize {
126    let from_count = select
127        .from
128        .as_ref()
129        .map(|f| f.expressions.len())
130        .unwrap_or(0);
131    let join_count = select.joins.len();
132    from_count + join_count
133}
134
135/// If the expression is a bare `Table` reference that should be isolated,
136/// wrap it in a `(SELECT * FROM table AS alias) AS alias` subquery.
137///
138/// A table is wrapped when:
139/// - It is an `Expression::Table` (not already a subquery)
140/// - It has an alias (required by the Python reference)
141/// - If a schema is provided, the table must have known columns in the schema
142///
143/// If no schema is provided, all aliased tables are wrapped (simplified mode).
144fn maybe_wrap_table(expression: Expression, schema: Option<&dyn Schema>) -> Expression {
145    match expression {
146        Expression::Table(ref table) => {
147            // If a schema is provided, check that the table has known columns.
148            // If we cannot find columns for the table, skip wrapping it (matching
149            // the Python behavior where `schema.column_names(source)` must be truthy).
150            if let Some(s) = schema {
151                let table_name = full_table_name(table);
152                if s.column_names(&table_name).unwrap_or_default().is_empty() {
153                    return expression;
154                }
155            }
156
157            // The table must have an alias; if it does not, we leave it as-is.
158            // The Python version raises an OptimizeError here, but in practice
159            // earlier passes (qualify_tables) ensure aliases are present.
160            let alias_name = match &table.alias {
161                Some(alias) if !alias.name.is_empty() => alias.name.clone(),
162                _ => return expression,
163            };
164
165            wrap_table_in_subquery(table.clone(), &alias_name)
166        }
167        _ => expression,
168    }
169}
170
171/// Build `(SELECT * FROM table_ref AS alias) AS alias` from a table reference.
172///
173/// The inner table reference keeps the original alias so that
174/// `FROM t AS t` becomes `(SELECT * FROM t AS t) AS t`.
175fn wrap_table_in_subquery(table: TableRef, alias_name: &str) -> Expression {
176    // Build: SELECT * FROM <table>
177    let inner_select = Select::new()
178        .column(Expression::Star(Star {
179            table: None,
180            except: None,
181            replace: None,
182            rename: None,
183            trailing_comments: Vec::new(),
184            span: None,
185        }))
186        .from(Expression::Table(table));
187
188    // Wrap the SELECT in a Subquery with the original alias
189    Expression::Subquery(Box::new(Subquery {
190        this: Expression::Select(Box::new(inner_select)),
191        alias: Some(Identifier::new(alias_name)),
192        column_aliases: Vec::new(),
193        order_by: None,
194        limit: None,
195        offset: None,
196        distribute_by: None,
197        sort_by: None,
198        cluster_by: None,
199        lateral: false,
200        modifiers_inside: false,
201        trailing_comments: Vec::new(),
202        inferred_type: None,
203    }))
204}
205
206/// Construct the fully qualified table name from a `TableRef`.
207///
208/// Produces `catalog.schema.name` or `schema.name` or just `name`
209/// depending on which parts are present.
210fn full_table_name(table: &TableRef) -> String {
211    let mut parts = Vec::new();
212    if let Some(ref catalog) = table.catalog {
213        parts.push(catalog.name.as_str());
214    }
215    if let Some(ref schema) = table.schema {
216        parts.push(schema.name.as_str());
217    }
218    parts.push(&table.name.name);
219    parts.join(".")
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225    use crate::generator::Generator;
226    use crate::parser::Parser;
227    use crate::schema::MappingSchema;
228
229    /// Helper: parse SQL into an Expression
230    fn parse(sql: &str) -> Expression {
231        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
232    }
233
234    /// Helper: generate SQL from an Expression
235    fn gen(expr: &Expression) -> String {
236        Generator::new().generate(expr).unwrap()
237    }
238
239    // ---------------------------------------------------------------
240    // Basic: single source should NOT be wrapped
241    // ---------------------------------------------------------------
242
243    #[test]
244    fn test_single_table_unchanged() {
245        let sql = "SELECT * FROM t AS t";
246        let expr = parse(sql);
247        let result = isolate_table_selects(expr, None, None);
248        let output = gen(&result);
249        // Should remain a plain table, not wrapped in a subquery
250        assert!(
251            !output.contains("(SELECT"),
252            "Single table should not be wrapped: {output}"
253        );
254    }
255
256    #[test]
257    fn test_single_subquery_unchanged() {
258        let sql = "SELECT * FROM (SELECT 1) AS t";
259        let expr = parse(sql);
260        let result = isolate_table_selects(expr, None, None);
261        let output = gen(&result);
262        // Still just one source, no additional wrapping expected
263        assert_eq!(
264            output.matches("(SELECT").count(),
265            1,
266            "Single subquery source should not gain extra wrapping: {output}"
267        );
268    }
269
270    // ---------------------------------------------------------------
271    // Multiple sources: tables should be wrapped
272    // ---------------------------------------------------------------
273
274    #[test]
275    fn test_two_tables_joined() {
276        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
277        let expr = parse(sql);
278        let result = isolate_table_selects(expr, None, None);
279        let output = gen(&result);
280        // Both tables should now be subqueries
281        assert!(
282            output.contains("(SELECT * FROM a AS a) AS a"),
283            "FROM table should be wrapped: {output}"
284        );
285        assert!(
286            output.contains("(SELECT * FROM b AS b) AS b"),
287            "JOIN table should be wrapped: {output}"
288        );
289    }
290
291    #[test]
292    fn test_table_with_join_subquery() {
293        // If one source is already a subquery and the other is a table,
294        // only the bare table should be wrapped.
295        let sql = "SELECT * FROM a AS a JOIN (SELECT * FROM b) AS b ON a.id = b.id";
296        let expr = parse(sql);
297        let result = isolate_table_selects(expr, None, None);
298        let output = gen(&result);
299        // `a` should be wrapped
300        assert!(
301            output.contains("(SELECT * FROM a AS a) AS a"),
302            "Bare table should be wrapped: {output}"
303        );
304        // `b` is already a subquery, so it should appear once
305        // (no double-wrapping)
306        assert_eq!(
307            output.matches("(SELECT * FROM b)").count(),
308            1,
309            "Already-subquery source should not be double-wrapped: {output}"
310        );
311    }
312
313    #[test]
314    fn test_no_alias_not_wrapped() {
315        // Tables without aliases are left alone (in real pipelines,
316        // qualify_tables runs first and assigns aliases).
317        let sql = "SELECT * FROM a JOIN b ON a.id = b.id";
318        let expr = parse(sql);
319        let result = isolate_table_selects(expr, None, None);
320        let output = gen(&result);
321        // Without aliases, tables should not be wrapped
322        assert!(
323            !output.contains("(SELECT * FROM a"),
324            "Table without alias should not be wrapped: {output}"
325        );
326    }
327
328    // ---------------------------------------------------------------
329    // Schema-aware mode: only wrap tables with known columns
330    // ---------------------------------------------------------------
331
332    #[test]
333    fn test_schema_known_table_wrapped() {
334        let mut schema = MappingSchema::new();
335        schema
336            .add_table(
337                "a",
338                &[(
339                    "id".to_string(),
340                    DataType::Int {
341                        length: None,
342                        integer_spelling: false,
343                    },
344                )],
345                None,
346            )
347            .unwrap();
348        schema
349            .add_table(
350                "b",
351                &[(
352                    "id".to_string(),
353                    DataType::Int {
354                        length: None,
355                        integer_spelling: false,
356                    },
357                )],
358                None,
359            )
360            .unwrap();
361
362        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
363        let expr = parse(sql);
364        let result = isolate_table_selects(expr, Some(&schema), None);
365        let output = gen(&result);
366        assert!(
367            output.contains("(SELECT * FROM a AS a) AS a"),
368            "Known table 'a' should be wrapped: {output}"
369        );
370        assert!(
371            output.contains("(SELECT * FROM b AS b) AS b"),
372            "Known table 'b' should be wrapped: {output}"
373        );
374    }
375
376    #[test]
377    fn test_schema_unknown_table_not_wrapped() {
378        let mut schema = MappingSchema::new();
379        // Only 'a' is in the schema; 'b' is unknown
380        schema
381            .add_table(
382                "a",
383                &[(
384                    "id".to_string(),
385                    DataType::Int {
386                        length: None,
387                        integer_spelling: false,
388                    },
389                )],
390                None,
391            )
392            .unwrap();
393
394        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
395        let expr = parse(sql);
396        let result = isolate_table_selects(expr, Some(&schema), None);
397        let output = gen(&result);
398        assert!(
399            output.contains("(SELECT * FROM a AS a) AS a"),
400            "Known table 'a' should be wrapped: {output}"
401        );
402        // 'b' is not in schema, so it should remain a plain table
403        assert!(
404            !output.contains("(SELECT * FROM b AS b) AS b"),
405            "Unknown table 'b' should NOT be wrapped: {output}"
406        );
407    }
408
409    // ---------------------------------------------------------------
410    // Recursive: CTEs and nested subqueries
411    // ---------------------------------------------------------------
412
413    #[test]
414    fn test_cte_inner_query_processed() {
415        let sql =
416            "WITH cte AS (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) SELECT * FROM cte AS c";
417        let expr = parse(sql);
418        let result = isolate_table_selects(expr, None, None);
419        let output = gen(&result);
420        // Inside the CTE, x and y should be wrapped
421        assert!(
422            output.contains("(SELECT * FROM x AS x) AS x"),
423            "CTE inner table 'x' should be wrapped: {output}"
424        );
425        assert!(
426            output.contains("(SELECT * FROM y AS y) AS y"),
427            "CTE inner table 'y' should be wrapped: {output}"
428        );
429    }
430
431    #[test]
432    fn test_nested_subquery_processed() {
433        let sql = "SELECT * FROM (SELECT * FROM a AS a JOIN b AS b ON a.id = b.id) AS sub";
434        let expr = parse(sql);
435        let result = isolate_table_selects(expr, None, None);
436        let output = gen(&result);
437        // The inner SELECT has two sources; they should be wrapped
438        assert!(
439            output.contains("(SELECT * FROM a AS a) AS a"),
440            "Nested inner table 'a' should be wrapped: {output}"
441        );
442    }
443
444    // ---------------------------------------------------------------
445    // Set operations: UNION, INTERSECT, EXCEPT
446    // ---------------------------------------------------------------
447
448    #[test]
449    fn test_union_both_sides_processed() {
450        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id UNION ALL SELECT * FROM c AS c";
451        let expr = parse(sql);
452        let result = isolate_table_selects(expr, None, None);
453        let output = gen(&result);
454        // Left side has two sources - should be wrapped
455        assert!(
456            output.contains("(SELECT * FROM a AS a) AS a"),
457            "UNION left side should be processed: {output}"
458        );
459        // Right side has only one source - should NOT be wrapped
460        assert!(
461            !output.contains("(SELECT * FROM c AS c) AS c"),
462            "UNION right side (single source) should not be wrapped: {output}"
463        );
464    }
465
466    // ---------------------------------------------------------------
467    // Edge cases
468    // ---------------------------------------------------------------
469
470    #[test]
471    fn test_cross_join() {
472        let sql = "SELECT * FROM a AS a CROSS JOIN b AS b";
473        let expr = parse(sql);
474        let result = isolate_table_selects(expr, None, None);
475        let output = gen(&result);
476        assert!(
477            output.contains("(SELECT * FROM a AS a) AS a"),
478            "CROSS JOIN table 'a' should be wrapped: {output}"
479        );
480        assert!(
481            output.contains("(SELECT * FROM b AS b) AS b"),
482            "CROSS JOIN table 'b' should be wrapped: {output}"
483        );
484    }
485
486    #[test]
487    fn test_multiple_from_tables() {
488        // Comma-separated FROM (implicit cross join)
489        let sql = "SELECT * FROM a AS a, b AS b";
490        let expr = parse(sql);
491        let result = isolate_table_selects(expr, None, None);
492        let output = gen(&result);
493        assert!(
494            output.contains("(SELECT * FROM a AS a) AS a"),
495            "Comma-join table 'a' should be wrapped: {output}"
496        );
497        assert!(
498            output.contains("(SELECT * FROM b AS b) AS b"),
499            "Comma-join table 'b' should be wrapped: {output}"
500        );
501    }
502
503    #[test]
504    fn test_three_way_join() {
505        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id JOIN c AS c ON b.id = c.id";
506        let expr = parse(sql);
507        let result = isolate_table_selects(expr, None, None);
508        let output = gen(&result);
509        assert!(
510            output.contains("(SELECT * FROM a AS a) AS a"),
511            "Three-way join: 'a' should be wrapped: {output}"
512        );
513        assert!(
514            output.contains("(SELECT * FROM b AS b) AS b"),
515            "Three-way join: 'b' should be wrapped: {output}"
516        );
517        assert!(
518            output.contains("(SELECT * FROM c AS c) AS c"),
519            "Three-way join: 'c' should be wrapped: {output}"
520        );
521    }
522
523    #[test]
524    fn test_qualified_table_name_with_schema() {
525        let mut schema = MappingSchema::new();
526        schema
527            .add_table(
528                "mydb.a",
529                &[(
530                    "id".to_string(),
531                    DataType::Int {
532                        length: None,
533                        integer_spelling: false,
534                    },
535                )],
536                None,
537            )
538            .unwrap();
539        schema
540            .add_table(
541                "mydb.b",
542                &[(
543                    "id".to_string(),
544                    DataType::Int {
545                        length: None,
546                        integer_spelling: false,
547                    },
548                )],
549                None,
550            )
551            .unwrap();
552
553        let sql = "SELECT * FROM mydb.a AS a JOIN mydb.b AS b ON a.id = b.id";
554        let expr = parse(sql);
555        let result = isolate_table_selects(expr, Some(&schema), None);
556        let output = gen(&result);
557        assert!(
558            output.contains("(SELECT * FROM mydb.a AS a) AS a"),
559            "Qualified table 'mydb.a' should be wrapped: {output}"
560        );
561        assert!(
562            output.contains("(SELECT * FROM mydb.b AS b) AS b"),
563            "Qualified table 'mydb.b' should be wrapped: {output}"
564        );
565    }
566
567    #[test]
568    fn test_non_select_expression_unchanged() {
569        // Non-SELECT expressions (e.g., INSERT, CREATE) pass through unchanged
570        let sql = "INSERT INTO t VALUES (1)";
571        let expr = parse(sql);
572        let original = gen(&expr);
573        let result = isolate_table_selects(expr, None, None);
574        let output = gen(&result);
575        assert_eq!(original, output, "Non-SELECT should be unchanged");
576    }
577}