Skip to main content

polyglot_sql/optimizer/
isolate_table_selects.rs

1//! Isolate Table Selects Optimization Pass
2//!
3//! This module wraps plain table references in subqueries (`SELECT * FROM table`)
4//! when multiple tables are present in a scope. This normalization is needed for
5//! other optimizations (like merge_subqueries) to work correctly, since they
6//! expect each source in a multi-table query to be a subquery rather than a bare
7//! table reference.
8//!
9//! Ported from sqlglot's optimizer/isolate_table_selects.py
10
11use crate::dialects::DialectType;
12use crate::expressions::*;
13use crate::schema::Schema;
14
15/// Error type for the isolate_table_selects pass
16#[derive(Debug, Clone, thiserror::Error)]
17pub enum IsolateTableSelectsError {
18    #[error("Tables require an alias: {0}")]
19    MissingAlias(String),
20}
21
22/// Wrap plain table references in subqueries when multiple sources are present.
23///
24/// When a SELECT has multiple sources (FROM + JOINs, or multiple FROM tables),
25/// each bare `Table` reference is replaced with:
26///
27/// ```sql
28/// (SELECT * FROM table AS alias) AS alias
29/// ```
30///
31/// This makes every source a subquery, which simplifies downstream
32/// optimizations such as `merge_subqueries`.
33///
34/// # Arguments
35///
36/// * `expression` - The SQL expression tree to transform
37/// * `schema` - Optional schema for looking up column names (used to skip
38///   tables whose columns are unknown, matching the Python behavior)
39/// * `_dialect` - Optional dialect (reserved for future use)
40///
41/// # Returns
42///
43/// The transformed expression with isolated table selects
44pub fn isolate_table_selects(
45    expression: Expression,
46    schema: Option<&dyn Schema>,
47    _dialect: Option<DialectType>,
48) -> Expression {
49    match expression {
50        Expression::Select(select) => {
51            let transformed = isolate_select(*select, schema);
52            Expression::Select(Box::new(transformed))
53        }
54        Expression::Union(mut union) => {
55            let left = std::mem::replace(&mut union.left, Expression::Null(Null));
56            union.left = isolate_table_selects(left, schema, _dialect);
57            let right = std::mem::replace(&mut union.right, Expression::Null(Null));
58            union.right = isolate_table_selects(right, schema, _dialect);
59            Expression::Union(union)
60        }
61        Expression::Intersect(mut intersect) => {
62            let left = std::mem::replace(&mut intersect.left, Expression::Null(Null));
63            intersect.left = isolate_table_selects(left, schema, _dialect);
64            let right = std::mem::replace(&mut intersect.right, Expression::Null(Null));
65            intersect.right = isolate_table_selects(right, schema, _dialect);
66            Expression::Intersect(intersect)
67        }
68        Expression::Except(mut except) => {
69            let left = std::mem::replace(&mut except.left, Expression::Null(Null));
70            except.left = isolate_table_selects(left, schema, _dialect);
71            let right = std::mem::replace(&mut except.right, Expression::Null(Null));
72            except.right = isolate_table_selects(right, schema, _dialect);
73            Expression::Except(except)
74        }
75        other => other,
76    }
77}
78
79/// Process a single SELECT statement, wrapping bare table references in
80/// subqueries when multiple sources are present.
81fn isolate_select(mut select: Select, schema: Option<&dyn Schema>) -> Select {
82    // First, recursively process CTEs
83    if let Some(ref mut with) = select.with {
84        for cte in &mut with.ctes {
85            cte.this = isolate_table_selects(cte.this.clone(), schema, None);
86        }
87    }
88
89    // Recursively process subqueries in FROM and JOINs
90    if let Some(ref mut from) = select.from {
91        for expr in &mut from.expressions {
92            if let Expression::Subquery(ref mut sq) = expr {
93                sq.this = isolate_table_selects(sq.this.clone(), schema, None);
94            }
95        }
96    }
97    for join in &mut select.joins {
98        if let Expression::Subquery(ref mut sq) = join.this {
99            sq.this = isolate_table_selects(sq.this.clone(), schema, None);
100        }
101    }
102
103    // Count the total number of sources (FROM expressions + JOINs)
104    let source_count = count_sources(&select);
105
106    // Only isolate when there are multiple sources
107    if source_count <= 1 {
108        return select;
109    }
110
111    // Wrap bare table references in FROM clause
112    if let Some(ref mut from) = select.from {
113        from.expressions = from
114            .expressions
115            .drain(..)
116            .map(|expr| maybe_wrap_table(expr, schema))
117            .collect();
118    }
119
120    // Wrap bare table references in JOINs
121    for join in &mut select.joins {
122        join.this = maybe_wrap_table(join.this.clone(), schema);
123    }
124
125    select
126}
127
128/// Count the total number of source tables/subqueries in a SELECT.
129///
130/// This counts FROM expressions plus JOINs.
131fn count_sources(select: &Select) -> usize {
132    let from_count = select
133        .from
134        .as_ref()
135        .map(|f| f.expressions.len())
136        .unwrap_or(0);
137    let join_count = select.joins.len();
138    from_count + join_count
139}
140
141/// If the expression is a bare `Table` reference that should be isolated,
142/// wrap it in a `(SELECT * FROM table AS alias) AS alias` subquery.
143///
144/// A table is wrapped when:
145/// - It is an `Expression::Table` (not already a subquery)
146/// - It has an alias (required by the Python reference)
147/// - If a schema is provided, the table must have known columns in the schema
148///
149/// If no schema is provided, all aliased tables are wrapped (simplified mode).
150fn maybe_wrap_table(expression: Expression, schema: Option<&dyn Schema>) -> Expression {
151    match expression {
152        Expression::Table(ref table) => {
153            // If a schema is provided, check that the table has known columns.
154            // If we cannot find columns for the table, skip wrapping it (matching
155            // the Python behavior where `schema.column_names(source)` must be truthy).
156            if let Some(s) = schema {
157                let table_name = full_table_name(table);
158                if s.column_names(&table_name).unwrap_or_default().is_empty() {
159                    return expression;
160                }
161            }
162
163            // The table must have an alias; if it does not, we leave it as-is.
164            // The Python version raises an OptimizeError here, but in practice
165            // earlier passes (qualify_tables) ensure aliases are present.
166            let alias_name = match &table.alias {
167                Some(alias) if !alias.name.is_empty() => alias.name.clone(),
168                _ => return expression,
169            };
170
171            wrap_table_in_subquery(*table.clone(), &alias_name)
172        }
173        _ => expression,
174    }
175}
176
177/// Build `(SELECT * FROM table_ref AS alias) AS alias` from a table reference.
178///
179/// The inner table reference keeps the original alias so that
180/// `FROM t AS t` becomes `(SELECT * FROM t AS t) AS t`.
181fn wrap_table_in_subquery(table: TableRef, alias_name: &str) -> Expression {
182    // Build: SELECT * FROM <table>
183    let inner_select = Select::new()
184        .column(Expression::Star(Star {
185            table: None,
186            except: None,
187            replace: None,
188            rename: None,
189            trailing_comments: Vec::new(),
190            span: None,
191        }))
192        .from(Expression::Table(Box::new(table)));
193
194    // Wrap the SELECT in a Subquery with the original alias
195    Expression::Subquery(Box::new(Subquery {
196        this: Expression::Select(Box::new(inner_select)),
197        alias: Some(Identifier::new(alias_name)),
198        column_aliases: Vec::new(),
199        order_by: None,
200        limit: None,
201        offset: None,
202        distribute_by: None,
203        sort_by: None,
204        cluster_by: None,
205        lateral: false,
206        modifiers_inside: false,
207        trailing_comments: Vec::new(),
208        inferred_type: None,
209    }))
210}
211
212/// Construct the fully qualified table name from a `TableRef`.
213///
214/// Produces `catalog.schema.name` or `schema.name` or just `name`
215/// depending on which parts are present.
216fn full_table_name(table: &TableRef) -> String {
217    let mut parts = Vec::new();
218    if let Some(ref catalog) = table.catalog {
219        parts.push(catalog.name.as_str());
220    }
221    if let Some(ref schema) = table.schema {
222        parts.push(schema.name.as_str());
223    }
224    parts.push(&table.name.name);
225    parts.join(".")
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use crate::generator::Generator;
232    use crate::parser::Parser;
233    use crate::schema::MappingSchema;
234
235    /// Helper: parse SQL into an Expression
236    fn parse(sql: &str) -> Expression {
237        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
238    }
239
240    /// Helper: generate SQL from an Expression
241    fn gen(expr: &Expression) -> String {
242        Generator::new().generate(expr).unwrap()
243    }
244
245    // ---------------------------------------------------------------
246    // Basic: single source should NOT be wrapped
247    // ---------------------------------------------------------------
248
249    #[test]
250    fn test_single_table_unchanged() {
251        let sql = "SELECT * FROM t AS t";
252        let expr = parse(sql);
253        let result = isolate_table_selects(expr, None, None);
254        let output = gen(&result);
255        // Should remain a plain table, not wrapped in a subquery
256        assert!(
257            !output.contains("(SELECT"),
258            "Single table should not be wrapped: {output}"
259        );
260    }
261
262    #[test]
263    fn test_single_subquery_unchanged() {
264        let sql = "SELECT * FROM (SELECT 1) AS t";
265        let expr = parse(sql);
266        let result = isolate_table_selects(expr, None, None);
267        let output = gen(&result);
268        // Still just one source, no additional wrapping expected
269        assert_eq!(
270            output.matches("(SELECT").count(),
271            1,
272            "Single subquery source should not gain extra wrapping: {output}"
273        );
274    }
275
276    // ---------------------------------------------------------------
277    // Multiple sources: tables should be wrapped
278    // ---------------------------------------------------------------
279
280    #[test]
281    fn test_two_tables_joined() {
282        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
283        let expr = parse(sql);
284        let result = isolate_table_selects(expr, None, None);
285        let output = gen(&result);
286        // Both tables should now be subqueries
287        assert!(
288            output.contains("(SELECT * FROM a AS a) AS a"),
289            "FROM table should be wrapped: {output}"
290        );
291        assert!(
292            output.contains("(SELECT * FROM b AS b) AS b"),
293            "JOIN table should be wrapped: {output}"
294        );
295    }
296
297    #[test]
298    fn test_table_with_join_subquery() {
299        // If one source is already a subquery and the other is a table,
300        // only the bare table should be wrapped.
301        let sql = "SELECT * FROM a AS a JOIN (SELECT * FROM b) AS b ON a.id = b.id";
302        let expr = parse(sql);
303        let result = isolate_table_selects(expr, None, None);
304        let output = gen(&result);
305        // `a` should be wrapped
306        assert!(
307            output.contains("(SELECT * FROM a AS a) AS a"),
308            "Bare table should be wrapped: {output}"
309        );
310        // `b` is already a subquery, so it should appear once
311        // (no double-wrapping)
312        assert_eq!(
313            output.matches("(SELECT * FROM b)").count(),
314            1,
315            "Already-subquery source should not be double-wrapped: {output}"
316        );
317    }
318
319    #[test]
320    fn test_no_alias_not_wrapped() {
321        // Tables without aliases are left alone (in real pipelines,
322        // qualify_tables runs first and assigns aliases).
323        let sql = "SELECT * FROM a JOIN b ON a.id = b.id";
324        let expr = parse(sql);
325        let result = isolate_table_selects(expr, None, None);
326        let output = gen(&result);
327        // Without aliases, tables should not be wrapped
328        assert!(
329            !output.contains("(SELECT * FROM a"),
330            "Table without alias should not be wrapped: {output}"
331        );
332    }
333
334    // ---------------------------------------------------------------
335    // Schema-aware mode: only wrap tables with known columns
336    // ---------------------------------------------------------------
337
338    #[test]
339    fn test_schema_known_table_wrapped() {
340        let mut schema = MappingSchema::new();
341        schema
342            .add_table(
343                "a",
344                &[(
345                    "id".to_string(),
346                    DataType::Int {
347                        length: None,
348                        integer_spelling: false,
349                    },
350                )],
351                None,
352            )
353            .unwrap();
354        schema
355            .add_table(
356                "b",
357                &[(
358                    "id".to_string(),
359                    DataType::Int {
360                        length: None,
361                        integer_spelling: false,
362                    },
363                )],
364                None,
365            )
366            .unwrap();
367
368        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
369        let expr = parse(sql);
370        let result = isolate_table_selects(expr, Some(&schema), None);
371        let output = gen(&result);
372        assert!(
373            output.contains("(SELECT * FROM a AS a) AS a"),
374            "Known table 'a' should be wrapped: {output}"
375        );
376        assert!(
377            output.contains("(SELECT * FROM b AS b) AS b"),
378            "Known table 'b' should be wrapped: {output}"
379        );
380    }
381
382    #[test]
383    fn test_schema_unknown_table_not_wrapped() {
384        let mut schema = MappingSchema::new();
385        // Only 'a' is in the schema; 'b' is unknown
386        schema
387            .add_table(
388                "a",
389                &[(
390                    "id".to_string(),
391                    DataType::Int {
392                        length: None,
393                        integer_spelling: false,
394                    },
395                )],
396                None,
397            )
398            .unwrap();
399
400        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
401        let expr = parse(sql);
402        let result = isolate_table_selects(expr, Some(&schema), None);
403        let output = gen(&result);
404        assert!(
405            output.contains("(SELECT * FROM a AS a) AS a"),
406            "Known table 'a' should be wrapped: {output}"
407        );
408        // 'b' is not in schema, so it should remain a plain table
409        assert!(
410            !output.contains("(SELECT * FROM b AS b) AS b"),
411            "Unknown table 'b' should NOT be wrapped: {output}"
412        );
413    }
414
415    // ---------------------------------------------------------------
416    // Recursive: CTEs and nested subqueries
417    // ---------------------------------------------------------------
418
419    #[test]
420    fn test_cte_inner_query_processed() {
421        let sql =
422            "WITH cte AS (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) SELECT * FROM cte AS c";
423        let expr = parse(sql);
424        let result = isolate_table_selects(expr, None, None);
425        let output = gen(&result);
426        // Inside the CTE, x and y should be wrapped
427        assert!(
428            output.contains("(SELECT * FROM x AS x) AS x"),
429            "CTE inner table 'x' should be wrapped: {output}"
430        );
431        assert!(
432            output.contains("(SELECT * FROM y AS y) AS y"),
433            "CTE inner table 'y' should be wrapped: {output}"
434        );
435    }
436
437    #[test]
438    fn test_nested_subquery_processed() {
439        let sql = "SELECT * FROM (SELECT * FROM a AS a JOIN b AS b ON a.id = b.id) AS sub";
440        let expr = parse(sql);
441        let result = isolate_table_selects(expr, None, None);
442        let output = gen(&result);
443        // The inner SELECT has two sources; they should be wrapped
444        assert!(
445            output.contains("(SELECT * FROM a AS a) AS a"),
446            "Nested inner table 'a' should be wrapped: {output}"
447        );
448    }
449
450    // ---------------------------------------------------------------
451    // Set operations: UNION, INTERSECT, EXCEPT
452    // ---------------------------------------------------------------
453
454    #[test]
455    fn test_union_both_sides_processed() {
456        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id UNION ALL SELECT * FROM c AS c";
457        let expr = parse(sql);
458        let result = isolate_table_selects(expr, None, None);
459        let output = gen(&result);
460        // Left side has two sources - should be wrapped
461        assert!(
462            output.contains("(SELECT * FROM a AS a) AS a"),
463            "UNION left side should be processed: {output}"
464        );
465        // Right side has only one source - should NOT be wrapped
466        assert!(
467            !output.contains("(SELECT * FROM c AS c) AS c"),
468            "UNION right side (single source) should not be wrapped: {output}"
469        );
470    }
471
472    // ---------------------------------------------------------------
473    // Edge cases
474    // ---------------------------------------------------------------
475
476    #[test]
477    fn test_cross_join() {
478        let sql = "SELECT * FROM a AS a CROSS JOIN b AS b";
479        let expr = parse(sql);
480        let result = isolate_table_selects(expr, None, None);
481        let output = gen(&result);
482        assert!(
483            output.contains("(SELECT * FROM a AS a) AS a"),
484            "CROSS JOIN table 'a' should be wrapped: {output}"
485        );
486        assert!(
487            output.contains("(SELECT * FROM b AS b) AS b"),
488            "CROSS JOIN table 'b' should be wrapped: {output}"
489        );
490    }
491
492    #[test]
493    fn test_multiple_from_tables() {
494        // Comma-separated FROM (implicit cross join)
495        let sql = "SELECT * FROM a AS a, b AS b";
496        let expr = parse(sql);
497        let result = isolate_table_selects(expr, None, None);
498        let output = gen(&result);
499        assert!(
500            output.contains("(SELECT * FROM a AS a) AS a"),
501            "Comma-join table 'a' should be wrapped: {output}"
502        );
503        assert!(
504            output.contains("(SELECT * FROM b AS b) AS b"),
505            "Comma-join table 'b' should be wrapped: {output}"
506        );
507    }
508
509    #[test]
510    fn test_three_way_join() {
511        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id JOIN c AS c ON b.id = c.id";
512        let expr = parse(sql);
513        let result = isolate_table_selects(expr, None, None);
514        let output = gen(&result);
515        assert!(
516            output.contains("(SELECT * FROM a AS a) AS a"),
517            "Three-way join: 'a' should be wrapped: {output}"
518        );
519        assert!(
520            output.contains("(SELECT * FROM b AS b) AS b"),
521            "Three-way join: 'b' should be wrapped: {output}"
522        );
523        assert!(
524            output.contains("(SELECT * FROM c AS c) AS c"),
525            "Three-way join: 'c' should be wrapped: {output}"
526        );
527    }
528
529    #[test]
530    fn test_qualified_table_name_with_schema() {
531        let mut schema = MappingSchema::new();
532        schema
533            .add_table(
534                "mydb.a",
535                &[(
536                    "id".to_string(),
537                    DataType::Int {
538                        length: None,
539                        integer_spelling: false,
540                    },
541                )],
542                None,
543            )
544            .unwrap();
545        schema
546            .add_table(
547                "mydb.b",
548                &[(
549                    "id".to_string(),
550                    DataType::Int {
551                        length: None,
552                        integer_spelling: false,
553                    },
554                )],
555                None,
556            )
557            .unwrap();
558
559        let sql = "SELECT * FROM mydb.a AS a JOIN mydb.b AS b ON a.id = b.id";
560        let expr = parse(sql);
561        let result = isolate_table_selects(expr, Some(&schema), None);
562        let output = gen(&result);
563        assert!(
564            output.contains("(SELECT * FROM mydb.a AS a) AS a"),
565            "Qualified table 'mydb.a' should be wrapped: {output}"
566        );
567        assert!(
568            output.contains("(SELECT * FROM mydb.b AS b) AS b"),
569            "Qualified table 'mydb.b' should be wrapped: {output}"
570        );
571    }
572
573    #[test]
574    fn test_non_select_expression_unchanged() {
575        // Non-SELECT expressions (e.g., INSERT, CREATE) pass through unchanged
576        let sql = "INSERT INTO t VALUES (1)";
577        let expr = parse(sql);
578        let original = gen(&expr);
579        let result = isolate_table_selects(expr, None, None);
580        let output = gen(&result);
581        assert_eq!(original, output, "Non-SELECT should be unchanged");
582    }
583}