Skip to main content

polyglot_sql/optimizer/
isolate_table_selects.rs

1//! Isolate Table Selects Optimization Pass
2//!
3//! This module wraps plain table references in subqueries (`SELECT * FROM table`)
4//! when multiple tables are present in a scope. This normalization is needed for
5//! other optimizations (like merge_subqueries) to work correctly, since they
6//! expect each source in a multi-table query to be a subquery rather than a bare
7//! table reference.
8//!
9//! Ported from sqlglot's optimizer/isolate_table_selects.py
10
11use crate::dialects::DialectType;
12use crate::expressions::*;
13use crate::schema::Schema;
14
15/// Error type for the isolate_table_selects pass
16#[derive(Debug, Clone, thiserror::Error)]
17pub enum IsolateTableSelectsError {
18    #[error("Tables require an alias: {0}")]
19    MissingAlias(String),
20}
21
22/// Wrap plain table references in subqueries when multiple sources are present.
23///
24/// When a SELECT has multiple sources (FROM + JOINs, or multiple FROM tables),
25/// each bare `Table` reference is replaced with:
26///
27/// ```sql
28/// (SELECT * FROM table AS alias) AS alias
29/// ```
30///
31/// This makes every source a subquery, which simplifies downstream
32/// optimizations such as `merge_subqueries`.
33///
34/// # Arguments
35///
36/// * `expression` - The SQL expression tree to transform
37/// * `schema` - Optional schema for looking up column names (used to skip
38///   tables whose columns are unknown, matching the Python behavior)
39/// * `_dialect` - Optional dialect (reserved for future use)
40///
41/// # Returns
42///
43/// The transformed expression with isolated table selects
44pub fn isolate_table_selects(
45    expression: Expression,
46    schema: Option<&dyn Schema>,
47    _dialect: Option<DialectType>,
48) -> Expression {
49    match expression {
50        Expression::Select(select) => {
51            let transformed = isolate_select(*select, schema);
52            Expression::Select(Box::new(transformed))
53        }
54        Expression::Union(mut union) => {
55            union.left = isolate_table_selects(union.left, schema, _dialect);
56            union.right = isolate_table_selects(union.right, schema, _dialect);
57            Expression::Union(union)
58        }
59        Expression::Intersect(mut intersect) => {
60            intersect.left = isolate_table_selects(intersect.left, schema, _dialect);
61            intersect.right = isolate_table_selects(intersect.right, schema, _dialect);
62            Expression::Intersect(intersect)
63        }
64        Expression::Except(mut except) => {
65            except.left = isolate_table_selects(except.left, schema, _dialect);
66            except.right = isolate_table_selects(except.right, schema, _dialect);
67            Expression::Except(except)
68        }
69        other => other,
70    }
71}
72
73/// Process a single SELECT statement, wrapping bare table references in
74/// subqueries when multiple sources are present.
75fn isolate_select(mut select: Select, schema: Option<&dyn Schema>) -> Select {
76    // First, recursively process CTEs
77    if let Some(ref mut with) = select.with {
78        for cte in &mut with.ctes {
79            cte.this = isolate_table_selects(cte.this.clone(), schema, None);
80        }
81    }
82
83    // Recursively process subqueries in FROM and JOINs
84    if let Some(ref mut from) = select.from {
85        for expr in &mut from.expressions {
86            if let Expression::Subquery(ref mut sq) = expr {
87                sq.this = isolate_table_selects(sq.this.clone(), schema, None);
88            }
89        }
90    }
91    for join in &mut select.joins {
92        if let Expression::Subquery(ref mut sq) = join.this {
93            sq.this = isolate_table_selects(sq.this.clone(), schema, None);
94        }
95    }
96
97    // Count the total number of sources (FROM expressions + JOINs)
98    let source_count = count_sources(&select);
99
100    // Only isolate when there are multiple sources
101    if source_count <= 1 {
102        return select;
103    }
104
105    // Wrap bare table references in FROM clause
106    if let Some(ref mut from) = select.from {
107        from.expressions = from
108            .expressions
109            .drain(..)
110            .map(|expr| maybe_wrap_table(expr, schema))
111            .collect();
112    }
113
114    // Wrap bare table references in JOINs
115    for join in &mut select.joins {
116        join.this = maybe_wrap_table(join.this.clone(), schema);
117    }
118
119    select
120}
121
122/// Count the total number of source tables/subqueries in a SELECT.
123///
124/// This counts FROM expressions plus JOINs.
125fn count_sources(select: &Select) -> usize {
126    let from_count = select
127        .from
128        .as_ref()
129        .map(|f| f.expressions.len())
130        .unwrap_or(0);
131    let join_count = select.joins.len();
132    from_count + join_count
133}
134
135/// If the expression is a bare `Table` reference that should be isolated,
136/// wrap it in a `(SELECT * FROM table AS alias) AS alias` subquery.
137///
138/// A table is wrapped when:
139/// - It is an `Expression::Table` (not already a subquery)
140/// - It has an alias (required by the Python reference)
141/// - If a schema is provided, the table must have known columns in the schema
142///
143/// If no schema is provided, all aliased tables are wrapped (simplified mode).
144fn maybe_wrap_table(expression: Expression, schema: Option<&dyn Schema>) -> Expression {
145    match expression {
146        Expression::Table(ref table) => {
147            // If a schema is provided, check that the table has known columns.
148            // If we cannot find columns for the table, skip wrapping it (matching
149            // the Python behavior where `schema.column_names(source)` must be truthy).
150            if let Some(s) = schema {
151                let table_name = full_table_name(table);
152                if s.column_names(&table_name).unwrap_or_default().is_empty() {
153                    return expression;
154                }
155            }
156
157            // The table must have an alias; if it does not, we leave it as-is.
158            // The Python version raises an OptimizeError here, but in practice
159            // earlier passes (qualify_tables) ensure aliases are present.
160            let alias_name = match &table.alias {
161                Some(alias) if !alias.name.is_empty() => alias.name.clone(),
162                _ => return expression,
163            };
164
165            wrap_table_in_subquery(*table.clone(), &alias_name)
166
167        }
168        _ => expression,
169    }
170}
171
172/// Build `(SELECT * FROM table_ref AS alias) AS alias` from a table reference.
173///
174/// The inner table reference keeps the original alias so that
175/// `FROM t AS t` becomes `(SELECT * FROM t AS t) AS t`.
176fn wrap_table_in_subquery(table: TableRef, alias_name: &str) -> Expression {
177    // Build: SELECT * FROM <table>
178    let inner_select = Select::new()
179        .column(Expression::Star(Star {
180            table: None,
181            except: None,
182            replace: None,
183            rename: None,
184            trailing_comments: Vec::new(),
185            span: None,
186        }))
187        .from(Expression::Table(Box::new(table)));
188
189    // Wrap the SELECT in a Subquery with the original alias
190    Expression::Subquery(Box::new(Subquery {
191        this: Expression::Select(Box::new(inner_select)),
192        alias: Some(Identifier::new(alias_name)),
193        column_aliases: Vec::new(),
194        order_by: None,
195        limit: None,
196        offset: None,
197        distribute_by: None,
198        sort_by: None,
199        cluster_by: None,
200        lateral: false,
201        modifiers_inside: false,
202        trailing_comments: Vec::new(),
203        inferred_type: None,
204    }))
205}
206
207/// Construct the fully qualified table name from a `TableRef`.
208///
209/// Produces `catalog.schema.name` or `schema.name` or just `name`
210/// depending on which parts are present.
211fn full_table_name(table: &TableRef) -> String {
212    let mut parts = Vec::new();
213    if let Some(ref catalog) = table.catalog {
214        parts.push(catalog.name.as_str());
215    }
216    if let Some(ref schema) = table.schema {
217        parts.push(schema.name.as_str());
218    }
219    parts.push(&table.name.name);
220    parts.join(".")
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226    use crate::generator::Generator;
227    use crate::parser::Parser;
228    use crate::schema::MappingSchema;
229
230    /// Helper: parse SQL into an Expression
231    fn parse(sql: &str) -> Expression {
232        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
233    }
234
235    /// Helper: generate SQL from an Expression
236    fn gen(expr: &Expression) -> String {
237        Generator::new().generate(expr).unwrap()
238    }
239
240    // ---------------------------------------------------------------
241    // Basic: single source should NOT be wrapped
242    // ---------------------------------------------------------------
243
244    #[test]
245    fn test_single_table_unchanged() {
246        let sql = "SELECT * FROM t AS t";
247        let expr = parse(sql);
248        let result = isolate_table_selects(expr, None, None);
249        let output = gen(&result);
250        // Should remain a plain table, not wrapped in a subquery
251        assert!(
252            !output.contains("(SELECT"),
253            "Single table should not be wrapped: {output}"
254        );
255    }
256
257    #[test]
258    fn test_single_subquery_unchanged() {
259        let sql = "SELECT * FROM (SELECT 1) AS t";
260        let expr = parse(sql);
261        let result = isolate_table_selects(expr, None, None);
262        let output = gen(&result);
263        // Still just one source, no additional wrapping expected
264        assert_eq!(
265            output.matches("(SELECT").count(),
266            1,
267            "Single subquery source should not gain extra wrapping: {output}"
268        );
269    }
270
271    // ---------------------------------------------------------------
272    // Multiple sources: tables should be wrapped
273    // ---------------------------------------------------------------
274
275    #[test]
276    fn test_two_tables_joined() {
277        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
278        let expr = parse(sql);
279        let result = isolate_table_selects(expr, None, None);
280        let output = gen(&result);
281        // Both tables should now be subqueries
282        assert!(
283            output.contains("(SELECT * FROM a AS a) AS a"),
284            "FROM table should be wrapped: {output}"
285        );
286        assert!(
287            output.contains("(SELECT * FROM b AS b) AS b"),
288            "JOIN table should be wrapped: {output}"
289        );
290    }
291
292    #[test]
293    fn test_table_with_join_subquery() {
294        // If one source is already a subquery and the other is a table,
295        // only the bare table should be wrapped.
296        let sql = "SELECT * FROM a AS a JOIN (SELECT * FROM b) AS b ON a.id = b.id";
297        let expr = parse(sql);
298        let result = isolate_table_selects(expr, None, None);
299        let output = gen(&result);
300        // `a` should be wrapped
301        assert!(
302            output.contains("(SELECT * FROM a AS a) AS a"),
303            "Bare table should be wrapped: {output}"
304        );
305        // `b` is already a subquery, so it should appear once
306        // (no double-wrapping)
307        assert_eq!(
308            output.matches("(SELECT * FROM b)").count(),
309            1,
310            "Already-subquery source should not be double-wrapped: {output}"
311        );
312    }
313
314    #[test]
315    fn test_no_alias_not_wrapped() {
316        // Tables without aliases are left alone (in real pipelines,
317        // qualify_tables runs first and assigns aliases).
318        let sql = "SELECT * FROM a JOIN b ON a.id = b.id";
319        let expr = parse(sql);
320        let result = isolate_table_selects(expr, None, None);
321        let output = gen(&result);
322        // Without aliases, tables should not be wrapped
323        assert!(
324            !output.contains("(SELECT * FROM a"),
325            "Table without alias should not be wrapped: {output}"
326        );
327    }
328
329    // ---------------------------------------------------------------
330    // Schema-aware mode: only wrap tables with known columns
331    // ---------------------------------------------------------------
332
333    #[test]
334    fn test_schema_known_table_wrapped() {
335        let mut schema = MappingSchema::new();
336        schema
337            .add_table(
338                "a",
339                &[(
340                    "id".to_string(),
341                    DataType::Int {
342                        length: None,
343                        integer_spelling: false,
344                    },
345                )],
346                None,
347            )
348            .unwrap();
349        schema
350            .add_table(
351                "b",
352                &[(
353                    "id".to_string(),
354                    DataType::Int {
355                        length: None,
356                        integer_spelling: false,
357                    },
358                )],
359                None,
360            )
361            .unwrap();
362
363        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
364        let expr = parse(sql);
365        let result = isolate_table_selects(expr, Some(&schema), None);
366        let output = gen(&result);
367        assert!(
368            output.contains("(SELECT * FROM a AS a) AS a"),
369            "Known table 'a' should be wrapped: {output}"
370        );
371        assert!(
372            output.contains("(SELECT * FROM b AS b) AS b"),
373            "Known table 'b' should be wrapped: {output}"
374        );
375    }
376
377    #[test]
378    fn test_schema_unknown_table_not_wrapped() {
379        let mut schema = MappingSchema::new();
380        // Only 'a' is in the schema; 'b' is unknown
381        schema
382            .add_table(
383                "a",
384                &[(
385                    "id".to_string(),
386                    DataType::Int {
387                        length: None,
388                        integer_spelling: false,
389                    },
390                )],
391                None,
392            )
393            .unwrap();
394
395        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
396        let expr = parse(sql);
397        let result = isolate_table_selects(expr, Some(&schema), None);
398        let output = gen(&result);
399        assert!(
400            output.contains("(SELECT * FROM a AS a) AS a"),
401            "Known table 'a' should be wrapped: {output}"
402        );
403        // 'b' is not in schema, so it should remain a plain table
404        assert!(
405            !output.contains("(SELECT * FROM b AS b) AS b"),
406            "Unknown table 'b' should NOT be wrapped: {output}"
407        );
408    }
409
410    // ---------------------------------------------------------------
411    // Recursive: CTEs and nested subqueries
412    // ---------------------------------------------------------------
413
414    #[test]
415    fn test_cte_inner_query_processed() {
416        let sql =
417            "WITH cte AS (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) SELECT * FROM cte AS c";
418        let expr = parse(sql);
419        let result = isolate_table_selects(expr, None, None);
420        let output = gen(&result);
421        // Inside the CTE, x and y should be wrapped
422        assert!(
423            output.contains("(SELECT * FROM x AS x) AS x"),
424            "CTE inner table 'x' should be wrapped: {output}"
425        );
426        assert!(
427            output.contains("(SELECT * FROM y AS y) AS y"),
428            "CTE inner table 'y' should be wrapped: {output}"
429        );
430    }
431
432    #[test]
433    fn test_nested_subquery_processed() {
434        let sql = "SELECT * FROM (SELECT * FROM a AS a JOIN b AS b ON a.id = b.id) AS sub";
435        let expr = parse(sql);
436        let result = isolate_table_selects(expr, None, None);
437        let output = gen(&result);
438        // The inner SELECT has two sources; they should be wrapped
439        assert!(
440            output.contains("(SELECT * FROM a AS a) AS a"),
441            "Nested inner table 'a' should be wrapped: {output}"
442        );
443    }
444
445    // ---------------------------------------------------------------
446    // Set operations: UNION, INTERSECT, EXCEPT
447    // ---------------------------------------------------------------
448
449    #[test]
450    fn test_union_both_sides_processed() {
451        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id UNION ALL SELECT * FROM c AS c";
452        let expr = parse(sql);
453        let result = isolate_table_selects(expr, None, None);
454        let output = gen(&result);
455        // Left side has two sources - should be wrapped
456        assert!(
457            output.contains("(SELECT * FROM a AS a) AS a"),
458            "UNION left side should be processed: {output}"
459        );
460        // Right side has only one source - should NOT be wrapped
461        assert!(
462            !output.contains("(SELECT * FROM c AS c) AS c"),
463            "UNION right side (single source) should not be wrapped: {output}"
464        );
465    }
466
467    // ---------------------------------------------------------------
468    // Edge cases
469    // ---------------------------------------------------------------
470
471    #[test]
472    fn test_cross_join() {
473        let sql = "SELECT * FROM a AS a CROSS JOIN b AS b";
474        let expr = parse(sql);
475        let result = isolate_table_selects(expr, None, None);
476        let output = gen(&result);
477        assert!(
478            output.contains("(SELECT * FROM a AS a) AS a"),
479            "CROSS JOIN table 'a' should be wrapped: {output}"
480        );
481        assert!(
482            output.contains("(SELECT * FROM b AS b) AS b"),
483            "CROSS JOIN table 'b' should be wrapped: {output}"
484        );
485    }
486
487    #[test]
488    fn test_multiple_from_tables() {
489        // Comma-separated FROM (implicit cross join)
490        let sql = "SELECT * FROM a AS a, b AS b";
491        let expr = parse(sql);
492        let result = isolate_table_selects(expr, None, None);
493        let output = gen(&result);
494        assert!(
495            output.contains("(SELECT * FROM a AS a) AS a"),
496            "Comma-join table 'a' should be wrapped: {output}"
497        );
498        assert!(
499            output.contains("(SELECT * FROM b AS b) AS b"),
500            "Comma-join table 'b' should be wrapped: {output}"
501        );
502    }
503
504    #[test]
505    fn test_three_way_join() {
506        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id JOIN c AS c ON b.id = c.id";
507        let expr = parse(sql);
508        let result = isolate_table_selects(expr, None, None);
509        let output = gen(&result);
510        assert!(
511            output.contains("(SELECT * FROM a AS a) AS a"),
512            "Three-way join: 'a' should be wrapped: {output}"
513        );
514        assert!(
515            output.contains("(SELECT * FROM b AS b) AS b"),
516            "Three-way join: 'b' should be wrapped: {output}"
517        );
518        assert!(
519            output.contains("(SELECT * FROM c AS c) AS c"),
520            "Three-way join: 'c' should be wrapped: {output}"
521        );
522    }
523
524    #[test]
525    fn test_qualified_table_name_with_schema() {
526        let mut schema = MappingSchema::new();
527        schema
528            .add_table(
529                "mydb.a",
530                &[(
531                    "id".to_string(),
532                    DataType::Int {
533                        length: None,
534                        integer_spelling: false,
535                    },
536                )],
537                None,
538            )
539            .unwrap();
540        schema
541            .add_table(
542                "mydb.b",
543                &[(
544                    "id".to_string(),
545                    DataType::Int {
546                        length: None,
547                        integer_spelling: false,
548                    },
549                )],
550                None,
551            )
552            .unwrap();
553
554        let sql = "SELECT * FROM mydb.a AS a JOIN mydb.b AS b ON a.id = b.id";
555        let expr = parse(sql);
556        let result = isolate_table_selects(expr, Some(&schema), None);
557        let output = gen(&result);
558        assert!(
559            output.contains("(SELECT * FROM mydb.a AS a) AS a"),
560            "Qualified table 'mydb.a' should be wrapped: {output}"
561        );
562        assert!(
563            output.contains("(SELECT * FROM mydb.b AS b) AS b"),
564            "Qualified table 'mydb.b' should be wrapped: {output}"
565        );
566    }
567
568    #[test]
569    fn test_non_select_expression_unchanged() {
570        // Non-SELECT expressions (e.g., INSERT, CREATE) pass through unchanged
571        let sql = "INSERT INTO t VALUES (1)";
572        let expr = parse(sql);
573        let original = gen(&expr);
574        let result = isolate_table_selects(expr, None, None);
575        let output = gen(&result);
576        assert_eq!(original, output, "Non-SELECT should be unchanged");
577    }
578}