Skip to main content

polyglot_sql/optimizer/
canonicalize.rs

1//! Canonicalization Module
2//!
3//! This module provides functionality for converting SQL expressions into a
4//! standard canonical form. This includes:
5//! - Converting string addition to CONCAT
6//! - Replacing date functions with casts
7//! - Removing redundant type casts
8//! - Ensuring boolean predicates
9//! - Removing unnecessary ASC from ORDER BY
10//!
11//! Ported from sqlglot's optimizer/canonicalize.py
12
13use crate::dialects::DialectType;
14use crate::expressions::{DataType, Expression, Literal};
15use crate::helper::{is_iso_date, is_iso_datetime};
16
17/// Converts a SQL expression into a standard canonical form.
18///
19/// This transformation relies on type annotations because many of the
20/// conversions depend on type inference.
21///
22/// # Arguments
23/// * `expression` - The expression to canonicalize
24/// * `dialect` - Optional dialect for dialect-specific behavior
25///
26/// # Returns
27/// The canonicalized expression
28pub fn canonicalize(expression: Expression, dialect: Option<DialectType>) -> Expression {
29    canonicalize_recursive(expression, dialect)
30}
31
32/// Recursively canonicalize an expression and its children
33fn canonicalize_recursive(expression: Expression, dialect: Option<DialectType>) -> Expression {
34    let expr = match expression {
35        Expression::Select(mut select) => {
36            // Canonicalize SELECT expressions
37            select.expressions = select
38                .expressions
39                .into_iter()
40                .map(|e| canonicalize_recursive(e, dialect))
41                .collect();
42
43            // Canonicalize FROM
44            if let Some(mut from) = select.from {
45                from.expressions = from
46                    .expressions
47                    .into_iter()
48                    .map(|e| canonicalize_recursive(e, dialect))
49                    .collect();
50                select.from = Some(from);
51            }
52
53            // Canonicalize WHERE
54            if let Some(mut where_clause) = select.where_clause {
55                where_clause.this = canonicalize_recursive(where_clause.this, dialect);
56                where_clause.this = ensure_bools(where_clause.this);
57                select.where_clause = Some(where_clause);
58            }
59
60            // Canonicalize HAVING
61            if let Some(mut having) = select.having {
62                having.this = canonicalize_recursive(having.this, dialect);
63                having.this = ensure_bools(having.this);
64                select.having = Some(having);
65            }
66
67            // Canonicalize ORDER BY
68            if let Some(mut order_by) = select.order_by {
69                order_by.expressions = order_by
70                    .expressions
71                    .into_iter()
72                    .map(|mut o| {
73                        o.this = canonicalize_recursive(o.this, dialect);
74                        o = remove_ascending_order(o);
75                        o
76                    })
77                    .collect();
78                select.order_by = Some(order_by);
79            }
80
81            // Canonicalize JOINs
82            select.joins = select
83                .joins
84                .into_iter()
85                .map(|mut j| {
86                    j.this = canonicalize_recursive(j.this, dialect);
87                    if let Some(on) = j.on {
88                        j.on = Some(canonicalize_recursive(on, dialect));
89                    }
90                    j
91                })
92                .collect();
93
94            Expression::Select(select)
95        }
96
97        // Binary operations that might involve string addition
98        Expression::Add(bin) => {
99            let left = canonicalize_recursive(bin.left, dialect);
100            let right = canonicalize_recursive(bin.right, dialect);
101            let result = Expression::Add(Box::new(crate::expressions::BinaryOp {
102                left,
103                right,
104                left_comments: bin.left_comments,
105                operator_comments: bin.operator_comments,
106                trailing_comments: bin.trailing_comments,
107                inferred_type: None,
108            }));
109            add_text_to_concat(result)
110        }
111
112        // Other binary operations
113        Expression::And(bin) => {
114            let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
115            let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
116            Expression::And(Box::new(crate::expressions::BinaryOp {
117                left,
118                right,
119                left_comments: bin.left_comments,
120                operator_comments: bin.operator_comments,
121                trailing_comments: bin.trailing_comments,
122                inferred_type: None,
123            }))
124        }
125        Expression::Or(bin) => {
126            let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
127            let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
128            Expression::Or(Box::new(crate::expressions::BinaryOp {
129                left,
130                right,
131                left_comments: bin.left_comments,
132                operator_comments: bin.operator_comments,
133                trailing_comments: bin.trailing_comments,
134                inferred_type: None,
135            }))
136        }
137
138        Expression::Not(un) => {
139            let inner = ensure_bools(canonicalize_recursive(un.this, dialect));
140            Expression::Not(Box::new(crate::expressions::UnaryOp {
141                this: inner,
142                inferred_type: None,
143            }))
144        }
145
146        // Comparison operations - check for date coercion
147        Expression::Eq(bin) => canonicalize_comparison(Expression::Eq, *bin, dialect),
148        Expression::Neq(bin) => canonicalize_comparison(Expression::Neq, *bin, dialect),
149        Expression::Lt(bin) => canonicalize_comparison(Expression::Lt, *bin, dialect),
150        Expression::Lte(bin) => canonicalize_comparison(Expression::Lte, *bin, dialect),
151        Expression::Gt(bin) => canonicalize_comparison(Expression::Gt, *bin, dialect),
152        Expression::Gte(bin) => canonicalize_comparison(Expression::Gte, *bin, dialect),
153
154        Expression::Sub(bin) => canonicalize_comparison(Expression::Sub, *bin, dialect),
155        Expression::Mul(bin) => canonicalize_binary(Expression::Mul, *bin, dialect),
156        Expression::Div(bin) => canonicalize_binary(Expression::Div, *bin, dialect),
157
158        // Cast - check for redundancy
159        Expression::Cast(cast) => {
160            let inner = canonicalize_recursive(cast.this, dialect);
161            let result = Expression::Cast(Box::new(crate::expressions::Cast {
162                this: inner,
163                to: cast.to,
164                trailing_comments: cast.trailing_comments,
165                double_colon_syntax: cast.double_colon_syntax,
166                format: cast.format,
167                default: cast.default,
168                inferred_type: None,
169            }));
170            remove_redundant_casts(result)
171        }
172
173        // Function expressions
174        Expression::Function(func) => {
175            let args = func
176                .args
177                .into_iter()
178                .map(|e| canonicalize_recursive(e, dialect))
179                .collect();
180            Expression::Function(Box::new(crate::expressions::Function {
181                name: func.name,
182                args,
183                distinct: func.distinct,
184                trailing_comments: func.trailing_comments,
185                use_bracket_syntax: func.use_bracket_syntax,
186                no_parens: func.no_parens,
187                quoted: func.quoted,
188                span: None,
189                inferred_type: None,
190            }))
191        }
192
193        Expression::AggregateFunction(agg) => {
194            let args = agg
195                .args
196                .into_iter()
197                .map(|e| canonicalize_recursive(e, dialect))
198                .collect();
199            Expression::AggregateFunction(Box::new(crate::expressions::AggregateFunction {
200                name: agg.name,
201                args,
202                distinct: agg.distinct,
203                filter: agg.filter.map(|f| canonicalize_recursive(f, dialect)),
204                order_by: agg.order_by,
205                limit: agg.limit,
206                ignore_nulls: agg.ignore_nulls,
207                inferred_type: None,
208            }))
209        }
210
211        // Alias
212        Expression::Alias(alias) => {
213            let inner = canonicalize_recursive(alias.this, dialect);
214            Expression::Alias(Box::new(crate::expressions::Alias {
215                this: inner,
216                alias: alias.alias,
217                column_aliases: alias.column_aliases,
218                pre_alias_comments: alias.pre_alias_comments,
219                trailing_comments: alias.trailing_comments,
220                inferred_type: None,
221            }))
222        }
223
224        // Paren
225        Expression::Paren(paren) => {
226            let inner = canonicalize_recursive(paren.this, dialect);
227            Expression::Paren(Box::new(crate::expressions::Paren {
228                this: inner,
229                trailing_comments: paren.trailing_comments,
230            }))
231        }
232
233        // Case
234        Expression::Case(case) => {
235            let operand = case.operand.map(|e| canonicalize_recursive(e, dialect));
236            let whens = case
237                .whens
238                .into_iter()
239                .map(|(w, t)| {
240                    (
241                        canonicalize_recursive(w, dialect),
242                        canonicalize_recursive(t, dialect),
243                    )
244                })
245                .collect();
246            let else_ = case.else_.map(|e| canonicalize_recursive(e, dialect));
247            Expression::Case(Box::new(crate::expressions::Case {
248                operand,
249                whens,
250                else_,
251                comments: Vec::new(),
252                inferred_type: None,
253            }))
254        }
255
256        // Between - check for date coercion
257        Expression::Between(between) => {
258            let this = canonicalize_recursive(between.this, dialect);
259            let low = canonicalize_recursive(between.low, dialect);
260            let high = canonicalize_recursive(between.high, dialect);
261            Expression::Between(Box::new(crate::expressions::Between {
262                this,
263                low,
264                high,
265                not: between.not,
266                symmetric: between.symmetric,
267            }))
268        }
269
270        // In
271        Expression::In(in_expr) => {
272            let this = canonicalize_recursive(in_expr.this, dialect);
273            let expressions = in_expr
274                .expressions
275                .into_iter()
276                .map(|e| canonicalize_recursive(e, dialect))
277                .collect();
278            let query = in_expr.query.map(|q| canonicalize_recursive(q, dialect));
279            Expression::In(Box::new(crate::expressions::In {
280                this,
281                expressions,
282                query,
283                not: in_expr.not,
284                global: in_expr.global,
285                unnest: in_expr.unnest,
286                is_field: in_expr.is_field,
287            }))
288        }
289
290        // Subquery
291        Expression::Subquery(subquery) => {
292            let this = canonicalize_recursive(subquery.this, dialect);
293            Expression::Subquery(Box::new(crate::expressions::Subquery {
294                this,
295                alias: subquery.alias,
296                column_aliases: subquery.column_aliases,
297                order_by: subquery.order_by,
298                limit: subquery.limit,
299                offset: subquery.offset,
300                distribute_by: subquery.distribute_by,
301                sort_by: subquery.sort_by,
302                cluster_by: subquery.cluster_by,
303                lateral: subquery.lateral,
304                modifiers_inside: subquery.modifiers_inside,
305                trailing_comments: subquery.trailing_comments,
306                inferred_type: None,
307            }))
308        }
309
310        // Set operations
311        Expression::Union(union) => {
312            let left = canonicalize_recursive(union.left, dialect);
313            let right = canonicalize_recursive(union.right, dialect);
314            Expression::Union(Box::new(crate::expressions::Union {
315                left,
316                right,
317                all: union.all,
318                distinct: union.distinct,
319                with: union.with,
320                order_by: union.order_by,
321                limit: union.limit,
322                offset: union.offset,
323                distribute_by: union.distribute_by,
324                sort_by: union.sort_by,
325                cluster_by: union.cluster_by,
326                by_name: union.by_name,
327                side: union.side,
328                kind: union.kind,
329                corresponding: union.corresponding,
330                strict: union.strict,
331                on_columns: union.on_columns,
332            }))
333        }
334        Expression::Intersect(intersect) => {
335            let left = canonicalize_recursive(intersect.left, dialect);
336            let right = canonicalize_recursive(intersect.right, dialect);
337            Expression::Intersect(Box::new(crate::expressions::Intersect {
338                left,
339                right,
340                all: intersect.all,
341                distinct: intersect.distinct,
342                with: intersect.with,
343                order_by: intersect.order_by,
344                limit: intersect.limit,
345                offset: intersect.offset,
346                distribute_by: intersect.distribute_by,
347                sort_by: intersect.sort_by,
348                cluster_by: intersect.cluster_by,
349                by_name: intersect.by_name,
350                side: intersect.side,
351                kind: intersect.kind,
352                corresponding: intersect.corresponding,
353                strict: intersect.strict,
354                on_columns: intersect.on_columns,
355            }))
356        }
357        Expression::Except(except) => {
358            let left = canonicalize_recursive(except.left, dialect);
359            let right = canonicalize_recursive(except.right, dialect);
360            Expression::Except(Box::new(crate::expressions::Except {
361                left,
362                right,
363                all: except.all,
364                distinct: except.distinct,
365                with: except.with,
366                order_by: except.order_by,
367                limit: except.limit,
368                offset: except.offset,
369                distribute_by: except.distribute_by,
370                sort_by: except.sort_by,
371                cluster_by: except.cluster_by,
372                by_name: except.by_name,
373                side: except.side,
374                kind: except.kind,
375                corresponding: except.corresponding,
376                strict: except.strict,
377                on_columns: except.on_columns,
378            }))
379        }
380
381        // Leaf nodes - return unchanged
382        other => other,
383    };
384
385    expr
386}
387
388/// Convert string addition to CONCAT.
389///
390/// When two TEXT types are added with +, convert to CONCAT.
391/// This is used by dialects like T-SQL and Redshift.
392fn add_text_to_concat(expression: Expression) -> Expression {
393    // In a full implementation, we would check if the operands are TEXT types
394    // and convert to CONCAT. For now, we return unchanged.
395    expression
396}
397
398/// Remove redundant cast expressions.
399///
400/// If casting to the same type the expression already is, remove the cast.
401fn remove_redundant_casts(expression: Expression) -> Expression {
402    if let Expression::Cast(cast) = &expression {
403        // Check if the inner expression's type matches the cast target
404        // In a full implementation with type annotations, we would compare types
405        // For now, just check simple cases
406
407        // If casting a literal to its natural type, we might be able to simplify
408        if let Expression::Literal(Literal::String(_)) = &cast.this {
409            if matches!(&cast.to, DataType::VarChar { .. } | DataType::Text) {
410                return cast.this.clone();
411            }
412        }
413        if let Expression::Literal(Literal::Number(_)) = &cast.this {
414            if matches!(
415                &cast.to,
416                DataType::Int { .. }
417                    | DataType::BigInt { .. }
418                    | DataType::Decimal { .. }
419                    | DataType::Float { .. }
420            ) {
421                // Could potentially remove cast, but be conservative
422            }
423        }
424    }
425    expression
426}
427
428/// Ensure expressions used as boolean predicates are actually boolean.
429///
430/// For example, in some dialects, integers can be used as booleans.
431/// This function ensures proper boolean semantics.
432fn ensure_bools(expression: Expression) -> Expression {
433    // In a full implementation, we would check if the expression is an integer
434    // and convert it to a comparison (e.g., x != 0).
435    // For now, return unchanged.
436    expression
437}
438
439/// Remove explicit ASC from ORDER BY clauses.
440///
441/// Since ASC is the default, `ORDER BY a ASC` can be simplified to `ORDER BY a`.
442fn remove_ascending_order(mut ordered: crate::expressions::Ordered) -> crate::expressions::Ordered {
443    // If ASC was explicitly written (not DESC), remove the explicit flag
444    // since ASC is the default ordering
445    if !ordered.desc && ordered.explicit_asc {
446        ordered.explicit_asc = false;
447    }
448    ordered
449}
450
451/// Canonicalize a binary comparison operation.
452fn canonicalize_comparison<F>(
453    constructor: F,
454    bin: crate::expressions::BinaryOp,
455    dialect: Option<DialectType>,
456) -> Expression
457where
458    F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
459{
460    let left = canonicalize_recursive(bin.left, dialect);
461    let right = canonicalize_recursive(bin.right, dialect);
462
463    // Check for date coercion opportunities
464    let (left, right) = coerce_date_operands(left, right);
465
466    constructor(Box::new(crate::expressions::BinaryOp {
467        left,
468        right,
469        left_comments: bin.left_comments,
470        operator_comments: bin.operator_comments,
471        trailing_comments: bin.trailing_comments,
472        inferred_type: None,
473    }))
474}
475
476/// Canonicalize a regular binary operation.
477fn canonicalize_binary<F>(
478    constructor: F,
479    bin: crate::expressions::BinaryOp,
480    dialect: Option<DialectType>,
481) -> Expression
482where
483    F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
484{
485    let left = canonicalize_recursive(bin.left, dialect);
486    let right = canonicalize_recursive(bin.right, dialect);
487
488    constructor(Box::new(crate::expressions::BinaryOp {
489        left,
490        right,
491        left_comments: bin.left_comments,
492        operator_comments: bin.operator_comments,
493        trailing_comments: bin.trailing_comments,
494        inferred_type: None,
495    }))
496}
497
498/// Coerce date operands in comparisons.
499///
500/// When comparing a date/datetime column with a string literal,
501/// add appropriate CAST to the string.
502fn coerce_date_operands(left: Expression, right: Expression) -> (Expression, Expression) {
503    // Check if we should cast string literals to date/datetime
504    let left = coerce_date_string(left, &right);
505    let right = coerce_date_string(right, &left);
506    (left, right)
507}
508
509/// Coerce a string literal to date/datetime if comparing with a temporal type.
510fn coerce_date_string(expr: Expression, _other: &Expression) -> Expression {
511    if let Expression::Literal(Literal::String(ref s)) = expr {
512        // Check if the string is an ISO date or datetime
513        if is_iso_date(s) {
514            // In a full implementation, we would add CAST to DATE
515            // For now, return unchanged
516        } else if is_iso_datetime(s) {
517            // In a full implementation, we would add CAST to DATETIME/TIMESTAMP
518            // For now, return unchanged
519        }
520    }
521    expr
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527    use crate::generator::Generator;
528    use crate::parser::Parser;
529
530    fn gen(expr: &Expression) -> String {
531        Generator::new().generate(expr).unwrap()
532    }
533
534    fn parse(sql: &str) -> Expression {
535        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
536    }
537
538    #[test]
539    fn test_canonicalize_simple() {
540        let expr = parse("SELECT a FROM t");
541        let result = canonicalize(expr, None);
542        let sql = gen(&result);
543        assert!(sql.contains("SELECT"));
544    }
545
546    #[test]
547    fn test_canonicalize_preserves_structure() {
548        let expr = parse("SELECT a, b FROM t WHERE c = 1");
549        let result = canonicalize(expr, None);
550        let sql = gen(&result);
551        assert!(sql.contains("WHERE"));
552    }
553
554    #[test]
555    fn test_canonicalize_and_or() {
556        let expr = parse("SELECT 1 WHERE a AND b OR c");
557        let result = canonicalize(expr, None);
558        let sql = gen(&result);
559        assert!(sql.contains("AND") || sql.contains("OR"));
560    }
561
562    #[test]
563    fn test_canonicalize_comparison() {
564        let expr = parse("SELECT 1 WHERE a = 1 AND b > 2");
565        let result = canonicalize(expr, None);
566        let sql = gen(&result);
567        assert!(sql.contains("=") && sql.contains(">"));
568    }
569
570    #[test]
571    fn test_canonicalize_case() {
572        let expr = parse("SELECT CASE WHEN a = 1 THEN 'yes' ELSE 'no' END FROM t");
573        let result = canonicalize(expr, None);
574        let sql = gen(&result);
575        assert!(sql.contains("CASE") && sql.contains("WHEN"));
576    }
577
578    #[test]
579    fn test_canonicalize_subquery() {
580        let expr = parse("SELECT a FROM (SELECT b FROM t) AS sub");
581        let result = canonicalize(expr, None);
582        let sql = gen(&result);
583        assert!(sql.contains("SELECT") && sql.contains("sub"));
584    }
585
586    #[test]
587    fn test_canonicalize_order_by() {
588        let expr = parse("SELECT a FROM t ORDER BY a");
589        let result = canonicalize(expr, None);
590        let sql = gen(&result);
591        assert!(sql.contains("ORDER BY"));
592    }
593
594    #[test]
595    fn test_canonicalize_union() {
596        let expr = parse("SELECT a FROM t UNION SELECT b FROM s");
597        let result = canonicalize(expr, None);
598        let sql = gen(&result);
599        assert!(sql.contains("UNION"));
600    }
601
602    #[test]
603    fn test_add_text_to_concat_passthrough() {
604        // Test that non-text additions pass through
605        let expr = parse("SELECT 1 + 2");
606        let result = canonicalize(expr, None);
607        let sql = gen(&result);
608        assert!(sql.contains("+"));
609    }
610
611    #[test]
612    fn test_canonicalize_function() {
613        let expr = parse("SELECT MAX(a) FROM t");
614        let result = canonicalize(expr, None);
615        let sql = gen(&result);
616        assert!(sql.contains("MAX"));
617    }
618
619    #[test]
620    fn test_canonicalize_between() {
621        let expr = parse("SELECT 1 WHERE a BETWEEN 1 AND 10");
622        let result = canonicalize(expr, None);
623        let sql = gen(&result);
624        assert!(sql.contains("BETWEEN"));
625    }
626}