Skip to main content

polyglot_sql/optimizer/
isolate_table_selects.rs

1//! Isolate Table Selects Optimization Pass
2//!
3//! This module wraps plain table references in subqueries (`SELECT * FROM table`)
4//! when multiple tables are present in a scope. This normalization is needed for
5//! other optimizations (like merge_subqueries) to work correctly, since they
6//! expect each source in a multi-table query to be a subquery rather than a bare
7//! table reference.
8//!
9//! Ported from sqlglot's optimizer/isolate_table_selects.py
10
11use crate::dialects::DialectType;
12use crate::expressions::*;
13use crate::schema::Schema;
14
15/// Error type for the isolate_table_selects pass
16#[derive(Debug, Clone, thiserror::Error)]
17pub enum IsolateTableSelectsError {
18    #[error("Tables require an alias: {0}")]
19    MissingAlias(String),
20}
21
22/// Wrap plain table references in subqueries when multiple sources are present.
23///
24/// When a SELECT has multiple sources (FROM + JOINs, or multiple FROM tables),
25/// each bare `Table` reference is replaced with:
26///
27/// ```sql
28/// (SELECT * FROM table AS alias) AS alias
29/// ```
30///
31/// This makes every source a subquery, which simplifies downstream
32/// optimizations such as `merge_subqueries`.
33///
34/// # Arguments
35///
36/// * `expression` - The SQL expression tree to transform
37/// * `schema` - Optional schema for looking up column names (used to skip
38///   tables whose columns are unknown, matching the Python behavior)
39/// * `_dialect` - Optional dialect (reserved for future use)
40///
41/// # Returns
42///
43/// The transformed expression with isolated table selects
44pub fn isolate_table_selects(
45    expression: Expression,
46    schema: Option<&dyn Schema>,
47    _dialect: Option<DialectType>,
48) -> Expression {
49    match expression {
50        Expression::Select(select) => {
51            let transformed = isolate_select(*select, schema);
52            Expression::Select(Box::new(transformed))
53        }
54        Expression::Union(mut union) => {
55            union.left = isolate_table_selects(union.left, schema, _dialect);
56            union.right = isolate_table_selects(union.right, schema, _dialect);
57            Expression::Union(union)
58        }
59        Expression::Intersect(mut intersect) => {
60            intersect.left = isolate_table_selects(intersect.left, schema, _dialect);
61            intersect.right = isolate_table_selects(intersect.right, schema, _dialect);
62            Expression::Intersect(intersect)
63        }
64        Expression::Except(mut except) => {
65            except.left = isolate_table_selects(except.left, schema, _dialect);
66            except.right = isolate_table_selects(except.right, schema, _dialect);
67            Expression::Except(except)
68        }
69        other => other,
70    }
71}
72
73/// Process a single SELECT statement, wrapping bare table references in
74/// subqueries when multiple sources are present.
75fn isolate_select(mut select: Select, schema: Option<&dyn Schema>) -> Select {
76    // First, recursively process CTEs
77    if let Some(ref mut with) = select.with {
78        for cte in &mut with.ctes {
79            cte.this = isolate_table_selects(cte.this.clone(), schema, None);
80        }
81    }
82
83    // Recursively process subqueries in FROM and JOINs
84    if let Some(ref mut from) = select.from {
85        for expr in &mut from.expressions {
86            if let Expression::Subquery(ref mut sq) = expr {
87                sq.this = isolate_table_selects(sq.this.clone(), schema, None);
88            }
89        }
90    }
91    for join in &mut select.joins {
92        if let Expression::Subquery(ref mut sq) = join.this {
93            sq.this = isolate_table_selects(sq.this.clone(), schema, None);
94        }
95    }
96
97    // Count the total number of sources (FROM expressions + JOINs)
98    let source_count = count_sources(&select);
99
100    // Only isolate when there are multiple sources
101    if source_count <= 1 {
102        return select;
103    }
104
105    // Wrap bare table references in FROM clause
106    if let Some(ref mut from) = select.from {
107        from.expressions = from
108            .expressions
109            .drain(..)
110            .map(|expr| maybe_wrap_table(expr, schema))
111            .collect();
112    }
113
114    // Wrap bare table references in JOINs
115    for join in &mut select.joins {
116        join.this = maybe_wrap_table(join.this.clone(), schema);
117    }
118
119    select
120}
121
122/// Count the total number of source tables/subqueries in a SELECT.
123///
124/// This counts FROM expressions plus JOINs.
125fn count_sources(select: &Select) -> usize {
126    let from_count = select
127        .from
128        .as_ref()
129        .map(|f| f.expressions.len())
130        .unwrap_or(0);
131    let join_count = select.joins.len();
132    from_count + join_count
133}
134
135/// If the expression is a bare `Table` reference that should be isolated,
136/// wrap it in a `(SELECT * FROM table AS alias) AS alias` subquery.
137///
138/// A table is wrapped when:
139/// - It is an `Expression::Table` (not already a subquery)
140/// - It has an alias (required by the Python reference)
141/// - If a schema is provided, the table must have known columns in the schema
142///
143/// If no schema is provided, all aliased tables are wrapped (simplified mode).
144fn maybe_wrap_table(expression: Expression, schema: Option<&dyn Schema>) -> Expression {
145    match expression {
146        Expression::Table(ref table) => {
147            // If a schema is provided, check that the table has known columns.
148            // If we cannot find columns for the table, skip wrapping it (matching
149            // the Python behavior where `schema.column_names(source)` must be truthy).
150            if let Some(s) = schema {
151                let table_name = full_table_name(table);
152                if s.column_names(&table_name).unwrap_or_default().is_empty() {
153                    return expression;
154                }
155            }
156
157            // The table must have an alias; if it does not, we leave it as-is.
158            // The Python version raises an OptimizeError here, but in practice
159            // earlier passes (qualify_tables) ensure aliases are present.
160            let alias_name = match &table.alias {
161                Some(alias) if !alias.name.is_empty() => alias.name.clone(),
162                _ => return expression,
163            };
164
165            wrap_table_in_subquery(table.clone(), &alias_name)
166        }
167        _ => expression,
168    }
169}
170
171/// Build `(SELECT * FROM table_ref AS alias) AS alias` from a table reference.
172///
173/// The inner table reference keeps the original alias so that
174/// `FROM t AS t` becomes `(SELECT * FROM t AS t) AS t`.
175fn wrap_table_in_subquery(table: TableRef, alias_name: &str) -> Expression {
176    // Build: SELECT * FROM <table>
177    let inner_select = Select::new()
178        .column(Expression::Star(Star {
179            table: None,
180            except: None,
181            replace: None,
182            rename: None,
183            trailing_comments: Vec::new(),
184            span: None,
185        }))
186        .from(Expression::Table(table));
187
188    // Wrap the SELECT in a Subquery with the original alias
189    Expression::Subquery(Box::new(Subquery {
190        this: Expression::Select(Box::new(inner_select)),
191        alias: Some(Identifier::new(alias_name)),
192        column_aliases: Vec::new(),
193        order_by: None,
194        limit: None,
195        offset: None,
196        distribute_by: None,
197        sort_by: None,
198        cluster_by: None,
199        lateral: false,
200        modifiers_inside: false,
201        trailing_comments: Vec::new(),
202    }))
203}
204
205/// Construct the fully qualified table name from a `TableRef`.
206///
207/// Produces `catalog.schema.name` or `schema.name` or just `name`
208/// depending on which parts are present.
209fn full_table_name(table: &TableRef) -> String {
210    let mut parts = Vec::new();
211    if let Some(ref catalog) = table.catalog {
212        parts.push(catalog.name.as_str());
213    }
214    if let Some(ref schema) = table.schema {
215        parts.push(schema.name.as_str());
216    }
217    parts.push(&table.name.name);
218    parts.join(".")
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use crate::generator::Generator;
225    use crate::parser::Parser;
226    use crate::schema::MappingSchema;
227
228    /// Helper: parse SQL into an Expression
229    fn parse(sql: &str) -> Expression {
230        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
231    }
232
233    /// Helper: generate SQL from an Expression
234    fn gen(expr: &Expression) -> String {
235        Generator::new().generate(expr).unwrap()
236    }
237
238    // ---------------------------------------------------------------
239    // Basic: single source should NOT be wrapped
240    // ---------------------------------------------------------------
241
242    #[test]
243    fn test_single_table_unchanged() {
244        let sql = "SELECT * FROM t AS t";
245        let expr = parse(sql);
246        let result = isolate_table_selects(expr, None, None);
247        let output = gen(&result);
248        // Should remain a plain table, not wrapped in a subquery
249        assert!(
250            !output.contains("(SELECT"),
251            "Single table should not be wrapped: {output}"
252        );
253    }
254
255    #[test]
256    fn test_single_subquery_unchanged() {
257        let sql = "SELECT * FROM (SELECT 1) AS t";
258        let expr = parse(sql);
259        let result = isolate_table_selects(expr, None, None);
260        let output = gen(&result);
261        // Still just one source, no additional wrapping expected
262        assert_eq!(
263            output.matches("(SELECT").count(),
264            1,
265            "Single subquery source should not gain extra wrapping: {output}"
266        );
267    }
268
269    // ---------------------------------------------------------------
270    // Multiple sources: tables should be wrapped
271    // ---------------------------------------------------------------
272
273    #[test]
274    fn test_two_tables_joined() {
275        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
276        let expr = parse(sql);
277        let result = isolate_table_selects(expr, None, None);
278        let output = gen(&result);
279        // Both tables should now be subqueries
280        assert!(
281            output.contains("(SELECT * FROM a AS a) AS a"),
282            "FROM table should be wrapped: {output}"
283        );
284        assert!(
285            output.contains("(SELECT * FROM b AS b) AS b"),
286            "JOIN table should be wrapped: {output}"
287        );
288    }
289
290    #[test]
291    fn test_table_with_join_subquery() {
292        // If one source is already a subquery and the other is a table,
293        // only the bare table should be wrapped.
294        let sql = "SELECT * FROM a AS a JOIN (SELECT * FROM b) AS b ON a.id = b.id";
295        let expr = parse(sql);
296        let result = isolate_table_selects(expr, None, None);
297        let output = gen(&result);
298        // `a` should be wrapped
299        assert!(
300            output.contains("(SELECT * FROM a AS a) AS a"),
301            "Bare table should be wrapped: {output}"
302        );
303        // `b` is already a subquery, so it should appear once
304        // (no double-wrapping)
305        assert_eq!(
306            output.matches("(SELECT * FROM b)").count(),
307            1,
308            "Already-subquery source should not be double-wrapped: {output}"
309        );
310    }
311
312    #[test]
313    fn test_no_alias_not_wrapped() {
314        // Tables without aliases are left alone (in real pipelines,
315        // qualify_tables runs first and assigns aliases).
316        let sql = "SELECT * FROM a JOIN b ON a.id = b.id";
317        let expr = parse(sql);
318        let result = isolate_table_selects(expr, None, None);
319        let output = gen(&result);
320        // Without aliases, tables should not be wrapped
321        assert!(
322            !output.contains("(SELECT * FROM a"),
323            "Table without alias should not be wrapped: {output}"
324        );
325    }
326
327    // ---------------------------------------------------------------
328    // Schema-aware mode: only wrap tables with known columns
329    // ---------------------------------------------------------------
330
331    #[test]
332    fn test_schema_known_table_wrapped() {
333        let mut schema = MappingSchema::new();
334        schema
335            .add_table(
336                "a",
337                &[(
338                    "id".to_string(),
339                    DataType::Int {
340                        length: None,
341                        integer_spelling: false,
342                    },
343                )],
344                None,
345            )
346            .unwrap();
347        schema
348            .add_table(
349                "b",
350                &[(
351                    "id".to_string(),
352                    DataType::Int {
353                        length: None,
354                        integer_spelling: false,
355                    },
356                )],
357                None,
358            )
359            .unwrap();
360
361        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
362        let expr = parse(sql);
363        let result = isolate_table_selects(expr, Some(&schema), None);
364        let output = gen(&result);
365        assert!(
366            output.contains("(SELECT * FROM a AS a) AS a"),
367            "Known table 'a' should be wrapped: {output}"
368        );
369        assert!(
370            output.contains("(SELECT * FROM b AS b) AS b"),
371            "Known table 'b' should be wrapped: {output}"
372        );
373    }
374
375    #[test]
376    fn test_schema_unknown_table_not_wrapped() {
377        let mut schema = MappingSchema::new();
378        // Only 'a' is in the schema; 'b' is unknown
379        schema
380            .add_table(
381                "a",
382                &[(
383                    "id".to_string(),
384                    DataType::Int {
385                        length: None,
386                        integer_spelling: false,
387                    },
388                )],
389                None,
390            )
391            .unwrap();
392
393        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
394        let expr = parse(sql);
395        let result = isolate_table_selects(expr, Some(&schema), None);
396        let output = gen(&result);
397        assert!(
398            output.contains("(SELECT * FROM a AS a) AS a"),
399            "Known table 'a' should be wrapped: {output}"
400        );
401        // 'b' is not in schema, so it should remain a plain table
402        assert!(
403            !output.contains("(SELECT * FROM b AS b) AS b"),
404            "Unknown table 'b' should NOT be wrapped: {output}"
405        );
406    }
407
408    // ---------------------------------------------------------------
409    // Recursive: CTEs and nested subqueries
410    // ---------------------------------------------------------------
411
412    #[test]
413    fn test_cte_inner_query_processed() {
414        let sql =
415            "WITH cte AS (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) SELECT * FROM cte AS c";
416        let expr = parse(sql);
417        let result = isolate_table_selects(expr, None, None);
418        let output = gen(&result);
419        // Inside the CTE, x and y should be wrapped
420        assert!(
421            output.contains("(SELECT * FROM x AS x) AS x"),
422            "CTE inner table 'x' should be wrapped: {output}"
423        );
424        assert!(
425            output.contains("(SELECT * FROM y AS y) AS y"),
426            "CTE inner table 'y' should be wrapped: {output}"
427        );
428    }
429
430    #[test]
431    fn test_nested_subquery_processed() {
432        let sql = "SELECT * FROM (SELECT * FROM a AS a JOIN b AS b ON a.id = b.id) AS sub";
433        let expr = parse(sql);
434        let result = isolate_table_selects(expr, None, None);
435        let output = gen(&result);
436        // The inner SELECT has two sources; they should be wrapped
437        assert!(
438            output.contains("(SELECT * FROM a AS a) AS a"),
439            "Nested inner table 'a' should be wrapped: {output}"
440        );
441    }
442
443    // ---------------------------------------------------------------
444    // Set operations: UNION, INTERSECT, EXCEPT
445    // ---------------------------------------------------------------
446
447    #[test]
448    fn test_union_both_sides_processed() {
449        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id UNION ALL SELECT * FROM c AS c";
450        let expr = parse(sql);
451        let result = isolate_table_selects(expr, None, None);
452        let output = gen(&result);
453        // Left side has two sources - should be wrapped
454        assert!(
455            output.contains("(SELECT * FROM a AS a) AS a"),
456            "UNION left side should be processed: {output}"
457        );
458        // Right side has only one source - should NOT be wrapped
459        assert!(
460            !output.contains("(SELECT * FROM c AS c) AS c"),
461            "UNION right side (single source) should not be wrapped: {output}"
462        );
463    }
464
465    // ---------------------------------------------------------------
466    // Edge cases
467    // ---------------------------------------------------------------
468
469    #[test]
470    fn test_cross_join() {
471        let sql = "SELECT * FROM a AS a CROSS JOIN b AS b";
472        let expr = parse(sql);
473        let result = isolate_table_selects(expr, None, None);
474        let output = gen(&result);
475        assert!(
476            output.contains("(SELECT * FROM a AS a) AS a"),
477            "CROSS JOIN table 'a' should be wrapped: {output}"
478        );
479        assert!(
480            output.contains("(SELECT * FROM b AS b) AS b"),
481            "CROSS JOIN table 'b' should be wrapped: {output}"
482        );
483    }
484
485    #[test]
486    fn test_multiple_from_tables() {
487        // Comma-separated FROM (implicit cross join)
488        let sql = "SELECT * FROM a AS a, b AS b";
489        let expr = parse(sql);
490        let result = isolate_table_selects(expr, None, None);
491        let output = gen(&result);
492        assert!(
493            output.contains("(SELECT * FROM a AS a) AS a"),
494            "Comma-join table 'a' should be wrapped: {output}"
495        );
496        assert!(
497            output.contains("(SELECT * FROM b AS b) AS b"),
498            "Comma-join table 'b' should be wrapped: {output}"
499        );
500    }
501
502    #[test]
503    fn test_three_way_join() {
504        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id JOIN c AS c ON b.id = c.id";
505        let expr = parse(sql);
506        let result = isolate_table_selects(expr, None, None);
507        let output = gen(&result);
508        assert!(
509            output.contains("(SELECT * FROM a AS a) AS a"),
510            "Three-way join: 'a' should be wrapped: {output}"
511        );
512        assert!(
513            output.contains("(SELECT * FROM b AS b) AS b"),
514            "Three-way join: 'b' should be wrapped: {output}"
515        );
516        assert!(
517            output.contains("(SELECT * FROM c AS c) AS c"),
518            "Three-way join: 'c' should be wrapped: {output}"
519        );
520    }
521
522    #[test]
523    fn test_qualified_table_name_with_schema() {
524        let mut schema = MappingSchema::new();
525        schema
526            .add_table(
527                "mydb.a",
528                &[(
529                    "id".to_string(),
530                    DataType::Int {
531                        length: None,
532                        integer_spelling: false,
533                    },
534                )],
535                None,
536            )
537            .unwrap();
538        schema
539            .add_table(
540                "mydb.b",
541                &[(
542                    "id".to_string(),
543                    DataType::Int {
544                        length: None,
545                        integer_spelling: false,
546                    },
547                )],
548                None,
549            )
550            .unwrap();
551
552        let sql = "SELECT * FROM mydb.a AS a JOIN mydb.b AS b ON a.id = b.id";
553        let expr = parse(sql);
554        let result = isolate_table_selects(expr, Some(&schema), None);
555        let output = gen(&result);
556        assert!(
557            output.contains("(SELECT * FROM mydb.a AS a) AS a"),
558            "Qualified table 'mydb.a' should be wrapped: {output}"
559        );
560        assert!(
561            output.contains("(SELECT * FROM mydb.b AS b) AS b"),
562            "Qualified table 'mydb.b' should be wrapped: {output}"
563        );
564    }
565
566    #[test]
567    fn test_non_select_expression_unchanged() {
568        // Non-SELECT expressions (e.g., INSERT, CREATE) pass through unchanged
569        let sql = "INSERT INTO t VALUES (1)";
570        let expr = parse(sql);
571        let original = gen(&expr);
572        let result = isolate_table_selects(expr, None, None);
573        let output = gen(&result);
574        assert_eq!(original, output, "Non-SELECT should be unchanged");
575    }
576}