Skip to main content

polyglot_sql/optimizer/
isolate_table_selects.rs

1//! Isolate Table Selects Optimization Pass
2//!
3//! This module wraps plain table references in subqueries (`SELECT * FROM table`)
4//! when multiple tables are present in a scope. This normalization is needed for
5//! other optimizations (like merge_subqueries) to work correctly, since they
6//! expect each source in a multi-table query to be a subquery rather than a bare
7//! table reference.
8//!
9//! Ported from sqlglot's optimizer/isolate_table_selects.py
10
11use crate::dialects::DialectType;
12use crate::expressions::*;
13use crate::schema::Schema;
14
15/// Error type for the isolate_table_selects pass
16#[derive(Debug, Clone, thiserror::Error)]
17pub enum IsolateTableSelectsError {
18    #[error("Tables require an alias: {0}")]
19    MissingAlias(String),
20}
21
22/// Wrap plain table references in subqueries when multiple sources are present.
23///
24/// When a SELECT has multiple sources (FROM + JOINs, or multiple FROM tables),
25/// each bare `Table` reference is replaced with:
26///
27/// ```sql
28/// (SELECT * FROM table AS alias) AS alias
29/// ```
30///
31/// This makes every source a subquery, which simplifies downstream
32/// optimizations such as `merge_subqueries`.
33///
34/// # Arguments
35///
36/// * `expression` - The SQL expression tree to transform
37/// * `schema` - Optional schema for looking up column names (used to skip
38///   tables whose columns are unknown, matching the Python behavior)
39/// * `_dialect` - Optional dialect (reserved for future use)
40///
41/// # Returns
42///
43/// The transformed expression with isolated table selects
44pub fn isolate_table_selects(
45    expression: Expression,
46    schema: Option<&dyn Schema>,
47    _dialect: Option<DialectType>,
48) -> Expression {
49    match expression {
50        Expression::Select(select) => {
51            let transformed = isolate_select(*select, schema);
52            Expression::Select(Box::new(transformed))
53        }
54        Expression::Union(mut union) => {
55            union.left = isolate_table_selects(union.left, schema, _dialect);
56            union.right = isolate_table_selects(union.right, schema, _dialect);
57            Expression::Union(union)
58        }
59        Expression::Intersect(mut intersect) => {
60            intersect.left = isolate_table_selects(intersect.left, schema, _dialect);
61            intersect.right = isolate_table_selects(intersect.right, schema, _dialect);
62            Expression::Intersect(intersect)
63        }
64        Expression::Except(mut except) => {
65            except.left = isolate_table_selects(except.left, schema, _dialect);
66            except.right = isolate_table_selects(except.right, schema, _dialect);
67            Expression::Except(except)
68        }
69        other => other,
70    }
71}
72
73/// Process a single SELECT statement, wrapping bare table references in
74/// subqueries when multiple sources are present.
75fn isolate_select(mut select: Select, schema: Option<&dyn Schema>) -> Select {
76    // First, recursively process CTEs
77    if let Some(ref mut with) = select.with {
78        for cte in &mut with.ctes {
79            cte.this = isolate_table_selects(cte.this.clone(), schema, None);
80        }
81    }
82
83    // Recursively process subqueries in FROM and JOINs
84    if let Some(ref mut from) = select.from {
85        for expr in &mut from.expressions {
86            if let Expression::Subquery(ref mut sq) = expr {
87                sq.this = isolate_table_selects(sq.this.clone(), schema, None);
88            }
89        }
90    }
91    for join in &mut select.joins {
92        if let Expression::Subquery(ref mut sq) = join.this {
93            sq.this = isolate_table_selects(sq.this.clone(), schema, None);
94        }
95    }
96
97    // Count the total number of sources (FROM expressions + JOINs)
98    let source_count = count_sources(&select);
99
100    // Only isolate when there are multiple sources
101    if source_count <= 1 {
102        return select;
103    }
104
105    // Wrap bare table references in FROM clause
106    if let Some(ref mut from) = select.from {
107        from.expressions = from
108            .expressions
109            .drain(..)
110            .map(|expr| maybe_wrap_table(expr, schema))
111            .collect();
112    }
113
114    // Wrap bare table references in JOINs
115    for join in &mut select.joins {
116        join.this = maybe_wrap_table(join.this.clone(), schema);
117    }
118
119    select
120}
121
122/// Count the total number of source tables/subqueries in a SELECT.
123///
124/// This counts FROM expressions plus JOINs.
125fn count_sources(select: &Select) -> usize {
126    let from_count = select
127        .from
128        .as_ref()
129        .map(|f| f.expressions.len())
130        .unwrap_or(0);
131    let join_count = select.joins.len();
132    from_count + join_count
133}
134
135/// If the expression is a bare `Table` reference that should be isolated,
136/// wrap it in a `(SELECT * FROM table AS alias) AS alias` subquery.
137///
138/// A table is wrapped when:
139/// - It is an `Expression::Table` (not already a subquery)
140/// - It has an alias (required by the Python reference)
141/// - If a schema is provided, the table must have known columns in the schema
142///
143/// If no schema is provided, all aliased tables are wrapped (simplified mode).
144fn maybe_wrap_table(expression: Expression, schema: Option<&dyn Schema>) -> Expression {
145    match expression {
146        Expression::Table(ref table) => {
147            // If a schema is provided, check that the table has known columns.
148            // If we cannot find columns for the table, skip wrapping it (matching
149            // the Python behavior where `schema.column_names(source)` must be truthy).
150            if let Some(s) = schema {
151                let table_name = full_table_name(table);
152                if s.column_names(&table_name).unwrap_or_default().is_empty() {
153                    return expression;
154                }
155            }
156
157            // The table must have an alias; if it does not, we leave it as-is.
158            // The Python version raises an OptimizeError here, but in practice
159            // earlier passes (qualify_tables) ensure aliases are present.
160            let alias_name = match &table.alias {
161                Some(alias) if !alias.name.is_empty() => alias.name.clone(),
162                _ => return expression,
163            };
164
165            wrap_table_in_subquery(table.clone(), &alias_name)
166        }
167        _ => expression,
168    }
169}
170
171/// Build `(SELECT * FROM table_ref AS alias) AS alias` from a table reference.
172///
173/// The inner table reference keeps the original alias so that
174/// `FROM t AS t` becomes `(SELECT * FROM t AS t) AS t`.
175fn wrap_table_in_subquery(table: TableRef, alias_name: &str) -> Expression {
176    // Build: SELECT * FROM <table>
177    let inner_select = Select::new()
178        .column(Expression::Star(Star {
179            table: None,
180            except: None,
181            replace: None,
182            rename: None,
183            trailing_comments: Vec::new(),
184        }))
185        .from(Expression::Table(table));
186
187    // Wrap the SELECT in a Subquery with the original alias
188    Expression::Subquery(Box::new(Subquery {
189        this: Expression::Select(Box::new(inner_select)),
190        alias: Some(Identifier::new(alias_name)),
191        column_aliases: Vec::new(),
192        order_by: None,
193        limit: None,
194        offset: None,
195        distribute_by: None,
196        sort_by: None,
197        cluster_by: None,
198        lateral: false,
199        modifiers_inside: false,
200        trailing_comments: Vec::new(),
201    }))
202}
203
204/// Construct the fully qualified table name from a `TableRef`.
205///
206/// Produces `catalog.schema.name` or `schema.name` or just `name`
207/// depending on which parts are present.
208fn full_table_name(table: &TableRef) -> String {
209    let mut parts = Vec::new();
210    if let Some(ref catalog) = table.catalog {
211        parts.push(catalog.name.as_str());
212    }
213    if let Some(ref schema) = table.schema {
214        parts.push(schema.name.as_str());
215    }
216    parts.push(&table.name.name);
217    parts.join(".")
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use crate::generator::Generator;
224    use crate::parser::Parser;
225    use crate::schema::MappingSchema;
226
227    /// Helper: parse SQL into an Expression
228    fn parse(sql: &str) -> Expression {
229        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
230    }
231
232    /// Helper: generate SQL from an Expression
233    fn gen(expr: &Expression) -> String {
234        Generator::new().generate(expr).unwrap()
235    }
236
237    // ---------------------------------------------------------------
238    // Basic: single source should NOT be wrapped
239    // ---------------------------------------------------------------
240
241    #[test]
242    fn test_single_table_unchanged() {
243        let sql = "SELECT * FROM t AS t";
244        let expr = parse(sql);
245        let result = isolate_table_selects(expr, None, None);
246        let output = gen(&result);
247        // Should remain a plain table, not wrapped in a subquery
248        assert!(
249            !output.contains("(SELECT"),
250            "Single table should not be wrapped: {output}"
251        );
252    }
253
254    #[test]
255    fn test_single_subquery_unchanged() {
256        let sql = "SELECT * FROM (SELECT 1) AS t";
257        let expr = parse(sql);
258        let result = isolate_table_selects(expr, None, None);
259        let output = gen(&result);
260        // Still just one source, no additional wrapping expected
261        assert_eq!(
262            output.matches("(SELECT").count(),
263            1,
264            "Single subquery source should not gain extra wrapping: {output}"
265        );
266    }
267
268    // ---------------------------------------------------------------
269    // Multiple sources: tables should be wrapped
270    // ---------------------------------------------------------------
271
272    #[test]
273    fn test_two_tables_joined() {
274        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
275        let expr = parse(sql);
276        let result = isolate_table_selects(expr, None, None);
277        let output = gen(&result);
278        // Both tables should now be subqueries
279        assert!(
280            output.contains("(SELECT * FROM a AS a) AS a"),
281            "FROM table should be wrapped: {output}"
282        );
283        assert!(
284            output.contains("(SELECT * FROM b AS b) AS b"),
285            "JOIN table should be wrapped: {output}"
286        );
287    }
288
289    #[test]
290    fn test_table_with_join_subquery() {
291        // If one source is already a subquery and the other is a table,
292        // only the bare table should be wrapped.
293        let sql = "SELECT * FROM a AS a JOIN (SELECT * FROM b) AS b ON a.id = b.id";
294        let expr = parse(sql);
295        let result = isolate_table_selects(expr, None, None);
296        let output = gen(&result);
297        // `a` should be wrapped
298        assert!(
299            output.contains("(SELECT * FROM a AS a) AS a"),
300            "Bare table should be wrapped: {output}"
301        );
302        // `b` is already a subquery, so it should appear once
303        // (no double-wrapping)
304        assert_eq!(
305            output.matches("(SELECT * FROM b)").count(),
306            1,
307            "Already-subquery source should not be double-wrapped: {output}"
308        );
309    }
310
311    #[test]
312    fn test_no_alias_not_wrapped() {
313        // Tables without aliases are left alone (in real pipelines,
314        // qualify_tables runs first and assigns aliases).
315        let sql = "SELECT * FROM a JOIN b ON a.id = b.id";
316        let expr = parse(sql);
317        let result = isolate_table_selects(expr, None, None);
318        let output = gen(&result);
319        // Without aliases, tables should not be wrapped
320        assert!(
321            !output.contains("(SELECT * FROM a"),
322            "Table without alias should not be wrapped: {output}"
323        );
324    }
325
326    // ---------------------------------------------------------------
327    // Schema-aware mode: only wrap tables with known columns
328    // ---------------------------------------------------------------
329
330    #[test]
331    fn test_schema_known_table_wrapped() {
332        let mut schema = MappingSchema::new();
333        schema
334            .add_table(
335                "a",
336                &[(
337                    "id".to_string(),
338                    DataType::Int {
339                        length: None,
340                        integer_spelling: false,
341                    },
342                )],
343                None,
344            )
345            .unwrap();
346        schema
347            .add_table(
348                "b",
349                &[(
350                    "id".to_string(),
351                    DataType::Int {
352                        length: None,
353                        integer_spelling: false,
354                    },
355                )],
356                None,
357            )
358            .unwrap();
359
360        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
361        let expr = parse(sql);
362        let result = isolate_table_selects(expr, Some(&schema), None);
363        let output = gen(&result);
364        assert!(
365            output.contains("(SELECT * FROM a AS a) AS a"),
366            "Known table 'a' should be wrapped: {output}"
367        );
368        assert!(
369            output.contains("(SELECT * FROM b AS b) AS b"),
370            "Known table 'b' should be wrapped: {output}"
371        );
372    }
373
374    #[test]
375    fn test_schema_unknown_table_not_wrapped() {
376        let mut schema = MappingSchema::new();
377        // Only 'a' is in the schema; 'b' is unknown
378        schema
379            .add_table(
380                "a",
381                &[(
382                    "id".to_string(),
383                    DataType::Int {
384                        length: None,
385                        integer_spelling: false,
386                    },
387                )],
388                None,
389            )
390            .unwrap();
391
392        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
393        let expr = parse(sql);
394        let result = isolate_table_selects(expr, Some(&schema), None);
395        let output = gen(&result);
396        assert!(
397            output.contains("(SELECT * FROM a AS a) AS a"),
398            "Known table 'a' should be wrapped: {output}"
399        );
400        // 'b' is not in schema, so it should remain a plain table
401        assert!(
402            !output.contains("(SELECT * FROM b AS b) AS b"),
403            "Unknown table 'b' should NOT be wrapped: {output}"
404        );
405    }
406
407    // ---------------------------------------------------------------
408    // Recursive: CTEs and nested subqueries
409    // ---------------------------------------------------------------
410
411    #[test]
412    fn test_cte_inner_query_processed() {
413        let sql =
414            "WITH cte AS (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) SELECT * FROM cte AS c";
415        let expr = parse(sql);
416        let result = isolate_table_selects(expr, None, None);
417        let output = gen(&result);
418        // Inside the CTE, x and y should be wrapped
419        assert!(
420            output.contains("(SELECT * FROM x AS x) AS x"),
421            "CTE inner table 'x' should be wrapped: {output}"
422        );
423        assert!(
424            output.contains("(SELECT * FROM y AS y) AS y"),
425            "CTE inner table 'y' should be wrapped: {output}"
426        );
427    }
428
429    #[test]
430    fn test_nested_subquery_processed() {
431        let sql = "SELECT * FROM (SELECT * FROM a AS a JOIN b AS b ON a.id = b.id) AS sub";
432        let expr = parse(sql);
433        let result = isolate_table_selects(expr, None, None);
434        let output = gen(&result);
435        // The inner SELECT has two sources; they should be wrapped
436        assert!(
437            output.contains("(SELECT * FROM a AS a) AS a"),
438            "Nested inner table 'a' should be wrapped: {output}"
439        );
440    }
441
442    // ---------------------------------------------------------------
443    // Set operations: UNION, INTERSECT, EXCEPT
444    // ---------------------------------------------------------------
445
446    #[test]
447    fn test_union_both_sides_processed() {
448        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id UNION ALL SELECT * FROM c AS c";
449        let expr = parse(sql);
450        let result = isolate_table_selects(expr, None, None);
451        let output = gen(&result);
452        // Left side has two sources - should be wrapped
453        assert!(
454            output.contains("(SELECT * FROM a AS a) AS a"),
455            "UNION left side should be processed: {output}"
456        );
457        // Right side has only one source - should NOT be wrapped
458        assert!(
459            !output.contains("(SELECT * FROM c AS c) AS c"),
460            "UNION right side (single source) should not be wrapped: {output}"
461        );
462    }
463
464    // ---------------------------------------------------------------
465    // Edge cases
466    // ---------------------------------------------------------------
467
468    #[test]
469    fn test_cross_join() {
470        let sql = "SELECT * FROM a AS a CROSS JOIN b AS b";
471        let expr = parse(sql);
472        let result = isolate_table_selects(expr, None, None);
473        let output = gen(&result);
474        assert!(
475            output.contains("(SELECT * FROM a AS a) AS a"),
476            "CROSS JOIN table 'a' should be wrapped: {output}"
477        );
478        assert!(
479            output.contains("(SELECT * FROM b AS b) AS b"),
480            "CROSS JOIN table 'b' should be wrapped: {output}"
481        );
482    }
483
484    #[test]
485    fn test_multiple_from_tables() {
486        // Comma-separated FROM (implicit cross join)
487        let sql = "SELECT * FROM a AS a, b AS b";
488        let expr = parse(sql);
489        let result = isolate_table_selects(expr, None, None);
490        let output = gen(&result);
491        assert!(
492            output.contains("(SELECT * FROM a AS a) AS a"),
493            "Comma-join table 'a' should be wrapped: {output}"
494        );
495        assert!(
496            output.contains("(SELECT * FROM b AS b) AS b"),
497            "Comma-join table 'b' should be wrapped: {output}"
498        );
499    }
500
501    #[test]
502    fn test_three_way_join() {
503        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id JOIN c AS c ON b.id = c.id";
504        let expr = parse(sql);
505        let result = isolate_table_selects(expr, None, None);
506        let output = gen(&result);
507        assert!(
508            output.contains("(SELECT * FROM a AS a) AS a"),
509            "Three-way join: 'a' should be wrapped: {output}"
510        );
511        assert!(
512            output.contains("(SELECT * FROM b AS b) AS b"),
513            "Three-way join: 'b' should be wrapped: {output}"
514        );
515        assert!(
516            output.contains("(SELECT * FROM c AS c) AS c"),
517            "Three-way join: 'c' should be wrapped: {output}"
518        );
519    }
520
521    #[test]
522    fn test_qualified_table_name_with_schema() {
523        let mut schema = MappingSchema::new();
524        schema
525            .add_table(
526                "mydb.a",
527                &[(
528                    "id".to_string(),
529                    DataType::Int {
530                        length: None,
531                        integer_spelling: false,
532                    },
533                )],
534                None,
535            )
536            .unwrap();
537        schema
538            .add_table(
539                "mydb.b",
540                &[(
541                    "id".to_string(),
542                    DataType::Int {
543                        length: None,
544                        integer_spelling: false,
545                    },
546                )],
547                None,
548            )
549            .unwrap();
550
551        let sql = "SELECT * FROM mydb.a AS a JOIN mydb.b AS b ON a.id = b.id";
552        let expr = parse(sql);
553        let result = isolate_table_selects(expr, Some(&schema), None);
554        let output = gen(&result);
555        assert!(
556            output.contains("(SELECT * FROM mydb.a AS a) AS a"),
557            "Qualified table 'mydb.a' should be wrapped: {output}"
558        );
559        assert!(
560            output.contains("(SELECT * FROM mydb.b AS b) AS b"),
561            "Qualified table 'mydb.b' should be wrapped: {output}"
562        );
563    }
564
565    #[test]
566    fn test_non_select_expression_unchanged() {
567        // Non-SELECT expressions (e.g., INSERT, CREATE) pass through unchanged
568        let sql = "INSERT INTO t VALUES (1)";
569        let expr = parse(sql);
570        let original = gen(&expr);
571        let result = isolate_table_selects(expr, None, None);
572        let output = gen(&result);
573        assert_eq!(original, output, "Non-SELECT should be unchanged");
574    }
575}