Skip to main content

polyglot_sql/optimizer/
canonicalize.rs

1//! Canonicalization Module
2//!
3//! This module provides functionality for converting SQL expressions into a
4//! standard canonical form. This includes:
5//! - Converting string addition to CONCAT
6//! - Replacing date functions with casts
7//! - Removing redundant type casts
8//! - Ensuring boolean predicates
9//! - Removing unnecessary ASC from ORDER BY
10//!
11//! Ported from sqlglot's optimizer/canonicalize.py
12
13use crate::dialects::DialectType;
14use crate::expressions::{DataType, Expression, Literal};
15use crate::helper::{is_iso_date, is_iso_datetime};
16
17/// Converts a SQL expression into a standard canonical form.
18///
19/// This transformation relies on type annotations because many of the
20/// conversions depend on type inference.
21///
22/// # Arguments
23/// * `expression` - The expression to canonicalize
24/// * `dialect` - Optional dialect for dialect-specific behavior
25///
26/// # Returns
27/// The canonicalized expression
28pub fn canonicalize(expression: Expression, dialect: Option<DialectType>) -> Expression {
29    canonicalize_recursive(expression, dialect)
30}
31
32/// Recursively canonicalize an expression and its children
33fn canonicalize_recursive(expression: Expression, dialect: Option<DialectType>) -> Expression {
34    let expr = match expression {
35        Expression::Select(mut select) => {
36            // Canonicalize SELECT expressions
37            select.expressions = select
38                .expressions
39                .into_iter()
40                .map(|e| canonicalize_recursive(e, dialect))
41                .collect();
42
43            // Canonicalize FROM
44            if let Some(mut from) = select.from {
45                from.expressions = from
46                    .expressions
47                    .into_iter()
48                    .map(|e| canonicalize_recursive(e, dialect))
49                    .collect();
50                select.from = Some(from);
51            }
52
53            // Canonicalize WHERE
54            if let Some(mut where_clause) = select.where_clause {
55                where_clause.this = canonicalize_recursive(where_clause.this, dialect);
56                where_clause.this = ensure_bools(where_clause.this);
57                select.where_clause = Some(where_clause);
58            }
59
60            // Canonicalize HAVING
61            if let Some(mut having) = select.having {
62                having.this = canonicalize_recursive(having.this, dialect);
63                having.this = ensure_bools(having.this);
64                select.having = Some(having);
65            }
66
67            // Canonicalize ORDER BY
68            if let Some(mut order_by) = select.order_by {
69                order_by.expressions = order_by
70                    .expressions
71                    .into_iter()
72                    .map(|mut o| {
73                        o.this = canonicalize_recursive(o.this, dialect);
74                        o = remove_ascending_order(o);
75                        o
76                    })
77                    .collect();
78                select.order_by = Some(order_by);
79            }
80
81            // Canonicalize JOINs
82            select.joins = select
83                .joins
84                .into_iter()
85                .map(|mut j| {
86                    j.this = canonicalize_recursive(j.this, dialect);
87                    if let Some(on) = j.on {
88                        j.on = Some(canonicalize_recursive(on, dialect));
89                    }
90                    j
91                })
92                .collect();
93
94            Expression::Select(select)
95        }
96
97        // Binary operations that might involve string addition
98        Expression::Add(bin) => {
99            let left = canonicalize_recursive(bin.left, dialect);
100            let right = canonicalize_recursive(bin.right, dialect);
101            let result = Expression::Add(Box::new(crate::expressions::BinaryOp {
102                left,
103                right,
104                left_comments: bin.left_comments,
105                operator_comments: bin.operator_comments,
106                trailing_comments: bin.trailing_comments,
107                inferred_type: None,
108            }));
109            add_text_to_concat(result)
110        }
111
112        // Other binary operations
113        Expression::And(bin) => {
114            let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
115            let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
116            Expression::And(Box::new(crate::expressions::BinaryOp {
117                left,
118                right,
119                left_comments: bin.left_comments,
120                operator_comments: bin.operator_comments,
121                trailing_comments: bin.trailing_comments,
122                inferred_type: None,
123            }))
124        }
125        Expression::Or(bin) => {
126            let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
127            let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
128            Expression::Or(Box::new(crate::expressions::BinaryOp {
129                left,
130                right,
131                left_comments: bin.left_comments,
132                operator_comments: bin.operator_comments,
133                trailing_comments: bin.trailing_comments,
134                inferred_type: None,
135            }))
136        }
137
138        Expression::Not(un) => {
139            let inner = ensure_bools(canonicalize_recursive(un.this, dialect));
140            Expression::Not(Box::new(crate::expressions::UnaryOp {
141                this: inner,
142                inferred_type: None,
143            }))
144        }
145
146        // Comparison operations - check for date coercion
147        Expression::Eq(bin) => canonicalize_comparison(Expression::Eq, *bin, dialect),
148        Expression::Neq(bin) => canonicalize_comparison(Expression::Neq, *bin, dialect),
149        Expression::Lt(bin) => canonicalize_comparison(Expression::Lt, *bin, dialect),
150        Expression::Lte(bin) => canonicalize_comparison(Expression::Lte, *bin, dialect),
151        Expression::Gt(bin) => canonicalize_comparison(Expression::Gt, *bin, dialect),
152        Expression::Gte(bin) => canonicalize_comparison(Expression::Gte, *bin, dialect),
153
154        Expression::Sub(bin) => canonicalize_comparison(Expression::Sub, *bin, dialect),
155        Expression::Mul(bin) => canonicalize_binary(Expression::Mul, *bin, dialect),
156        Expression::Div(bin) => canonicalize_binary(Expression::Div, *bin, dialect),
157
158        // Cast - check for redundancy
159        Expression::Cast(cast) => {
160            let inner = canonicalize_recursive(cast.this, dialect);
161            let result = Expression::Cast(Box::new(crate::expressions::Cast {
162                this: inner,
163                to: cast.to,
164                trailing_comments: cast.trailing_comments,
165                double_colon_syntax: cast.double_colon_syntax,
166                format: cast.format,
167                default: cast.default,
168                inferred_type: None,
169            }));
170            remove_redundant_casts(result)
171        }
172
173        // Function expressions
174        Expression::Function(func) => {
175            let args = func
176                .args
177                .into_iter()
178                .map(|e| canonicalize_recursive(e, dialect))
179                .collect();
180            Expression::Function(Box::new(crate::expressions::Function {
181                name: func.name,
182                args,
183                distinct: func.distinct,
184                trailing_comments: func.trailing_comments,
185                use_bracket_syntax: func.use_bracket_syntax,
186                no_parens: func.no_parens,
187                quoted: func.quoted,
188                span: None,
189                inferred_type: None,
190            }))
191        }
192
193        Expression::AggregateFunction(agg) => {
194            let args = agg
195                .args
196                .into_iter()
197                .map(|e| canonicalize_recursive(e, dialect))
198                .collect();
199            Expression::AggregateFunction(Box::new(crate::expressions::AggregateFunction {
200                name: agg.name,
201                args,
202                distinct: agg.distinct,
203                filter: agg.filter.map(|f| canonicalize_recursive(f, dialect)),
204                order_by: agg.order_by,
205                limit: agg.limit,
206                ignore_nulls: agg.ignore_nulls,
207                inferred_type: None,
208            }))
209        }
210
211        // Alias
212        Expression::Alias(alias) => {
213            let inner = canonicalize_recursive(alias.this, dialect);
214            Expression::Alias(Box::new(crate::expressions::Alias {
215                this: inner,
216                alias: alias.alias,
217                column_aliases: alias.column_aliases,
218                pre_alias_comments: alias.pre_alias_comments,
219                trailing_comments: alias.trailing_comments,
220                inferred_type: None,
221            }))
222        }
223
224        // Paren
225        Expression::Paren(paren) => {
226            let inner = canonicalize_recursive(paren.this, dialect);
227            Expression::Paren(Box::new(crate::expressions::Paren {
228                this: inner,
229                trailing_comments: paren.trailing_comments,
230            }))
231        }
232
233        // Case
234        Expression::Case(case) => {
235            let operand = case.operand.map(|e| canonicalize_recursive(e, dialect));
236            let whens = case
237                .whens
238                .into_iter()
239                .map(|(w, t)| {
240                    (
241                        canonicalize_recursive(w, dialect),
242                        canonicalize_recursive(t, dialect),
243                    )
244                })
245                .collect();
246            let else_ = case.else_.map(|e| canonicalize_recursive(e, dialect));
247            Expression::Case(Box::new(crate::expressions::Case {
248                operand,
249                whens,
250                else_,
251                comments: Vec::new(),
252                inferred_type: None,
253            }))
254        }
255
256        // Between - check for date coercion
257        Expression::Between(between) => {
258            let this = canonicalize_recursive(between.this, dialect);
259            let low = canonicalize_recursive(between.low, dialect);
260            let high = canonicalize_recursive(between.high, dialect);
261            Expression::Between(Box::new(crate::expressions::Between {
262                this,
263                low,
264                high,
265                not: between.not,
266                symmetric: between.symmetric,
267            }))
268        }
269
270        // In
271        Expression::In(in_expr) => {
272            let this = canonicalize_recursive(in_expr.this, dialect);
273            let expressions = in_expr
274                .expressions
275                .into_iter()
276                .map(|e| canonicalize_recursive(e, dialect))
277                .collect();
278            let query = in_expr.query.map(|q| canonicalize_recursive(q, dialect));
279            Expression::In(Box::new(crate::expressions::In {
280                this,
281                expressions,
282                query,
283                not: in_expr.not,
284                global: in_expr.global,
285                unnest: in_expr.unnest,
286                is_field: in_expr.is_field,
287            }))
288        }
289
290        // Subquery
291        Expression::Subquery(subquery) => {
292            let this = canonicalize_recursive(subquery.this, dialect);
293            Expression::Subquery(Box::new(crate::expressions::Subquery {
294                this,
295                alias: subquery.alias,
296                column_aliases: subquery.column_aliases,
297                order_by: subquery.order_by,
298                limit: subquery.limit,
299                offset: subquery.offset,
300                distribute_by: subquery.distribute_by,
301                sort_by: subquery.sort_by,
302                cluster_by: subquery.cluster_by,
303                lateral: subquery.lateral,
304                modifiers_inside: subquery.modifiers_inside,
305                trailing_comments: subquery.trailing_comments,
306                inferred_type: None,
307            }))
308        }
309
310        // Set operations
311        Expression::Union(union) => {
312            let left = canonicalize_recursive(union.left, dialect);
313            let right = canonicalize_recursive(union.right, dialect);
314            Expression::Union(Box::new(crate::expressions::Union {
315                left,
316                right,
317                all: union.all,
318                distinct: union.distinct,
319                with: union.with,
320                order_by: union.order_by,
321                limit: union.limit,
322                offset: union.offset,
323                distribute_by: union.distribute_by,
324                sort_by: union.sort_by,
325                cluster_by: union.cluster_by,
326                by_name: union.by_name,
327                side: union.side,
328                kind: union.kind,
329                corresponding: union.corresponding,
330                strict: union.strict,
331                on_columns: union.on_columns,
332            }))
333        }
334        Expression::Intersect(intersect) => {
335            let left = canonicalize_recursive(intersect.left, dialect);
336            let right = canonicalize_recursive(intersect.right, dialect);
337            Expression::Intersect(Box::new(crate::expressions::Intersect {
338                left,
339                right,
340                all: intersect.all,
341                distinct: intersect.distinct,
342                with: intersect.with,
343                order_by: intersect.order_by,
344                limit: intersect.limit,
345                offset: intersect.offset,
346                distribute_by: intersect.distribute_by,
347                sort_by: intersect.sort_by,
348                cluster_by: intersect.cluster_by,
349                by_name: intersect.by_name,
350                side: intersect.side,
351                kind: intersect.kind,
352                corresponding: intersect.corresponding,
353                strict: intersect.strict,
354                on_columns: intersect.on_columns,
355            }))
356        }
357        Expression::Except(except) => {
358            let left = canonicalize_recursive(except.left, dialect);
359            let right = canonicalize_recursive(except.right, dialect);
360            Expression::Except(Box::new(crate::expressions::Except {
361                left,
362                right,
363                all: except.all,
364                distinct: except.distinct,
365                with: except.with,
366                order_by: except.order_by,
367                limit: except.limit,
368                offset: except.offset,
369                distribute_by: except.distribute_by,
370                sort_by: except.sort_by,
371                cluster_by: except.cluster_by,
372                by_name: except.by_name,
373                side: except.side,
374                kind: except.kind,
375                corresponding: except.corresponding,
376                strict: except.strict,
377                on_columns: except.on_columns,
378            }))
379        }
380
381        // Leaf nodes - return unchanged
382        other => other,
383    };
384
385    expr
386}
387
388/// Convert string addition to CONCAT.
389///
390/// When two TEXT types are added with +, convert to CONCAT.
391/// This is used by dialects like T-SQL and Redshift.
392fn add_text_to_concat(expression: Expression) -> Expression {
393    // In a full implementation, we would check if the operands are TEXT types
394    // and convert to CONCAT. For now, we return unchanged.
395    expression
396}
397
398/// Remove redundant cast expressions.
399///
400/// If casting to the same type the expression already is, remove the cast.
401fn remove_redundant_casts(expression: Expression) -> Expression {
402    if let Expression::Cast(cast) = &expression {
403        // Check if the inner expression's type matches the cast target
404        // In a full implementation with type annotations, we would compare types
405        // For now, just check simple cases
406
407        // If casting a literal to its natural type, we might be able to simplify
408        if let Expression::Literal(lit) = &cast.this {
409            if let Literal::String(_) = lit.as_ref() {
410            if matches!(&cast.to, DataType::VarChar { .. } | DataType::Text) {
411                return cast.this.clone();
412            }
413        }
414        }
415        if let Expression::Literal(lit) = &cast.this {
416            if let Literal::Number(_) = lit.as_ref() {
417            if matches!(
418                &cast.to,
419                DataType::Int { .. }
420                    | DataType::BigInt { .. }
421                    | DataType::Decimal { .. }
422                    | DataType::Float { .. }
423            ) {
424                // Could potentially remove cast, but be conservative
425            }
426        }
427        }
428    }
429    expression
430}
431
432/// Ensure expressions used as boolean predicates are actually boolean.
433///
434/// For example, in some dialects, integers can be used as booleans.
435/// This function ensures proper boolean semantics.
436fn ensure_bools(expression: Expression) -> Expression {
437    // In a full implementation, we would check if the expression is an integer
438    // and convert it to a comparison (e.g., x != 0).
439    // For now, return unchanged.
440    expression
441}
442
443/// Remove explicit ASC from ORDER BY clauses.
444///
445/// Since ASC is the default, `ORDER BY a ASC` can be simplified to `ORDER BY a`.
446fn remove_ascending_order(mut ordered: crate::expressions::Ordered) -> crate::expressions::Ordered {
447    // If ASC was explicitly written (not DESC), remove the explicit flag
448    // since ASC is the default ordering
449    if !ordered.desc && ordered.explicit_asc {
450        ordered.explicit_asc = false;
451    }
452    ordered
453}
454
455/// Canonicalize a binary comparison operation.
456fn canonicalize_comparison<F>(
457    constructor: F,
458    bin: crate::expressions::BinaryOp,
459    dialect: Option<DialectType>,
460) -> Expression
461where
462    F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
463{
464    let left = canonicalize_recursive(bin.left, dialect);
465    let right = canonicalize_recursive(bin.right, dialect);
466
467    // Check for date coercion opportunities
468    let (left, right) = coerce_date_operands(left, right);
469
470    constructor(Box::new(crate::expressions::BinaryOp {
471        left,
472        right,
473        left_comments: bin.left_comments,
474        operator_comments: bin.operator_comments,
475        trailing_comments: bin.trailing_comments,
476        inferred_type: None,
477    }))
478}
479
480/// Canonicalize a regular binary operation.
481fn canonicalize_binary<F>(
482    constructor: F,
483    bin: crate::expressions::BinaryOp,
484    dialect: Option<DialectType>,
485) -> Expression
486where
487    F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
488{
489    let left = canonicalize_recursive(bin.left, dialect);
490    let right = canonicalize_recursive(bin.right, dialect);
491
492    constructor(Box::new(crate::expressions::BinaryOp {
493        left,
494        right,
495        left_comments: bin.left_comments,
496        operator_comments: bin.operator_comments,
497        trailing_comments: bin.trailing_comments,
498        inferred_type: None,
499    }))
500}
501
502/// Coerce date operands in comparisons.
503///
504/// When comparing a date/datetime column with a string literal,
505/// add appropriate CAST to the string.
506fn coerce_date_operands(left: Expression, right: Expression) -> (Expression, Expression) {
507    // Check if we should cast string literals to date/datetime
508    let left = coerce_date_string(left, &right);
509    let right = coerce_date_string(right, &left);
510    (left, right)
511}
512
513/// Coerce a string literal to date/datetime if comparing with a temporal type.
514fn coerce_date_string(expr: Expression, _other: &Expression) -> Expression {
515    if let Expression::Literal(ref lit) = expr {
516        if let Literal::String(ref s) = lit.as_ref() {
517        // Check if the string is an ISO date or datetime
518        if is_iso_date(s) {
519            // In a full implementation, we would add CAST to DATE
520            // For now, return unchanged
521        } else if is_iso_datetime(s) {
522            // In a full implementation, we would add CAST to DATETIME/TIMESTAMP
523            // For now, return unchanged
524        }
525    }
526    }
527    expr
528}
529
530#[cfg(test)]
531mod tests {
532    use super::*;
533    use crate::generator::Generator;
534    use crate::parser::Parser;
535
536    fn gen(expr: &Expression) -> String {
537        Generator::new().generate(expr).unwrap()
538    }
539
540    fn parse(sql: &str) -> Expression {
541        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
542    }
543
544    #[test]
545    fn test_canonicalize_simple() {
546        let expr = parse("SELECT a FROM t");
547        let result = canonicalize(expr, None);
548        let sql = gen(&result);
549        assert!(sql.contains("SELECT"));
550    }
551
552    #[test]
553    fn test_canonicalize_preserves_structure() {
554        let expr = parse("SELECT a, b FROM t WHERE c = 1");
555        let result = canonicalize(expr, None);
556        let sql = gen(&result);
557        assert!(sql.contains("WHERE"));
558    }
559
560    #[test]
561    fn test_canonicalize_and_or() {
562        let expr = parse("SELECT 1 WHERE a AND b OR c");
563        let result = canonicalize(expr, None);
564        let sql = gen(&result);
565        assert!(sql.contains("AND") || sql.contains("OR"));
566    }
567
568    #[test]
569    fn test_canonicalize_comparison() {
570        let expr = parse("SELECT 1 WHERE a = 1 AND b > 2");
571        let result = canonicalize(expr, None);
572        let sql = gen(&result);
573        assert!(sql.contains("=") && sql.contains(">"));
574    }
575
576    #[test]
577    fn test_canonicalize_case() {
578        let expr = parse("SELECT CASE WHEN a = 1 THEN 'yes' ELSE 'no' END FROM t");
579        let result = canonicalize(expr, None);
580        let sql = gen(&result);
581        assert!(sql.contains("CASE") && sql.contains("WHEN"));
582    }
583
584    #[test]
585    fn test_canonicalize_subquery() {
586        let expr = parse("SELECT a FROM (SELECT b FROM t) AS sub");
587        let result = canonicalize(expr, None);
588        let sql = gen(&result);
589        assert!(sql.contains("SELECT") && sql.contains("sub"));
590    }
591
592    #[test]
593    fn test_canonicalize_order_by() {
594        let expr = parse("SELECT a FROM t ORDER BY a");
595        let result = canonicalize(expr, None);
596        let sql = gen(&result);
597        assert!(sql.contains("ORDER BY"));
598    }
599
600    #[test]
601    fn test_canonicalize_union() {
602        let expr = parse("SELECT a FROM t UNION SELECT b FROM s");
603        let result = canonicalize(expr, None);
604        let sql = gen(&result);
605        assert!(sql.contains("UNION"));
606    }
607
608    #[test]
609    fn test_add_text_to_concat_passthrough() {
610        // Test that non-text additions pass through
611        let expr = parse("SELECT 1 + 2");
612        let result = canonicalize(expr, None);
613        let sql = gen(&result);
614        assert!(sql.contains("+"));
615    }
616
617    #[test]
618    fn test_canonicalize_function() {
619        let expr = parse("SELECT MAX(a) FROM t");
620        let result = canonicalize(expr, None);
621        let sql = gen(&result);
622        assert!(sql.contains("MAX"));
623    }
624
625    #[test]
626    fn test_canonicalize_between() {
627        let expr = parse("SELECT 1 WHERE a BETWEEN 1 AND 10");
628        let result = canonicalize(expr, None);
629        let sql = gen(&result);
630        assert!(sql.contains("BETWEEN"));
631    }
632}