Skip to main content

polyglot_sql/optimizer/
canonicalize.rs

1//! Canonicalization Module
2//!
3//! This module provides functionality for converting SQL expressions into a
4//! standard canonical form. This includes:
5//! - Converting string addition to CONCAT
6//! - Replacing date functions with casts
7//! - Removing redundant type casts
8//! - Ensuring boolean predicates
9//! - Removing unnecessary ASC from ORDER BY
10//!
11//! Ported from sqlglot's optimizer/canonicalize.py
12
13use crate::dialects::DialectType;
14use crate::expressions::{DataType, Expression, Literal, Null};
15use crate::helper::{is_iso_date, is_iso_datetime};
16
17/// Converts a SQL expression into a standard canonical form.
18///
19/// This transformation relies on type annotations because many of the
20/// conversions depend on type inference.
21///
22/// # Arguments
23/// * `expression` - The expression to canonicalize
24/// * `dialect` - Optional dialect for dialect-specific behavior
25///
26/// # Returns
27/// The canonicalized expression
28pub fn canonicalize(expression: Expression, dialect: Option<DialectType>) -> Expression {
29    canonicalize_recursive(expression, dialect)
30}
31
32/// Recursively canonicalize an expression and its children
33fn canonicalize_recursive(expression: Expression, dialect: Option<DialectType>) -> Expression {
34    let expr = match expression {
35        Expression::Select(mut select) => {
36            // Canonicalize SELECT expressions
37            select.expressions = select
38                .expressions
39                .into_iter()
40                .map(|e| canonicalize_recursive(e, dialect))
41                .collect();
42
43            // Canonicalize FROM
44            if let Some(mut from) = select.from {
45                from.expressions = from
46                    .expressions
47                    .into_iter()
48                    .map(|e| canonicalize_recursive(e, dialect))
49                    .collect();
50                select.from = Some(from);
51            }
52
53            // Canonicalize WHERE
54            if let Some(mut where_clause) = select.where_clause {
55                where_clause.this = canonicalize_recursive(where_clause.this, dialect);
56                where_clause.this = ensure_bools(where_clause.this);
57                select.where_clause = Some(where_clause);
58            }
59
60            // Canonicalize HAVING
61            if let Some(mut having) = select.having {
62                having.this = canonicalize_recursive(having.this, dialect);
63                having.this = ensure_bools(having.this);
64                select.having = Some(having);
65            }
66
67            // Canonicalize ORDER BY
68            if let Some(mut order_by) = select.order_by {
69                order_by.expressions = order_by
70                    .expressions
71                    .into_iter()
72                    .map(|mut o| {
73                        o.this = canonicalize_recursive(o.this, dialect);
74                        o = remove_ascending_order(o);
75                        o
76                    })
77                    .collect();
78                select.order_by = Some(order_by);
79            }
80
81            // Canonicalize JOINs
82            select.joins = select
83                .joins
84                .into_iter()
85                .map(|mut j| {
86                    j.this = canonicalize_recursive(j.this, dialect);
87                    if let Some(on) = j.on {
88                        j.on = Some(canonicalize_recursive(on, dialect));
89                    }
90                    j
91                })
92                .collect();
93
94            Expression::Select(select)
95        }
96
97        // Binary operations that might involve string addition
98        Expression::Add(bin) => {
99            let left = canonicalize_recursive(bin.left, dialect);
100            let right = canonicalize_recursive(bin.right, dialect);
101            let result = Expression::Add(Box::new(crate::expressions::BinaryOp {
102                left,
103                right,
104                left_comments: bin.left_comments,
105                operator_comments: bin.operator_comments,
106                trailing_comments: bin.trailing_comments,
107                inferred_type: None,
108            }));
109            add_text_to_concat(result)
110        }
111
112        // Other binary operations
113        Expression::And(bin) => {
114            let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
115            let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
116            Expression::And(Box::new(crate::expressions::BinaryOp {
117                left,
118                right,
119                left_comments: bin.left_comments,
120                operator_comments: bin.operator_comments,
121                trailing_comments: bin.trailing_comments,
122                inferred_type: None,
123            }))
124        }
125        Expression::Or(bin) => {
126            let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
127            let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
128            Expression::Or(Box::new(crate::expressions::BinaryOp {
129                left,
130                right,
131                left_comments: bin.left_comments,
132                operator_comments: bin.operator_comments,
133                trailing_comments: bin.trailing_comments,
134                inferred_type: None,
135            }))
136        }
137
138        Expression::Not(un) => {
139            let inner = ensure_bools(canonicalize_recursive(un.this, dialect));
140            Expression::Not(Box::new(crate::expressions::UnaryOp {
141                this: inner,
142                inferred_type: None,
143            }))
144        }
145
146        // Comparison operations - check for date coercion
147        Expression::Eq(bin) => canonicalize_comparison(Expression::Eq, *bin, dialect),
148        Expression::Neq(bin) => canonicalize_comparison(Expression::Neq, *bin, dialect),
149        Expression::Lt(bin) => canonicalize_comparison(Expression::Lt, *bin, dialect),
150        Expression::Lte(bin) => canonicalize_comparison(Expression::Lte, *bin, dialect),
151        Expression::Gt(bin) => canonicalize_comparison(Expression::Gt, *bin, dialect),
152        Expression::Gte(bin) => canonicalize_comparison(Expression::Gte, *bin, dialect),
153
154        Expression::Sub(bin) => canonicalize_comparison(Expression::Sub, *bin, dialect),
155        Expression::Mul(bin) => canonicalize_binary(Expression::Mul, *bin, dialect),
156        Expression::Div(bin) => canonicalize_binary(Expression::Div, *bin, dialect),
157
158        // Cast - check for redundancy
159        Expression::Cast(cast) => {
160            let inner = canonicalize_recursive(cast.this, dialect);
161            let result = Expression::Cast(Box::new(crate::expressions::Cast {
162                this: inner,
163                to: cast.to,
164                trailing_comments: cast.trailing_comments,
165                double_colon_syntax: cast.double_colon_syntax,
166                format: cast.format,
167                default: cast.default,
168                inferred_type: None,
169            }));
170            remove_redundant_casts(result)
171        }
172
173        // Function expressions
174        Expression::Function(func) => {
175            let args = func
176                .args
177                .into_iter()
178                .map(|e| canonicalize_recursive(e, dialect))
179                .collect();
180            Expression::Function(Box::new(crate::expressions::Function {
181                name: func.name,
182                args,
183                distinct: func.distinct,
184                trailing_comments: func.trailing_comments,
185                use_bracket_syntax: func.use_bracket_syntax,
186                no_parens: func.no_parens,
187                quoted: func.quoted,
188                span: None,
189                inferred_type: None,
190            }))
191        }
192
193        Expression::AggregateFunction(agg) => {
194            let args = agg
195                .args
196                .into_iter()
197                .map(|e| canonicalize_recursive(e, dialect))
198                .collect();
199            Expression::AggregateFunction(Box::new(crate::expressions::AggregateFunction {
200                name: agg.name,
201                args,
202                distinct: agg.distinct,
203                filter: agg.filter.map(|f| canonicalize_recursive(f, dialect)),
204                order_by: agg.order_by,
205                limit: agg.limit,
206                ignore_nulls: agg.ignore_nulls,
207                inferred_type: None,
208            }))
209        }
210
211        // Alias
212        Expression::Alias(alias) => {
213            let inner = canonicalize_recursive(alias.this, dialect);
214            Expression::Alias(Box::new(crate::expressions::Alias {
215                this: inner,
216                alias: alias.alias,
217                column_aliases: alias.column_aliases,
218                pre_alias_comments: alias.pre_alias_comments,
219                trailing_comments: alias.trailing_comments,
220                inferred_type: None,
221            }))
222        }
223
224        // Paren
225        Expression::Paren(paren) => {
226            let inner = canonicalize_recursive(paren.this, dialect);
227            Expression::Paren(Box::new(crate::expressions::Paren {
228                this: inner,
229                trailing_comments: paren.trailing_comments,
230            }))
231        }
232
233        // Case
234        Expression::Case(case) => {
235            let operand = case.operand.map(|e| canonicalize_recursive(e, dialect));
236            let whens = case
237                .whens
238                .into_iter()
239                .map(|(w, t)| {
240                    (
241                        canonicalize_recursive(w, dialect),
242                        canonicalize_recursive(t, dialect),
243                    )
244                })
245                .collect();
246            let else_ = case.else_.map(|e| canonicalize_recursive(e, dialect));
247            Expression::Case(Box::new(crate::expressions::Case {
248                operand,
249                whens,
250                else_,
251                comments: Vec::new(),
252                inferred_type: None,
253            }))
254        }
255
256        // Between - check for date coercion
257        Expression::Between(between) => {
258            let this = canonicalize_recursive(between.this, dialect);
259            let low = canonicalize_recursive(between.low, dialect);
260            let high = canonicalize_recursive(between.high, dialect);
261            Expression::Between(Box::new(crate::expressions::Between {
262                this,
263                low,
264                high,
265                not: between.not,
266                symmetric: between.symmetric,
267            }))
268        }
269
270        // In
271        Expression::In(in_expr) => {
272            let this = canonicalize_recursive(in_expr.this, dialect);
273            let expressions = in_expr
274                .expressions
275                .into_iter()
276                .map(|e| canonicalize_recursive(e, dialect))
277                .collect();
278            let query = in_expr.query.map(|q| canonicalize_recursive(q, dialect));
279            Expression::In(Box::new(crate::expressions::In {
280                this,
281                expressions,
282                query,
283                not: in_expr.not,
284                global: in_expr.global,
285                unnest: in_expr.unnest,
286                is_field: in_expr.is_field,
287            }))
288        }
289
290        // Subquery
291        Expression::Subquery(subquery) => {
292            let this = canonicalize_recursive(subquery.this, dialect);
293            Expression::Subquery(Box::new(crate::expressions::Subquery {
294                this,
295                alias: subquery.alias,
296                column_aliases: subquery.column_aliases,
297                order_by: subquery.order_by,
298                limit: subquery.limit,
299                offset: subquery.offset,
300                distribute_by: subquery.distribute_by,
301                sort_by: subquery.sort_by,
302                cluster_by: subquery.cluster_by,
303                lateral: subquery.lateral,
304                modifiers_inside: subquery.modifiers_inside,
305                trailing_comments: subquery.trailing_comments,
306                inferred_type: None,
307            }))
308        }
309
310        // Set operations
311        Expression::Union(union) => {
312            let mut u = *union;
313            let left = std::mem::replace(&mut u.left, Expression::Null(Null));
314            u.left = canonicalize_recursive(left, dialect);
315            let right = std::mem::replace(&mut u.right, Expression::Null(Null));
316            u.right = canonicalize_recursive(right, dialect);
317            Expression::Union(Box::new(u))
318        }
319        Expression::Intersect(intersect) => {
320            let mut i = *intersect;
321            let left = std::mem::replace(&mut i.left, Expression::Null(Null));
322            i.left = canonicalize_recursive(left, dialect);
323            let right = std::mem::replace(&mut i.right, Expression::Null(Null));
324            i.right = canonicalize_recursive(right, dialect);
325            Expression::Intersect(Box::new(i))
326        }
327        Expression::Except(except) => {
328            let mut e = *except;
329            let left = std::mem::replace(&mut e.left, Expression::Null(Null));
330            e.left = canonicalize_recursive(left, dialect);
331            let right = std::mem::replace(&mut e.right, Expression::Null(Null));
332            e.right = canonicalize_recursive(right, dialect);
333            Expression::Except(Box::new(e))
334        }
335
336        // Leaf nodes - return unchanged
337        other => other,
338    };
339
340    expr
341}
342
343/// Convert string addition to CONCAT.
344///
345/// When two TEXT types are added with +, convert to CONCAT.
346/// This is used by dialects like T-SQL and Redshift.
347fn add_text_to_concat(expression: Expression) -> Expression {
348    // In a full implementation, we would check if the operands are TEXT types
349    // and convert to CONCAT. For now, we return unchanged.
350    expression
351}
352
353/// Remove redundant cast expressions.
354///
355/// If casting to the same type the expression already is, remove the cast.
356fn remove_redundant_casts(expression: Expression) -> Expression {
357    if let Expression::Cast(cast) = &expression {
358        // Check if the inner expression's type matches the cast target
359        // In a full implementation with type annotations, we would compare types
360        // For now, just check simple cases
361
362        // If casting a literal to its natural type, we might be able to simplify
363        if let Expression::Literal(lit) = &cast.this {
364            if let Literal::String(_) = lit.as_ref() {
365                if matches!(&cast.to, DataType::VarChar { .. } | DataType::Text) {
366                    return cast.this.clone();
367                }
368            }
369        }
370        if let Expression::Literal(lit) = &cast.this {
371            if let Literal::Number(_) = lit.as_ref() {
372                if matches!(
373                    &cast.to,
374                    DataType::Int { .. }
375                        | DataType::BigInt { .. }
376                        | DataType::Decimal { .. }
377                        | DataType::Float { .. }
378                ) {
379                    // Could potentially remove cast, but be conservative
380                }
381            }
382        }
383    }
384    expression
385}
386
387/// Ensure expressions used as boolean predicates are actually boolean.
388///
389/// For example, in some dialects, integers can be used as booleans.
390/// This function ensures proper boolean semantics.
391fn ensure_bools(expression: Expression) -> Expression {
392    // In a full implementation, we would check if the expression is an integer
393    // and convert it to a comparison (e.g., x != 0).
394    // For now, return unchanged.
395    expression
396}
397
398/// Remove explicit ASC from ORDER BY clauses.
399///
400/// Since ASC is the default, `ORDER BY a ASC` can be simplified to `ORDER BY a`.
401fn remove_ascending_order(mut ordered: crate::expressions::Ordered) -> crate::expressions::Ordered {
402    // If ASC was explicitly written (not DESC), remove the explicit flag
403    // since ASC is the default ordering
404    if !ordered.desc && ordered.explicit_asc {
405        ordered.explicit_asc = false;
406    }
407    ordered
408}
409
410/// Canonicalize a binary comparison operation.
411fn canonicalize_comparison<F>(
412    constructor: F,
413    bin: crate::expressions::BinaryOp,
414    dialect: Option<DialectType>,
415) -> Expression
416where
417    F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
418{
419    let left = canonicalize_recursive(bin.left, dialect);
420    let right = canonicalize_recursive(bin.right, dialect);
421
422    // Check for date coercion opportunities
423    let (left, right) = coerce_date_operands(left, right);
424
425    constructor(Box::new(crate::expressions::BinaryOp {
426        left,
427        right,
428        left_comments: bin.left_comments,
429        operator_comments: bin.operator_comments,
430        trailing_comments: bin.trailing_comments,
431        inferred_type: None,
432    }))
433}
434
435/// Canonicalize a regular binary operation.
436fn canonicalize_binary<F>(
437    constructor: F,
438    bin: crate::expressions::BinaryOp,
439    dialect: Option<DialectType>,
440) -> Expression
441where
442    F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
443{
444    let left = canonicalize_recursive(bin.left, dialect);
445    let right = canonicalize_recursive(bin.right, dialect);
446
447    constructor(Box::new(crate::expressions::BinaryOp {
448        left,
449        right,
450        left_comments: bin.left_comments,
451        operator_comments: bin.operator_comments,
452        trailing_comments: bin.trailing_comments,
453        inferred_type: None,
454    }))
455}
456
457/// Coerce date operands in comparisons.
458///
459/// When comparing a date/datetime column with a string literal,
460/// add appropriate CAST to the string.
461fn coerce_date_operands(left: Expression, right: Expression) -> (Expression, Expression) {
462    // Check if we should cast string literals to date/datetime
463    let left = coerce_date_string(left, &right);
464    let right = coerce_date_string(right, &left);
465    (left, right)
466}
467
468/// Coerce a string literal to date/datetime if comparing with a temporal type.
469fn coerce_date_string(expr: Expression, _other: &Expression) -> Expression {
470    if let Expression::Literal(ref lit) = expr {
471        if let Literal::String(ref s) = lit.as_ref() {
472            // Check if the string is an ISO date or datetime
473            if is_iso_date(s) {
474                // In a full implementation, we would add CAST to DATE
475                // For now, return unchanged
476            } else if is_iso_datetime(s) {
477                // In a full implementation, we would add CAST to DATETIME/TIMESTAMP
478                // For now, return unchanged
479            }
480        }
481    }
482    expr
483}
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488    use crate::generator::Generator;
489    use crate::parser::Parser;
490
491    fn gen(expr: &Expression) -> String {
492        Generator::new().generate(expr).unwrap()
493    }
494
495    fn parse(sql: &str) -> Expression {
496        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
497    }
498
499    #[test]
500    fn test_canonicalize_simple() {
501        let expr = parse("SELECT a FROM t");
502        let result = canonicalize(expr, None);
503        let sql = gen(&result);
504        assert!(sql.contains("SELECT"));
505    }
506
507    #[test]
508    fn test_canonicalize_preserves_structure() {
509        let expr = parse("SELECT a, b FROM t WHERE c = 1");
510        let result = canonicalize(expr, None);
511        let sql = gen(&result);
512        assert!(sql.contains("WHERE"));
513    }
514
515    #[test]
516    fn test_canonicalize_and_or() {
517        let expr = parse("SELECT 1 WHERE a AND b OR c");
518        let result = canonicalize(expr, None);
519        let sql = gen(&result);
520        assert!(sql.contains("AND") || sql.contains("OR"));
521    }
522
523    #[test]
524    fn test_canonicalize_comparison() {
525        let expr = parse("SELECT 1 WHERE a = 1 AND b > 2");
526        let result = canonicalize(expr, None);
527        let sql = gen(&result);
528        assert!(sql.contains("=") && sql.contains(">"));
529    }
530
531    #[test]
532    fn test_canonicalize_case() {
533        let expr = parse("SELECT CASE WHEN a = 1 THEN 'yes' ELSE 'no' END FROM t");
534        let result = canonicalize(expr, None);
535        let sql = gen(&result);
536        assert!(sql.contains("CASE") && sql.contains("WHEN"));
537    }
538
539    #[test]
540    fn test_canonicalize_subquery() {
541        let expr = parse("SELECT a FROM (SELECT b FROM t) AS sub");
542        let result = canonicalize(expr, None);
543        let sql = gen(&result);
544        assert!(sql.contains("SELECT") && sql.contains("sub"));
545    }
546
547    #[test]
548    fn test_canonicalize_order_by() {
549        let expr = parse("SELECT a FROM t ORDER BY a");
550        let result = canonicalize(expr, None);
551        let sql = gen(&result);
552        assert!(sql.contains("ORDER BY"));
553    }
554
555    #[test]
556    fn test_canonicalize_union() {
557        let expr = parse("SELECT a FROM t UNION SELECT b FROM s");
558        let result = canonicalize(expr, None);
559        let sql = gen(&result);
560        assert!(sql.contains("UNION"));
561    }
562
563    #[test]
564    fn test_add_text_to_concat_passthrough() {
565        // Test that non-text additions pass through
566        let expr = parse("SELECT 1 + 2");
567        let result = canonicalize(expr, None);
568        let sql = gen(&result);
569        assert!(sql.contains("+"));
570    }
571
572    #[test]
573    fn test_canonicalize_function() {
574        let expr = parse("SELECT MAX(a) FROM t");
575        let result = canonicalize(expr, None);
576        let sql = gen(&result);
577        assert!(sql.contains("MAX"));
578    }
579
580    #[test]
581    fn test_canonicalize_between() {
582        let expr = parse("SELECT 1 WHERE a BETWEEN 1 AND 10");
583        let result = canonicalize(expr, None);
584        let sql = gen(&result);
585        assert!(sql.contains("BETWEEN"));
586    }
587}