Skip to main content

polyglot_sql/optimizer/
canonicalize.rs

1//! Canonicalization Module
2//!
3//! This module provides functionality for converting SQL expressions into a
4//! standard canonical form. This includes:
5//! - Converting string addition to CONCAT
6//! - Replacing date functions with casts
7//! - Removing redundant type casts
8//! - Ensuring boolean predicates
9//! - Removing unnecessary ASC from ORDER BY
10//!
11//! Ported from sqlglot's optimizer/canonicalize.py
12
13use crate::dialects::DialectType;
14use crate::expressions::{DataType, Expression, Literal};
15use crate::helper::{is_iso_date, is_iso_datetime};
16
17/// Converts a SQL expression into a standard canonical form.
18///
19/// This transformation relies on type annotations because many of the
20/// conversions depend on type inference.
21///
22/// # Arguments
23/// * `expression` - The expression to canonicalize
24/// * `dialect` - Optional dialect for dialect-specific behavior
25///
26/// # Returns
27/// The canonicalized expression
28pub fn canonicalize(expression: Expression, dialect: Option<DialectType>) -> Expression {
29    canonicalize_recursive(expression, dialect)
30}
31
32/// Recursively canonicalize an expression and its children
33fn canonicalize_recursive(expression: Expression, dialect: Option<DialectType>) -> Expression {
34    let expr = match expression {
35        Expression::Select(mut select) => {
36            // Canonicalize SELECT expressions
37            select.expressions = select
38                .expressions
39                .into_iter()
40                .map(|e| canonicalize_recursive(e, dialect))
41                .collect();
42
43            // Canonicalize FROM
44            if let Some(mut from) = select.from {
45                from.expressions = from
46                    .expressions
47                    .into_iter()
48                    .map(|e| canonicalize_recursive(e, dialect))
49                    .collect();
50                select.from = Some(from);
51            }
52
53            // Canonicalize WHERE
54            if let Some(mut where_clause) = select.where_clause {
55                where_clause.this = canonicalize_recursive(where_clause.this, dialect);
56                where_clause.this = ensure_bools(where_clause.this);
57                select.where_clause = Some(where_clause);
58            }
59
60            // Canonicalize HAVING
61            if let Some(mut having) = select.having {
62                having.this = canonicalize_recursive(having.this, dialect);
63                having.this = ensure_bools(having.this);
64                select.having = Some(having);
65            }
66
67            // Canonicalize ORDER BY
68            if let Some(mut order_by) = select.order_by {
69                order_by.expressions = order_by
70                    .expressions
71                    .into_iter()
72                    .map(|mut o| {
73                        o.this = canonicalize_recursive(o.this, dialect);
74                        o = remove_ascending_order(o);
75                        o
76                    })
77                    .collect();
78                select.order_by = Some(order_by);
79            }
80
81            // Canonicalize JOINs
82            select.joins = select
83                .joins
84                .into_iter()
85                .map(|mut j| {
86                    j.this = canonicalize_recursive(j.this, dialect);
87                    if let Some(on) = j.on {
88                        j.on = Some(canonicalize_recursive(on, dialect));
89                    }
90                    j
91                })
92                .collect();
93
94            Expression::Select(select)
95        }
96
97        // Binary operations that might involve string addition
98        Expression::Add(bin) => {
99            let left = canonicalize_recursive(bin.left, dialect);
100            let right = canonicalize_recursive(bin.right, dialect);
101            let result = Expression::Add(Box::new(crate::expressions::BinaryOp {
102                left,
103                right,
104                left_comments: bin.left_comments,
105                operator_comments: bin.operator_comments,
106                trailing_comments: bin.trailing_comments,
107            }));
108            add_text_to_concat(result)
109        }
110
111        // Other binary operations
112        Expression::And(bin) => {
113            let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
114            let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
115            Expression::And(Box::new(crate::expressions::BinaryOp {
116                left,
117                right,
118                left_comments: bin.left_comments,
119                operator_comments: bin.operator_comments,
120                trailing_comments: bin.trailing_comments,
121            }))
122        }
123        Expression::Or(bin) => {
124            let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
125            let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
126            Expression::Or(Box::new(crate::expressions::BinaryOp {
127                left,
128                right,
129                left_comments: bin.left_comments,
130                operator_comments: bin.operator_comments,
131                trailing_comments: bin.trailing_comments,
132            }))
133        }
134
135        Expression::Not(un) => {
136            let inner = ensure_bools(canonicalize_recursive(un.this, dialect));
137            Expression::Not(Box::new(crate::expressions::UnaryOp { this: inner }))
138        }
139
140        // Comparison operations - check for date coercion
141        Expression::Eq(bin) => canonicalize_comparison(Expression::Eq, *bin, dialect),
142        Expression::Neq(bin) => canonicalize_comparison(Expression::Neq, *bin, dialect),
143        Expression::Lt(bin) => canonicalize_comparison(Expression::Lt, *bin, dialect),
144        Expression::Lte(bin) => canonicalize_comparison(Expression::Lte, *bin, dialect),
145        Expression::Gt(bin) => canonicalize_comparison(Expression::Gt, *bin, dialect),
146        Expression::Gte(bin) => canonicalize_comparison(Expression::Gte, *bin, dialect),
147
148        Expression::Sub(bin) => canonicalize_comparison(Expression::Sub, *bin, dialect),
149        Expression::Mul(bin) => canonicalize_binary(Expression::Mul, *bin, dialect),
150        Expression::Div(bin) => canonicalize_binary(Expression::Div, *bin, dialect),
151
152        // Cast - check for redundancy
153        Expression::Cast(cast) => {
154            let inner = canonicalize_recursive(cast.this, dialect);
155            let result = Expression::Cast(Box::new(crate::expressions::Cast {
156                this: inner,
157                to: cast.to,
158                trailing_comments: cast.trailing_comments,
159                double_colon_syntax: cast.double_colon_syntax,
160                format: cast.format,
161                default: cast.default,
162            }));
163            remove_redundant_casts(result)
164        }
165
166        // Function expressions
167        Expression::Function(func) => {
168            let args = func
169                .args
170                .into_iter()
171                .map(|e| canonicalize_recursive(e, dialect))
172                .collect();
173            Expression::Function(Box::new(crate::expressions::Function {
174                name: func.name,
175                args,
176                distinct: func.distinct,
177                trailing_comments: func.trailing_comments,
178                use_bracket_syntax: func.use_bracket_syntax,
179                no_parens: func.no_parens,
180                quoted: func.quoted,
181                span: None,
182            }))
183        }
184
185        Expression::AggregateFunction(agg) => {
186            let args = agg
187                .args
188                .into_iter()
189                .map(|e| canonicalize_recursive(e, dialect))
190                .collect();
191            Expression::AggregateFunction(Box::new(crate::expressions::AggregateFunction {
192                name: agg.name,
193                args,
194                distinct: agg.distinct,
195                filter: agg.filter.map(|f| canonicalize_recursive(f, dialect)),
196                order_by: agg.order_by,
197                limit: agg.limit,
198                ignore_nulls: agg.ignore_nulls,
199            }))
200        }
201
202        // Alias
203        Expression::Alias(alias) => {
204            let inner = canonicalize_recursive(alias.this, dialect);
205            Expression::Alias(Box::new(crate::expressions::Alias {
206                this: inner,
207                alias: alias.alias,
208                column_aliases: alias.column_aliases,
209                pre_alias_comments: alias.pre_alias_comments,
210                trailing_comments: alias.trailing_comments,
211            }))
212        }
213
214        // Paren
215        Expression::Paren(paren) => {
216            let inner = canonicalize_recursive(paren.this, dialect);
217            Expression::Paren(Box::new(crate::expressions::Paren {
218                this: inner,
219                trailing_comments: paren.trailing_comments,
220            }))
221        }
222
223        // Case
224        Expression::Case(case) => {
225            let operand = case.operand.map(|e| canonicalize_recursive(e, dialect));
226            let whens = case
227                .whens
228                .into_iter()
229                .map(|(w, t)| {
230                    (
231                        canonicalize_recursive(w, dialect),
232                        canonicalize_recursive(t, dialect),
233                    )
234                })
235                .collect();
236            let else_ = case.else_.map(|e| canonicalize_recursive(e, dialect));
237            Expression::Case(Box::new(crate::expressions::Case {
238                operand,
239                whens,
240                else_,
241                comments: Vec::new(),
242            }))
243        }
244
245        // Between - check for date coercion
246        Expression::Between(between) => {
247            let this = canonicalize_recursive(between.this, dialect);
248            let low = canonicalize_recursive(between.low, dialect);
249            let high = canonicalize_recursive(between.high, dialect);
250            Expression::Between(Box::new(crate::expressions::Between {
251                this,
252                low,
253                high,
254                not: between.not,
255                symmetric: between.symmetric,
256            }))
257        }
258
259        // In
260        Expression::In(in_expr) => {
261            let this = canonicalize_recursive(in_expr.this, dialect);
262            let expressions = in_expr
263                .expressions
264                .into_iter()
265                .map(|e| canonicalize_recursive(e, dialect))
266                .collect();
267            let query = in_expr.query.map(|q| canonicalize_recursive(q, dialect));
268            Expression::In(Box::new(crate::expressions::In {
269                this,
270                expressions,
271                query,
272                not: in_expr.not,
273                global: in_expr.global,
274                unnest: in_expr.unnest,
275                is_field: in_expr.is_field,
276            }))
277        }
278
279        // Subquery
280        Expression::Subquery(subquery) => {
281            let this = canonicalize_recursive(subquery.this, dialect);
282            Expression::Subquery(Box::new(crate::expressions::Subquery {
283                this,
284                alias: subquery.alias,
285                column_aliases: subquery.column_aliases,
286                order_by: subquery.order_by,
287                limit: subquery.limit,
288                offset: subquery.offset,
289                distribute_by: subquery.distribute_by,
290                sort_by: subquery.sort_by,
291                cluster_by: subquery.cluster_by,
292                lateral: subquery.lateral,
293                modifiers_inside: subquery.modifiers_inside,
294                trailing_comments: subquery.trailing_comments,
295            }))
296        }
297
298        // Set operations
299        Expression::Union(union) => {
300            let left = canonicalize_recursive(union.left, dialect);
301            let right = canonicalize_recursive(union.right, dialect);
302            Expression::Union(Box::new(crate::expressions::Union {
303                left,
304                right,
305                all: union.all,
306                distinct: union.distinct,
307                with: union.with,
308                order_by: union.order_by,
309                limit: union.limit,
310                offset: union.offset,
311                distribute_by: union.distribute_by,
312                sort_by: union.sort_by,
313                cluster_by: union.cluster_by,
314                by_name: union.by_name,
315                side: union.side,
316                kind: union.kind,
317                corresponding: union.corresponding,
318                strict: union.strict,
319                on_columns: union.on_columns,
320            }))
321        }
322        Expression::Intersect(intersect) => {
323            let left = canonicalize_recursive(intersect.left, dialect);
324            let right = canonicalize_recursive(intersect.right, dialect);
325            Expression::Intersect(Box::new(crate::expressions::Intersect {
326                left,
327                right,
328                all: intersect.all,
329                distinct: intersect.distinct,
330                with: intersect.with,
331                order_by: intersect.order_by,
332                limit: intersect.limit,
333                offset: intersect.offset,
334                distribute_by: intersect.distribute_by,
335                sort_by: intersect.sort_by,
336                cluster_by: intersect.cluster_by,
337                by_name: intersect.by_name,
338                side: intersect.side,
339                kind: intersect.kind,
340                corresponding: intersect.corresponding,
341                strict: intersect.strict,
342                on_columns: intersect.on_columns,
343            }))
344        }
345        Expression::Except(except) => {
346            let left = canonicalize_recursive(except.left, dialect);
347            let right = canonicalize_recursive(except.right, dialect);
348            Expression::Except(Box::new(crate::expressions::Except {
349                left,
350                right,
351                all: except.all,
352                distinct: except.distinct,
353                with: except.with,
354                order_by: except.order_by,
355                limit: except.limit,
356                offset: except.offset,
357                distribute_by: except.distribute_by,
358                sort_by: except.sort_by,
359                cluster_by: except.cluster_by,
360                by_name: except.by_name,
361                side: except.side,
362                kind: except.kind,
363                corresponding: except.corresponding,
364                strict: except.strict,
365                on_columns: except.on_columns,
366            }))
367        }
368
369        // Leaf nodes - return unchanged
370        other => other,
371    };
372
373    expr
374}
375
376/// Convert string addition to CONCAT.
377///
378/// When two TEXT types are added with +, convert to CONCAT.
379/// This is used by dialects like T-SQL and Redshift.
380fn add_text_to_concat(expression: Expression) -> Expression {
381    // In a full implementation, we would check if the operands are TEXT types
382    // and convert to CONCAT. For now, we return unchanged.
383    expression
384}
385
386/// Remove redundant cast expressions.
387///
388/// If casting to the same type the expression already is, remove the cast.
389fn remove_redundant_casts(expression: Expression) -> Expression {
390    if let Expression::Cast(cast) = &expression {
391        // Check if the inner expression's type matches the cast target
392        // In a full implementation with type annotations, we would compare types
393        // For now, just check simple cases
394
395        // If casting a literal to its natural type, we might be able to simplify
396        if let Expression::Literal(Literal::String(_)) = &cast.this {
397            if matches!(&cast.to, DataType::VarChar { .. } | DataType::Text) {
398                return cast.this.clone();
399            }
400        }
401        if let Expression::Literal(Literal::Number(_)) = &cast.this {
402            if matches!(
403                &cast.to,
404                DataType::Int { .. }
405                    | DataType::BigInt { .. }
406                    | DataType::Decimal { .. }
407                    | DataType::Float { .. }
408            ) {
409                // Could potentially remove cast, but be conservative
410            }
411        }
412    }
413    expression
414}
415
416/// Ensure expressions used as boolean predicates are actually boolean.
417///
418/// For example, in some dialects, integers can be used as booleans.
419/// This function ensures proper boolean semantics.
420fn ensure_bools(expression: Expression) -> Expression {
421    // In a full implementation, we would check if the expression is an integer
422    // and convert it to a comparison (e.g., x != 0).
423    // For now, return unchanged.
424    expression
425}
426
427/// Remove explicit ASC from ORDER BY clauses.
428///
429/// Since ASC is the default, `ORDER BY a ASC` can be simplified to `ORDER BY a`.
430fn remove_ascending_order(mut ordered: crate::expressions::Ordered) -> crate::expressions::Ordered {
431    // If ASC was explicitly written (not DESC), remove the explicit flag
432    // since ASC is the default ordering
433    if !ordered.desc && ordered.explicit_asc {
434        ordered.explicit_asc = false;
435    }
436    ordered
437}
438
439/// Canonicalize a binary comparison operation.
440fn canonicalize_comparison<F>(
441    constructor: F,
442    bin: crate::expressions::BinaryOp,
443    dialect: Option<DialectType>,
444) -> Expression
445where
446    F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
447{
448    let left = canonicalize_recursive(bin.left, dialect);
449    let right = canonicalize_recursive(bin.right, dialect);
450
451    // Check for date coercion opportunities
452    let (left, right) = coerce_date_operands(left, right);
453
454    constructor(Box::new(crate::expressions::BinaryOp {
455        left,
456        right,
457        left_comments: bin.left_comments,
458        operator_comments: bin.operator_comments,
459        trailing_comments: bin.trailing_comments,
460    }))
461}
462
463/// Canonicalize a regular binary operation.
464fn canonicalize_binary<F>(
465    constructor: F,
466    bin: crate::expressions::BinaryOp,
467    dialect: Option<DialectType>,
468) -> Expression
469where
470    F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
471{
472    let left = canonicalize_recursive(bin.left, dialect);
473    let right = canonicalize_recursive(bin.right, dialect);
474
475    constructor(Box::new(crate::expressions::BinaryOp {
476        left,
477        right,
478        left_comments: bin.left_comments,
479        operator_comments: bin.operator_comments,
480        trailing_comments: bin.trailing_comments,
481    }))
482}
483
484/// Coerce date operands in comparisons.
485///
486/// When comparing a date/datetime column with a string literal,
487/// add appropriate CAST to the string.
488fn coerce_date_operands(left: Expression, right: Expression) -> (Expression, Expression) {
489    // Check if we should cast string literals to date/datetime
490    let left = coerce_date_string(left, &right);
491    let right = coerce_date_string(right, &left);
492    (left, right)
493}
494
495/// Coerce a string literal to date/datetime if comparing with a temporal type.
496fn coerce_date_string(expr: Expression, _other: &Expression) -> Expression {
497    if let Expression::Literal(Literal::String(ref s)) = expr {
498        // Check if the string is an ISO date or datetime
499        if is_iso_date(s) {
500            // In a full implementation, we would add CAST to DATE
501            // For now, return unchanged
502        } else if is_iso_datetime(s) {
503            // In a full implementation, we would add CAST to DATETIME/TIMESTAMP
504            // For now, return unchanged
505        }
506    }
507    expr
508}
509
510#[cfg(test)]
511mod tests {
512    use super::*;
513    use crate::generator::Generator;
514    use crate::parser::Parser;
515
516    fn gen(expr: &Expression) -> String {
517        Generator::new().generate(expr).unwrap()
518    }
519
520    fn parse(sql: &str) -> Expression {
521        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
522    }
523
524    #[test]
525    fn test_canonicalize_simple() {
526        let expr = parse("SELECT a FROM t");
527        let result = canonicalize(expr, None);
528        let sql = gen(&result);
529        assert!(sql.contains("SELECT"));
530    }
531
532    #[test]
533    fn test_canonicalize_preserves_structure() {
534        let expr = parse("SELECT a, b FROM t WHERE c = 1");
535        let result = canonicalize(expr, None);
536        let sql = gen(&result);
537        assert!(sql.contains("WHERE"));
538    }
539
540    #[test]
541    fn test_canonicalize_and_or() {
542        let expr = parse("SELECT 1 WHERE a AND b OR c");
543        let result = canonicalize(expr, None);
544        let sql = gen(&result);
545        assert!(sql.contains("AND") || sql.contains("OR"));
546    }
547
548    #[test]
549    fn test_canonicalize_comparison() {
550        let expr = parse("SELECT 1 WHERE a = 1 AND b > 2");
551        let result = canonicalize(expr, None);
552        let sql = gen(&result);
553        assert!(sql.contains("=") && sql.contains(">"));
554    }
555
556    #[test]
557    fn test_canonicalize_case() {
558        let expr = parse("SELECT CASE WHEN a = 1 THEN 'yes' ELSE 'no' END FROM t");
559        let result = canonicalize(expr, None);
560        let sql = gen(&result);
561        assert!(sql.contains("CASE") && sql.contains("WHEN"));
562    }
563
564    #[test]
565    fn test_canonicalize_subquery() {
566        let expr = parse("SELECT a FROM (SELECT b FROM t) AS sub");
567        let result = canonicalize(expr, None);
568        let sql = gen(&result);
569        assert!(sql.contains("SELECT") && sql.contains("sub"));
570    }
571
572    #[test]
573    fn test_canonicalize_order_by() {
574        let expr = parse("SELECT a FROM t ORDER BY a");
575        let result = canonicalize(expr, None);
576        let sql = gen(&result);
577        assert!(sql.contains("ORDER BY"));
578    }
579
580    #[test]
581    fn test_canonicalize_union() {
582        let expr = parse("SELECT a FROM t UNION SELECT b FROM s");
583        let result = canonicalize(expr, None);
584        let sql = gen(&result);
585        assert!(sql.contains("UNION"));
586    }
587
588    #[test]
589    fn test_add_text_to_concat_passthrough() {
590        // Test that non-text additions pass through
591        let expr = parse("SELECT 1 + 2");
592        let result = canonicalize(expr, None);
593        let sql = gen(&result);
594        assert!(sql.contains("+"));
595    }
596
597    #[test]
598    fn test_canonicalize_function() {
599        let expr = parse("SELECT MAX(a) FROM t");
600        let result = canonicalize(expr, None);
601        let sql = gen(&result);
602        assert!(sql.contains("MAX"));
603    }
604
605    #[test]
606    fn test_canonicalize_between() {
607        let expr = parse("SELECT 1 WHERE a BETWEEN 1 AND 10");
608        let result = canonicalize(expr, None);
609        let sql = gen(&result);
610        assert!(sql.contains("BETWEEN"));
611    }
612}