Skip to main content

polyglot_sql/optimizer/
isolate_table_selects.rs

1//! Isolate Table Selects Optimization Pass
2//!
3//! This module wraps plain table references in subqueries (`SELECT * FROM table`)
4//! when multiple tables are present in a scope. This normalization is needed for
5//! other optimizations (like merge_subqueries) to work correctly, since they
6//! expect each source in a multi-table query to be a subquery rather than a bare
7//! table reference.
8//!
9//! Ported from sqlglot's optimizer/isolate_table_selects.py
10
11use crate::dialects::DialectType;
12use crate::expressions::*;
13use crate::schema::Schema;
14
15/// Error type for the isolate_table_selects pass
16#[derive(Debug, Clone, thiserror::Error)]
17pub enum IsolateTableSelectsError {
18    #[error("Tables require an alias: {0}")]
19    MissingAlias(String),
20}
21
22/// Wrap plain table references in subqueries when multiple sources are present.
23///
24/// When a SELECT has multiple sources (FROM + JOINs, or multiple FROM tables),
25/// each bare `Table` reference is replaced with:
26///
27/// ```sql
28/// (SELECT * FROM table AS alias) AS alias
29/// ```
30///
31/// This makes every source a subquery, which simplifies downstream
32/// optimizations such as `merge_subqueries`.
33///
34/// # Arguments
35///
36/// * `expression` - The SQL expression tree to transform
37/// * `schema` - Optional schema for looking up column names (used to skip
38///   tables whose columns are unknown, matching the Python behavior)
39/// * `_dialect` - Optional dialect (reserved for future use)
40///
41/// # Returns
42///
43/// The transformed expression with isolated table selects
44pub fn isolate_table_selects(
45    expression: Expression,
46    schema: Option<&dyn Schema>,
47    _dialect: Option<DialectType>,
48) -> Expression {
49    match expression {
50        Expression::Select(select) => {
51            let transformed = isolate_select(*select, schema);
52            Expression::Select(Box::new(transformed))
53        }
54        Expression::Union(mut union) => {
55            let left = std::mem::replace(&mut union.left, Expression::Null(Null));
56            union.left = isolate_table_selects(left, schema, _dialect);
57            let right = std::mem::replace(&mut union.right, Expression::Null(Null));
58            union.right = isolate_table_selects(right, schema, _dialect);
59            Expression::Union(union)
60        }
61        Expression::Intersect(mut intersect) => {
62            let left = std::mem::replace(&mut intersect.left, Expression::Null(Null));
63            intersect.left = isolate_table_selects(left, schema, _dialect);
64            let right = std::mem::replace(&mut intersect.right, Expression::Null(Null));
65            intersect.right = isolate_table_selects(right, schema, _dialect);
66            Expression::Intersect(intersect)
67        }
68        Expression::Except(mut except) => {
69            let left = std::mem::replace(&mut except.left, Expression::Null(Null));
70            except.left = isolate_table_selects(left, schema, _dialect);
71            let right = std::mem::replace(&mut except.right, Expression::Null(Null));
72            except.right = isolate_table_selects(right, schema, _dialect);
73            Expression::Except(except)
74        }
75        other => other,
76    }
77}
78
79/// Process a single SELECT statement, wrapping bare table references in
80/// subqueries when multiple sources are present.
81fn isolate_select(mut select: Select, schema: Option<&dyn Schema>) -> Select {
82    // First, recursively process CTEs
83    if let Some(ref mut with) = select.with {
84        for cte in &mut with.ctes {
85            cte.this = isolate_table_selects(cte.this.clone(), schema, None);
86        }
87    }
88
89    // Recursively process subqueries in FROM and JOINs
90    if let Some(ref mut from) = select.from {
91        for expr in &mut from.expressions {
92            if let Expression::Subquery(ref mut sq) = expr {
93                sq.this = isolate_table_selects(sq.this.clone(), schema, None);
94            }
95        }
96    }
97    for join in &mut select.joins {
98        if let Expression::Subquery(ref mut sq) = join.this {
99            sq.this = isolate_table_selects(sq.this.clone(), schema, None);
100        }
101    }
102
103    // Count the total number of sources (FROM expressions + JOINs)
104    let source_count = count_sources(&select);
105
106    // Only isolate when there are multiple sources
107    if source_count <= 1 {
108        return select;
109    }
110
111    // Wrap bare table references in FROM clause
112    if let Some(ref mut from) = select.from {
113        from.expressions = from
114            .expressions
115            .drain(..)
116            .map(|expr| maybe_wrap_table(expr, schema))
117            .collect();
118    }
119
120    // Wrap bare table references in JOINs
121    for join in &mut select.joins {
122        join.this = maybe_wrap_table(join.this.clone(), schema);
123    }
124
125    select
126}
127
128/// Count the total number of source tables/subqueries in a SELECT.
129///
130/// This counts FROM expressions plus JOINs.
131fn count_sources(select: &Select) -> usize {
132    let from_count = select
133        .from
134        .as_ref()
135        .map(|f| f.expressions.len())
136        .unwrap_or(0);
137    let join_count = select.joins.len();
138    from_count + join_count
139}
140
141/// If the expression is a bare `Table` reference that should be isolated,
142/// wrap it in a `(SELECT * FROM table AS alias) AS alias` subquery.
143///
144/// A table is wrapped when:
145/// - It is an `Expression::Table` (not already a subquery)
146/// - It has an alias (required by the Python reference)
147/// - If a schema is provided, the table must have known columns in the schema
148///
149/// If no schema is provided, all aliased tables are wrapped (simplified mode).
150fn maybe_wrap_table(expression: Expression, schema: Option<&dyn Schema>) -> Expression {
151    match expression {
152        Expression::Table(ref table) => {
153            // If a schema is provided, check that the table has known columns.
154            // If we cannot find columns for the table, skip wrapping it (matching
155            // the Python behavior where `schema.column_names(source)` must be truthy).
156            if let Some(s) = schema {
157                let table_name = full_table_name(table);
158                if s.column_names(&table_name).unwrap_or_default().is_empty() {
159                    return expression;
160                }
161            }
162
163            // The table must have an alias; if it does not, we leave it as-is.
164            // The Python version raises an OptimizeError here, but in practice
165            // earlier passes (qualify_tables) ensure aliases are present.
166            let alias_name = match &table.alias {
167                Some(alias) if !alias.name.is_empty() => alias.name.clone(),
168                _ => return expression,
169            };
170
171            wrap_table_in_subquery(*table.clone(), &alias_name)
172        }
173        _ => expression,
174    }
175}
176
177/// Build `(SELECT * FROM table_ref AS alias) AS alias` from a table reference.
178///
179/// The inner table reference keeps the original alias so that
180/// `FROM t AS t` becomes `(SELECT * FROM t AS t) AS t`.
181fn wrap_table_in_subquery(table: TableRef, alias_name: &str) -> Expression {
182    // Build: SELECT * FROM <table>
183    let inner_select = Select::new()
184        .column(Expression::Star(Star {
185            table: None,
186            except: None,
187            replace: None,
188            rename: None,
189            trailing_comments: Vec::new(),
190            span: None,
191        }))
192        .from(Expression::Table(Box::new(table)));
193
194    // Wrap the SELECT in a Subquery with the original alias
195    Expression::Subquery(Box::new(Subquery {
196        this: Expression::Select(Box::new(inner_select)),
197        alias: Some(Identifier::new(alias_name)),
198        column_aliases: Vec::new(),
199        alias_explicit_as: false,
200        alias_keyword: None,
201        order_by: None,
202        limit: None,
203        offset: None,
204        distribute_by: None,
205        sort_by: None,
206        cluster_by: None,
207        lateral: false,
208        modifiers_inside: false,
209        trailing_comments: Vec::new(),
210        inferred_type: None,
211    }))
212}
213
214/// Construct the fully qualified table name from a `TableRef`.
215///
216/// Produces `catalog.schema.name` or `schema.name` or just `name`
217/// depending on which parts are present.
218fn full_table_name(table: &TableRef) -> String {
219    let mut parts = Vec::new();
220    if let Some(ref catalog) = table.catalog {
221        parts.push(catalog.name.as_str());
222    }
223    if let Some(ref schema) = table.schema {
224        parts.push(schema.name.as_str());
225    }
226    parts.push(&table.name.name);
227    parts.join(".")
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use crate::generator::Generator;
234    use crate::parser::Parser;
235    use crate::schema::MappingSchema;
236
237    /// Helper: parse SQL into an Expression
238    fn parse(sql: &str) -> Expression {
239        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
240    }
241
242    /// Helper: generate SQL from an Expression
243    fn gen(expr: &Expression) -> String {
244        Generator::new().generate(expr).unwrap()
245    }
246
247    // ---------------------------------------------------------------
248    // Basic: single source should NOT be wrapped
249    // ---------------------------------------------------------------
250
251    #[test]
252    fn test_single_table_unchanged() {
253        let sql = "SELECT * FROM t AS t";
254        let expr = parse(sql);
255        let result = isolate_table_selects(expr, None, None);
256        let output = gen(&result);
257        // Should remain a plain table, not wrapped in a subquery
258        assert!(
259            !output.contains("(SELECT"),
260            "Single table should not be wrapped: {output}"
261        );
262    }
263
264    #[test]
265    fn test_single_subquery_unchanged() {
266        let sql = "SELECT * FROM (SELECT 1) AS t";
267        let expr = parse(sql);
268        let result = isolate_table_selects(expr, None, None);
269        let output = gen(&result);
270        // Still just one source, no additional wrapping expected
271        assert_eq!(
272            output.matches("(SELECT").count(),
273            1,
274            "Single subquery source should not gain extra wrapping: {output}"
275        );
276    }
277
278    // ---------------------------------------------------------------
279    // Multiple sources: tables should be wrapped
280    // ---------------------------------------------------------------
281
282    #[test]
283    fn test_two_tables_joined() {
284        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
285        let expr = parse(sql);
286        let result = isolate_table_selects(expr, None, None);
287        let output = gen(&result);
288        // Both tables should now be subqueries
289        assert!(
290            output.contains("(SELECT * FROM a AS a) AS a"),
291            "FROM table should be wrapped: {output}"
292        );
293        assert!(
294            output.contains("(SELECT * FROM b AS b) AS b"),
295            "JOIN table should be wrapped: {output}"
296        );
297    }
298
299    #[test]
300    fn test_table_with_join_subquery() {
301        // If one source is already a subquery and the other is a table,
302        // only the bare table should be wrapped.
303        let sql = "SELECT * FROM a AS a JOIN (SELECT * FROM b) AS b ON a.id = b.id";
304        let expr = parse(sql);
305        let result = isolate_table_selects(expr, None, None);
306        let output = gen(&result);
307        // `a` should be wrapped
308        assert!(
309            output.contains("(SELECT * FROM a AS a) AS a"),
310            "Bare table should be wrapped: {output}"
311        );
312        // `b` is already a subquery, so it should appear once
313        // (no double-wrapping)
314        assert_eq!(
315            output.matches("(SELECT * FROM b)").count(),
316            1,
317            "Already-subquery source should not be double-wrapped: {output}"
318        );
319    }
320
321    #[test]
322    fn test_no_alias_not_wrapped() {
323        // Tables without aliases are left alone (in real pipelines,
324        // qualify_tables runs first and assigns aliases).
325        let sql = "SELECT * FROM a JOIN b ON a.id = b.id";
326        let expr = parse(sql);
327        let result = isolate_table_selects(expr, None, None);
328        let output = gen(&result);
329        // Without aliases, tables should not be wrapped
330        assert!(
331            !output.contains("(SELECT * FROM a"),
332            "Table without alias should not be wrapped: {output}"
333        );
334    }
335
336    // ---------------------------------------------------------------
337    // Schema-aware mode: only wrap tables with known columns
338    // ---------------------------------------------------------------
339
340    #[test]
341    fn test_schema_known_table_wrapped() {
342        let mut schema = MappingSchema::new();
343        schema
344            .add_table(
345                "a",
346                &[(
347                    "id".to_string(),
348                    DataType::Int {
349                        length: None,
350                        integer_spelling: false,
351                    },
352                )],
353                None,
354            )
355            .unwrap();
356        schema
357            .add_table(
358                "b",
359                &[(
360                    "id".to_string(),
361                    DataType::Int {
362                        length: None,
363                        integer_spelling: false,
364                    },
365                )],
366                None,
367            )
368            .unwrap();
369
370        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
371        let expr = parse(sql);
372        let result = isolate_table_selects(expr, Some(&schema), None);
373        let output = gen(&result);
374        assert!(
375            output.contains("(SELECT * FROM a AS a) AS a"),
376            "Known table 'a' should be wrapped: {output}"
377        );
378        assert!(
379            output.contains("(SELECT * FROM b AS b) AS b"),
380            "Known table 'b' should be wrapped: {output}"
381        );
382    }
383
384    #[test]
385    fn test_schema_unknown_table_not_wrapped() {
386        let mut schema = MappingSchema::new();
387        // Only 'a' is in the schema; 'b' is unknown
388        schema
389            .add_table(
390                "a",
391                &[(
392                    "id".to_string(),
393                    DataType::Int {
394                        length: None,
395                        integer_spelling: false,
396                    },
397                )],
398                None,
399            )
400            .unwrap();
401
402        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
403        let expr = parse(sql);
404        let result = isolate_table_selects(expr, Some(&schema), None);
405        let output = gen(&result);
406        assert!(
407            output.contains("(SELECT * FROM a AS a) AS a"),
408            "Known table 'a' should be wrapped: {output}"
409        );
410        // 'b' is not in schema, so it should remain a plain table
411        assert!(
412            !output.contains("(SELECT * FROM b AS b) AS b"),
413            "Unknown table 'b' should NOT be wrapped: {output}"
414        );
415    }
416
417    // ---------------------------------------------------------------
418    // Recursive: CTEs and nested subqueries
419    // ---------------------------------------------------------------
420
421    #[test]
422    fn test_cte_inner_query_processed() {
423        let sql =
424            "WITH cte AS (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) SELECT * FROM cte AS c";
425        let expr = parse(sql);
426        let result = isolate_table_selects(expr, None, None);
427        let output = gen(&result);
428        // Inside the CTE, x and y should be wrapped
429        assert!(
430            output.contains("(SELECT * FROM x AS x) AS x"),
431            "CTE inner table 'x' should be wrapped: {output}"
432        );
433        assert!(
434            output.contains("(SELECT * FROM y AS y) AS y"),
435            "CTE inner table 'y' should be wrapped: {output}"
436        );
437    }
438
439    #[test]
440    fn test_nested_subquery_processed() {
441        let sql = "SELECT * FROM (SELECT * FROM a AS a JOIN b AS b ON a.id = b.id) AS sub";
442        let expr = parse(sql);
443        let result = isolate_table_selects(expr, None, None);
444        let output = gen(&result);
445        // The inner SELECT has two sources; they should be wrapped
446        assert!(
447            output.contains("(SELECT * FROM a AS a) AS a"),
448            "Nested inner table 'a' should be wrapped: {output}"
449        );
450    }
451
452    // ---------------------------------------------------------------
453    // Set operations: UNION, INTERSECT, EXCEPT
454    // ---------------------------------------------------------------
455
456    #[test]
457    fn test_union_both_sides_processed() {
458        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id UNION ALL SELECT * FROM c AS c";
459        let expr = parse(sql);
460        let result = isolate_table_selects(expr, None, None);
461        let output = gen(&result);
462        // Left side has two sources - should be wrapped
463        assert!(
464            output.contains("(SELECT * FROM a AS a) AS a"),
465            "UNION left side should be processed: {output}"
466        );
467        // Right side has only one source - should NOT be wrapped
468        assert!(
469            !output.contains("(SELECT * FROM c AS c) AS c"),
470            "UNION right side (single source) should not be wrapped: {output}"
471        );
472    }
473
474    // ---------------------------------------------------------------
475    // Edge cases
476    // ---------------------------------------------------------------
477
478    #[test]
479    fn test_cross_join() {
480        let sql = "SELECT * FROM a AS a CROSS JOIN b AS b";
481        let expr = parse(sql);
482        let result = isolate_table_selects(expr, None, None);
483        let output = gen(&result);
484        assert!(
485            output.contains("(SELECT * FROM a AS a) AS a"),
486            "CROSS JOIN table 'a' should be wrapped: {output}"
487        );
488        assert!(
489            output.contains("(SELECT * FROM b AS b) AS b"),
490            "CROSS JOIN table 'b' should be wrapped: {output}"
491        );
492    }
493
494    #[test]
495    fn test_multiple_from_tables() {
496        // Comma-separated FROM (implicit cross join)
497        let sql = "SELECT * FROM a AS a, b AS b";
498        let expr = parse(sql);
499        let result = isolate_table_selects(expr, None, None);
500        let output = gen(&result);
501        assert!(
502            output.contains("(SELECT * FROM a AS a) AS a"),
503            "Comma-join table 'a' should be wrapped: {output}"
504        );
505        assert!(
506            output.contains("(SELECT * FROM b AS b) AS b"),
507            "Comma-join table 'b' should be wrapped: {output}"
508        );
509    }
510
511    #[test]
512    fn test_three_way_join() {
513        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id JOIN c AS c ON b.id = c.id";
514        let expr = parse(sql);
515        let result = isolate_table_selects(expr, None, None);
516        let output = gen(&result);
517        assert!(
518            output.contains("(SELECT * FROM a AS a) AS a"),
519            "Three-way join: 'a' should be wrapped: {output}"
520        );
521        assert!(
522            output.contains("(SELECT * FROM b AS b) AS b"),
523            "Three-way join: 'b' should be wrapped: {output}"
524        );
525        assert!(
526            output.contains("(SELECT * FROM c AS c) AS c"),
527            "Three-way join: 'c' should be wrapped: {output}"
528        );
529    }
530
531    #[test]
532    fn test_qualified_table_name_with_schema() {
533        let mut schema = MappingSchema::new();
534        schema
535            .add_table(
536                "mydb.a",
537                &[(
538                    "id".to_string(),
539                    DataType::Int {
540                        length: None,
541                        integer_spelling: false,
542                    },
543                )],
544                None,
545            )
546            .unwrap();
547        schema
548            .add_table(
549                "mydb.b",
550                &[(
551                    "id".to_string(),
552                    DataType::Int {
553                        length: None,
554                        integer_spelling: false,
555                    },
556                )],
557                None,
558            )
559            .unwrap();
560
561        let sql = "SELECT * FROM mydb.a AS a JOIN mydb.b AS b ON a.id = b.id";
562        let expr = parse(sql);
563        let result = isolate_table_selects(expr, Some(&schema), None);
564        let output = gen(&result);
565        assert!(
566            output.contains("(SELECT * FROM mydb.a AS a) AS a"),
567            "Qualified table 'mydb.a' should be wrapped: {output}"
568        );
569        assert!(
570            output.contains("(SELECT * FROM mydb.b AS b) AS b"),
571            "Qualified table 'mydb.b' should be wrapped: {output}"
572        );
573    }
574
575    #[test]
576    fn test_non_select_expression_unchanged() {
577        // Non-SELECT expressions (e.g., INSERT, CREATE) pass through unchanged
578        let sql = "INSERT INTO t VALUES (1)";
579        let expr = parse(sql);
580        let original = gen(&expr);
581        let result = isolate_table_selects(expr, None, None);
582        let output = gen(&result);
583        assert_eq!(original, output, "Non-SELECT should be unchanged");
584    }
585}