Skip to main content

polyglot_sql/optimizer/
isolate_table_selects.rs

1//! Isolate Table Selects Optimization Pass
2//!
3//! This module wraps plain table references in subqueries (`SELECT * FROM table`)
4//! when multiple tables are present in a scope. This normalization is needed for
5//! other optimizations (like merge_subqueries) to work correctly, since they
6//! expect each source in a multi-table query to be a subquery rather than a bare
7//! table reference.
8//!
9//! Ported from sqlglot's optimizer/isolate_table_selects.py
10
11use crate::dialects::DialectType;
12use crate::expressions::*;
13use crate::schema::Schema;
14
15/// Error type for the isolate_table_selects pass
16#[derive(Debug, Clone, thiserror::Error)]
17pub enum IsolateTableSelectsError {
18    #[error("Tables require an alias: {0}")]
19    MissingAlias(String),
20}
21
22/// Wrap plain table references in subqueries when multiple sources are present.
23///
24/// When a SELECT has multiple sources (FROM + JOINs, or multiple FROM tables),
25/// each bare `Table` reference is replaced with:
26///
27/// ```sql
28/// (SELECT * FROM table AS alias) AS alias
29/// ```
30///
31/// This makes every source a subquery, which simplifies downstream
32/// optimizations such as `merge_subqueries`.
33///
34/// # Arguments
35///
36/// * `expression` - The SQL expression tree to transform
37/// * `schema` - Optional schema for looking up column names (used to skip
38///   tables whose columns are unknown, matching the Python behavior)
39/// * `_dialect` - Optional dialect (reserved for future use)
40///
41/// # Returns
42///
43/// The transformed expression with isolated table selects
44pub fn isolate_table_selects(
45    expression: Expression,
46    schema: Option<&dyn Schema>,
47    _dialect: Option<DialectType>,
48) -> Expression {
49    match expression {
50        Expression::Select(select) => {
51            let transformed = isolate_select(*select, schema);
52            Expression::Select(Box::new(transformed))
53        }
54        Expression::Union(mut union) => {
55            union.left = isolate_table_selects(union.left, schema, _dialect);
56            union.right = isolate_table_selects(union.right, schema, _dialect);
57            Expression::Union(union)
58        }
59        Expression::Intersect(mut intersect) => {
60            intersect.left = isolate_table_selects(intersect.left, schema, _dialect);
61            intersect.right = isolate_table_selects(intersect.right, schema, _dialect);
62            Expression::Intersect(intersect)
63        }
64        Expression::Except(mut except) => {
65            except.left = isolate_table_selects(except.left, schema, _dialect);
66            except.right = isolate_table_selects(except.right, schema, _dialect);
67            Expression::Except(except)
68        }
69        other => other,
70    }
71}
72
73/// Process a single SELECT statement, wrapping bare table references in
74/// subqueries when multiple sources are present.
75fn isolate_select(mut select: Select, schema: Option<&dyn Schema>) -> Select {
76    // First, recursively process CTEs
77    if let Some(ref mut with) = select.with {
78        for cte in &mut with.ctes {
79            cte.this = isolate_table_selects(cte.this.clone(), schema, None);
80        }
81    }
82
83    // Recursively process subqueries in FROM and JOINs
84    if let Some(ref mut from) = select.from {
85        for expr in &mut from.expressions {
86            if let Expression::Subquery(ref mut sq) = expr {
87                sq.this = isolate_table_selects(sq.this.clone(), schema, None);
88            }
89        }
90    }
91    for join in &mut select.joins {
92        if let Expression::Subquery(ref mut sq) = join.this {
93            sq.this = isolate_table_selects(sq.this.clone(), schema, None);
94        }
95    }
96
97    // Count the total number of sources (FROM expressions + JOINs)
98    let source_count = count_sources(&select);
99
100    // Only isolate when there are multiple sources
101    if source_count <= 1 {
102        return select;
103    }
104
105    // Wrap bare table references in FROM clause
106    if let Some(ref mut from) = select.from {
107        from.expressions = from
108            .expressions
109            .drain(..)
110            .map(|expr| maybe_wrap_table(expr, schema))
111            .collect();
112    }
113
114    // Wrap bare table references in JOINs
115    for join in &mut select.joins {
116        join.this = maybe_wrap_table(join.this.clone(), schema);
117    }
118
119    select
120}
121
122/// Count the total number of source tables/subqueries in a SELECT.
123///
124/// This counts FROM expressions plus JOINs.
125fn count_sources(select: &Select) -> usize {
126    let from_count = select
127        .from
128        .as_ref()
129        .map(|f| f.expressions.len())
130        .unwrap_or(0);
131    let join_count = select.joins.len();
132    from_count + join_count
133}
134
135/// If the expression is a bare `Table` reference that should be isolated,
136/// wrap it in a `(SELECT * FROM table AS alias) AS alias` subquery.
137///
138/// A table is wrapped when:
139/// - It is an `Expression::Table` (not already a subquery)
140/// - It has an alias (required by the Python reference)
141/// - If a schema is provided, the table must have known columns in the schema
142///
143/// If no schema is provided, all aliased tables are wrapped (simplified mode).
144fn maybe_wrap_table(expression: Expression, schema: Option<&dyn Schema>) -> Expression {
145    match expression {
146        Expression::Table(ref table) => {
147            // If a schema is provided, check that the table has known columns.
148            // If we cannot find columns for the table, skip wrapping it (matching
149            // the Python behavior where `schema.column_names(source)` must be truthy).
150            if let Some(s) = schema {
151                let table_name = full_table_name(table);
152                if s.column_names(&table_name).unwrap_or_default().is_empty() {
153                    return expression;
154                }
155            }
156
157            // The table must have an alias; if it does not, we leave it as-is.
158            // The Python version raises an OptimizeError here, but in practice
159            // earlier passes (qualify_tables) ensure aliases are present.
160            let alias_name = match &table.alias {
161                Some(alias) if !alias.name.is_empty() => alias.name.clone(),
162                _ => return expression,
163            };
164
165            wrap_table_in_subquery(table.clone(), &alias_name)
166        }
167        _ => expression,
168    }
169}
170
171/// Build `(SELECT * FROM table_ref AS alias) AS alias` from a table reference.
172///
173/// The inner table reference keeps the original alias so that
174/// `FROM t AS t` becomes `(SELECT * FROM t AS t) AS t`.
175fn wrap_table_in_subquery(table: TableRef, alias_name: &str) -> Expression {
176    // Build: SELECT * FROM <table>
177    let inner_select = Select::new()
178        .column(Expression::Star(Star {
179            table: None,
180            except: None,
181            replace: None,
182            rename: None,
183            trailing_comments: Vec::new(),
184        }))
185        .from(Expression::Table(table));
186
187    // Wrap the SELECT in a Subquery with the original alias
188    Expression::Subquery(Box::new(Subquery {
189        this: Expression::Select(Box::new(inner_select)),
190        alias: Some(Identifier::new(alias_name)),
191        column_aliases: Vec::new(),
192        order_by: None,
193        limit: None,
194        offset: None,
195        distribute_by: None,
196        sort_by: None,
197        cluster_by: None,
198        lateral: false,
199        modifiers_inside: false,
200        trailing_comments: Vec::new(),
201    }))
202}
203
204/// Construct the fully qualified table name from a `TableRef`.
205///
206/// Produces `catalog.schema.name` or `schema.name` or just `name`
207/// depending on which parts are present.
208fn full_table_name(table: &TableRef) -> String {
209    let mut parts = Vec::new();
210    if let Some(ref catalog) = table.catalog {
211        parts.push(catalog.name.as_str());
212    }
213    if let Some(ref schema) = table.schema {
214        parts.push(schema.name.as_str());
215    }
216    parts.push(&table.name.name);
217    parts.join(".")
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use crate::generator::Generator;
224    use crate::parser::Parser;
225    use crate::schema::MappingSchema;
226
227    /// Helper: parse SQL into an Expression
228    fn parse(sql: &str) -> Expression {
229        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
230    }
231
232    /// Helper: generate SQL from an Expression
233    fn gen(expr: &Expression) -> String {
234        Generator::new().generate(expr).unwrap()
235    }
236
237    // ---------------------------------------------------------------
238    // Basic: single source should NOT be wrapped
239    // ---------------------------------------------------------------
240
241    #[test]
242    fn test_single_table_unchanged() {
243        let sql = "SELECT * FROM t AS t";
244        let expr = parse(sql);
245        let result = isolate_table_selects(expr, None, None);
246        let output = gen(&result);
247        // Should remain a plain table, not wrapped in a subquery
248        assert!(
249            !output.contains("(SELECT"),
250            "Single table should not be wrapped: {output}"
251        );
252    }
253
254    #[test]
255    fn test_single_subquery_unchanged() {
256        let sql = "SELECT * FROM (SELECT 1) AS t";
257        let expr = parse(sql);
258        let result = isolate_table_selects(expr, None, None);
259        let output = gen(&result);
260        // Still just one source, no additional wrapping expected
261        assert_eq!(
262            output.matches("(SELECT").count(),
263            1,
264            "Single subquery source should not gain extra wrapping: {output}"
265        );
266    }
267
268    // ---------------------------------------------------------------
269    // Multiple sources: tables should be wrapped
270    // ---------------------------------------------------------------
271
272    #[test]
273    fn test_two_tables_joined() {
274        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
275        let expr = parse(sql);
276        let result = isolate_table_selects(expr, None, None);
277        let output = gen(&result);
278        // Both tables should now be subqueries
279        assert!(
280            output.contains("(SELECT * FROM a AS a) AS a"),
281            "FROM table should be wrapped: {output}"
282        );
283        assert!(
284            output.contains("(SELECT * FROM b AS b) AS b"),
285            "JOIN table should be wrapped: {output}"
286        );
287    }
288
289    #[test]
290    fn test_table_with_join_subquery() {
291        // If one source is already a subquery and the other is a table,
292        // only the bare table should be wrapped.
293        let sql = "SELECT * FROM a AS a JOIN (SELECT * FROM b) AS b ON a.id = b.id";
294        let expr = parse(sql);
295        let result = isolate_table_selects(expr, None, None);
296        let output = gen(&result);
297        // `a` should be wrapped
298        assert!(
299            output.contains("(SELECT * FROM a AS a) AS a"),
300            "Bare table should be wrapped: {output}"
301        );
302        // `b` is already a subquery, so it should appear once
303        // (no double-wrapping)
304        assert_eq!(
305            output.matches("(SELECT * FROM b)").count(),
306            1,
307            "Already-subquery source should not be double-wrapped: {output}"
308        );
309    }
310
311    #[test]
312    fn test_no_alias_not_wrapped() {
313        // Tables without aliases are left alone (in real pipelines,
314        // qualify_tables runs first and assigns aliases).
315        let sql = "SELECT * FROM a JOIN b ON a.id = b.id";
316        let expr = parse(sql);
317        let result = isolate_table_selects(expr, None, None);
318        let output = gen(&result);
319        // Without aliases, tables should not be wrapped
320        assert!(
321            !output.contains("(SELECT * FROM a"),
322            "Table without alias should not be wrapped: {output}"
323        );
324    }
325
326    // ---------------------------------------------------------------
327    // Schema-aware mode: only wrap tables with known columns
328    // ---------------------------------------------------------------
329
330    #[test]
331    fn test_schema_known_table_wrapped() {
332        let mut schema = MappingSchema::new();
333        schema
334            .add_table(
335                "a",
336                &[("id".to_string(), DataType::Int { length: None, integer_spelling: false })],
337                None,
338            )
339            .unwrap();
340        schema
341            .add_table(
342                "b",
343                &[("id".to_string(), DataType::Int { length: None, integer_spelling: false })],
344                None,
345            )
346            .unwrap();
347
348        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
349        let expr = parse(sql);
350        let result = isolate_table_selects(expr, Some(&schema), None);
351        let output = gen(&result);
352        assert!(
353            output.contains("(SELECT * FROM a AS a) AS a"),
354            "Known table 'a' should be wrapped: {output}"
355        );
356        assert!(
357            output.contains("(SELECT * FROM b AS b) AS b"),
358            "Known table 'b' should be wrapped: {output}"
359        );
360    }
361
362    #[test]
363    fn test_schema_unknown_table_not_wrapped() {
364        let mut schema = MappingSchema::new();
365        // Only 'a' is in the schema; 'b' is unknown
366        schema
367            .add_table(
368                "a",
369                &[("id".to_string(), DataType::Int { length: None, integer_spelling: false })],
370                None,
371            )
372            .unwrap();
373
374        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
375        let expr = parse(sql);
376        let result = isolate_table_selects(expr, Some(&schema), None);
377        let output = gen(&result);
378        assert!(
379            output.contains("(SELECT * FROM a AS a) AS a"),
380            "Known table 'a' should be wrapped: {output}"
381        );
382        // 'b' is not in schema, so it should remain a plain table
383        assert!(
384            !output.contains("(SELECT * FROM b AS b) AS b"),
385            "Unknown table 'b' should NOT be wrapped: {output}"
386        );
387    }
388
389    // ---------------------------------------------------------------
390    // Recursive: CTEs and nested subqueries
391    // ---------------------------------------------------------------
392
393    #[test]
394    fn test_cte_inner_query_processed() {
395        let sql = "WITH cte AS (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) SELECT * FROM cte AS c";
396        let expr = parse(sql);
397        let result = isolate_table_selects(expr, None, None);
398        let output = gen(&result);
399        // Inside the CTE, x and y should be wrapped
400        assert!(
401            output.contains("(SELECT * FROM x AS x) AS x"),
402            "CTE inner table 'x' should be wrapped: {output}"
403        );
404        assert!(
405            output.contains("(SELECT * FROM y AS y) AS y"),
406            "CTE inner table 'y' should be wrapped: {output}"
407        );
408    }
409
410    #[test]
411    fn test_nested_subquery_processed() {
412        let sql = "SELECT * FROM (SELECT * FROM a AS a JOIN b AS b ON a.id = b.id) AS sub";
413        let expr = parse(sql);
414        let result = isolate_table_selects(expr, None, None);
415        let output = gen(&result);
416        // The inner SELECT has two sources; they should be wrapped
417        assert!(
418            output.contains("(SELECT * FROM a AS a) AS a"),
419            "Nested inner table 'a' should be wrapped: {output}"
420        );
421    }
422
423    // ---------------------------------------------------------------
424    // Set operations: UNION, INTERSECT, EXCEPT
425    // ---------------------------------------------------------------
426
427    #[test]
428    fn test_union_both_sides_processed() {
429        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id UNION ALL SELECT * FROM c AS c";
430        let expr = parse(sql);
431        let result = isolate_table_selects(expr, None, None);
432        let output = gen(&result);
433        // Left side has two sources - should be wrapped
434        assert!(
435            output.contains("(SELECT * FROM a AS a) AS a"),
436            "UNION left side should be processed: {output}"
437        );
438        // Right side has only one source - should NOT be wrapped
439        assert!(
440            !output.contains("(SELECT * FROM c AS c) AS c"),
441            "UNION right side (single source) should not be wrapped: {output}"
442        );
443    }
444
445    // ---------------------------------------------------------------
446    // Edge cases
447    // ---------------------------------------------------------------
448
449    #[test]
450    fn test_cross_join() {
451        let sql = "SELECT * FROM a AS a CROSS JOIN b AS b";
452        let expr = parse(sql);
453        let result = isolate_table_selects(expr, None, None);
454        let output = gen(&result);
455        assert!(
456            output.contains("(SELECT * FROM a AS a) AS a"),
457            "CROSS JOIN table 'a' should be wrapped: {output}"
458        );
459        assert!(
460            output.contains("(SELECT * FROM b AS b) AS b"),
461            "CROSS JOIN table 'b' should be wrapped: {output}"
462        );
463    }
464
465    #[test]
466    fn test_multiple_from_tables() {
467        // Comma-separated FROM (implicit cross join)
468        let sql = "SELECT * FROM a AS a, b AS b";
469        let expr = parse(sql);
470        let result = isolate_table_selects(expr, None, None);
471        let output = gen(&result);
472        assert!(
473            output.contains("(SELECT * FROM a AS a) AS a"),
474            "Comma-join table 'a' should be wrapped: {output}"
475        );
476        assert!(
477            output.contains("(SELECT * FROM b AS b) AS b"),
478            "Comma-join table 'b' should be wrapped: {output}"
479        );
480    }
481
482    #[test]
483    fn test_three_way_join() {
484        let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id JOIN c AS c ON b.id = c.id";
485        let expr = parse(sql);
486        let result = isolate_table_selects(expr, None, None);
487        let output = gen(&result);
488        assert!(
489            output.contains("(SELECT * FROM a AS a) AS a"),
490            "Three-way join: 'a' should be wrapped: {output}"
491        );
492        assert!(
493            output.contains("(SELECT * FROM b AS b) AS b"),
494            "Three-way join: 'b' should be wrapped: {output}"
495        );
496        assert!(
497            output.contains("(SELECT * FROM c AS c) AS c"),
498            "Three-way join: 'c' should be wrapped: {output}"
499        );
500    }
501
502    #[test]
503    fn test_qualified_table_name_with_schema() {
504        let mut schema = MappingSchema::new();
505        schema
506            .add_table(
507                "mydb.a",
508                &[("id".to_string(), DataType::Int { length: None, integer_spelling: false })],
509                None,
510            )
511            .unwrap();
512        schema
513            .add_table(
514                "mydb.b",
515                &[("id".to_string(), DataType::Int { length: None, integer_spelling: false })],
516                None,
517            )
518            .unwrap();
519
520        let sql = "SELECT * FROM mydb.a AS a JOIN mydb.b AS b ON a.id = b.id";
521        let expr = parse(sql);
522        let result = isolate_table_selects(expr, Some(&schema), None);
523        let output = gen(&result);
524        assert!(
525            output.contains("(SELECT * FROM mydb.a AS a) AS a"),
526            "Qualified table 'mydb.a' should be wrapped: {output}"
527        );
528        assert!(
529            output.contains("(SELECT * FROM mydb.b AS b) AS b"),
530            "Qualified table 'mydb.b' should be wrapped: {output}"
531        );
532    }
533
534    #[test]
535    fn test_non_select_expression_unchanged() {
536        // Non-SELECT expressions (e.g., INSERT, CREATE) pass through unchanged
537        let sql = "INSERT INTO t VALUES (1)";
538        let expr = parse(sql);
539        let original = gen(&expr);
540        let result = isolate_table_selects(expr, None, None);
541        let output = gen(&result);
542        assert_eq!(original, output, "Non-SELECT should be unchanged");
543    }
544}