Skip to main content

polyglot_sql/optimizer/
canonicalize.rs

1//! Canonicalization Module
2//!
3//! This module provides functionality for converting SQL expressions into a
4//! standard canonical form. This includes:
5//! - Converting string addition to CONCAT
6//! - Replacing date functions with casts
7//! - Removing redundant type casts
8//! - Ensuring boolean predicates
9//! - Removing unnecessary ASC from ORDER BY
10//!
11//! Ported from sqlglot's optimizer/canonicalize.py
12
13use crate::dialects::DialectType;
14use crate::expressions::{DataType, Expression, Literal};
15use crate::helper::{is_iso_date, is_iso_datetime};
16
17/// Converts a SQL expression into a standard canonical form.
18///
19/// This transformation relies on type annotations because many of the
20/// conversions depend on type inference.
21///
22/// # Arguments
23/// * `expression` - The expression to canonicalize
24/// * `dialect` - Optional dialect for dialect-specific behavior
25///
26/// # Returns
27/// The canonicalized expression
28pub fn canonicalize(expression: Expression, dialect: Option<DialectType>) -> Expression {
29    canonicalize_recursive(expression, dialect)
30}
31
32/// Recursively canonicalize an expression and its children
33fn canonicalize_recursive(expression: Expression, dialect: Option<DialectType>) -> Expression {
34    let expr = match expression {
35        Expression::Select(mut select) => {
36            // Canonicalize SELECT expressions
37            select.expressions = select
38                .expressions
39                .into_iter()
40                .map(|e| canonicalize_recursive(e, dialect))
41                .collect();
42
43            // Canonicalize FROM
44            if let Some(mut from) = select.from {
45                from.expressions = from
46                    .expressions
47                    .into_iter()
48                    .map(|e| canonicalize_recursive(e, dialect))
49                    .collect();
50                select.from = Some(from);
51            }
52
53            // Canonicalize WHERE
54            if let Some(mut where_clause) = select.where_clause {
55                where_clause.this = canonicalize_recursive(where_clause.this, dialect);
56                where_clause.this = ensure_bools(where_clause.this);
57                select.where_clause = Some(where_clause);
58            }
59
60            // Canonicalize HAVING
61            if let Some(mut having) = select.having {
62                having.this = canonicalize_recursive(having.this, dialect);
63                having.this = ensure_bools(having.this);
64                select.having = Some(having);
65            }
66
67            // Canonicalize ORDER BY
68            if let Some(mut order_by) = select.order_by {
69                order_by.expressions = order_by
70                    .expressions
71                    .into_iter()
72                    .map(|mut o| {
73                        o.this = canonicalize_recursive(o.this, dialect);
74                        o = remove_ascending_order(o);
75                        o
76                    })
77                    .collect();
78                select.order_by = Some(order_by);
79            }
80
81            // Canonicalize JOINs
82            select.joins = select
83                .joins
84                .into_iter()
85                .map(|mut j| {
86                    j.this = canonicalize_recursive(j.this, dialect);
87                    if let Some(on) = j.on {
88                        j.on = Some(canonicalize_recursive(on, dialect));
89                    }
90                    j
91                })
92                .collect();
93
94            Expression::Select(select)
95        }
96
97        // Binary operations that might involve string addition
98        Expression::Add(bin) => {
99            let left = canonicalize_recursive(bin.left, dialect);
100            let right = canonicalize_recursive(bin.right, dialect);
101            let result = Expression::Add(Box::new(crate::expressions::BinaryOp {
102                left,
103                right,
104                left_comments: bin.left_comments,
105                operator_comments: bin.operator_comments,
106                trailing_comments: bin.trailing_comments,
107            }));
108            add_text_to_concat(result)
109        }
110
111        // Other binary operations
112        Expression::And(bin) => {
113            let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
114            let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
115            Expression::And(Box::new(crate::expressions::BinaryOp {
116                left,
117                right,
118                left_comments: bin.left_comments,
119                operator_comments: bin.operator_comments,
120                trailing_comments: bin.trailing_comments,
121            }))
122        }
123        Expression::Or(bin) => {
124            let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
125            let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
126            Expression::Or(Box::new(crate::expressions::BinaryOp {
127                left,
128                right,
129                left_comments: bin.left_comments,
130                operator_comments: bin.operator_comments,
131                trailing_comments: bin.trailing_comments,
132            }))
133        }
134
135        Expression::Not(un) => {
136            let inner = ensure_bools(canonicalize_recursive(un.this, dialect));
137            Expression::Not(Box::new(crate::expressions::UnaryOp { this: inner }))
138        }
139
140        // Comparison operations - check for date coercion
141        Expression::Eq(bin) => canonicalize_comparison(Expression::Eq, *bin, dialect),
142        Expression::Neq(bin) => canonicalize_comparison(Expression::Neq, *bin, dialect),
143        Expression::Lt(bin) => canonicalize_comparison(Expression::Lt, *bin, dialect),
144        Expression::Lte(bin) => canonicalize_comparison(Expression::Lte, *bin, dialect),
145        Expression::Gt(bin) => canonicalize_comparison(Expression::Gt, *bin, dialect),
146        Expression::Gte(bin) => canonicalize_comparison(Expression::Gte, *bin, dialect),
147
148        Expression::Sub(bin) => canonicalize_comparison(Expression::Sub, *bin, dialect),
149        Expression::Mul(bin) => canonicalize_binary(Expression::Mul, *bin, dialect),
150        Expression::Div(bin) => canonicalize_binary(Expression::Div, *bin, dialect),
151
152        // Cast - check for redundancy
153        Expression::Cast(cast) => {
154            let inner = canonicalize_recursive(cast.this, dialect);
155            let result = Expression::Cast(Box::new(crate::expressions::Cast {
156                this: inner,
157                to: cast.to,
158                trailing_comments: cast.trailing_comments,
159                double_colon_syntax: cast.double_colon_syntax,
160                format: cast.format,
161                default: cast.default,
162            }));
163            remove_redundant_casts(result)
164        }
165
166        // Function expressions
167        Expression::Function(func) => {
168            let args = func
169                .args
170                .into_iter()
171                .map(|e| canonicalize_recursive(e, dialect))
172                .collect();
173            Expression::Function(Box::new(crate::expressions::Function {
174                name: func.name,
175                args,
176                distinct: func.distinct,
177                trailing_comments: func.trailing_comments,
178                use_bracket_syntax: func.use_bracket_syntax,
179                no_parens: func.no_parens,
180                quoted: func.quoted,
181            }))
182        }
183
184        Expression::AggregateFunction(agg) => {
185            let args = agg
186                .args
187                .into_iter()
188                .map(|e| canonicalize_recursive(e, dialect))
189                .collect();
190            Expression::AggregateFunction(Box::new(crate::expressions::AggregateFunction {
191                name: agg.name,
192                args,
193                distinct: agg.distinct,
194                filter: agg.filter.map(|f| canonicalize_recursive(f, dialect)),
195                order_by: agg.order_by,
196                limit: agg.limit,
197                ignore_nulls: agg.ignore_nulls,
198            }))
199        }
200
201        // Alias
202        Expression::Alias(alias) => {
203            let inner = canonicalize_recursive(alias.this, dialect);
204            Expression::Alias(Box::new(crate::expressions::Alias {
205                this: inner,
206                alias: alias.alias,
207                column_aliases: alias.column_aliases,
208                pre_alias_comments: alias.pre_alias_comments,
209                trailing_comments: alias.trailing_comments,
210            }))
211        }
212
213        // Paren
214        Expression::Paren(paren) => {
215            let inner = canonicalize_recursive(paren.this, dialect);
216            Expression::Paren(Box::new(crate::expressions::Paren {
217                this: inner,
218                trailing_comments: paren.trailing_comments,
219            }))
220        }
221
222        // Case
223        Expression::Case(case) => {
224            let operand = case.operand.map(|e| canonicalize_recursive(e, dialect));
225            let whens = case
226                .whens
227                .into_iter()
228                .map(|(w, t)| {
229                    (
230                        canonicalize_recursive(w, dialect),
231                        canonicalize_recursive(t, dialect),
232                    )
233                })
234                .collect();
235            let else_ = case.else_.map(|e| canonicalize_recursive(e, dialect));
236            Expression::Case(Box::new(crate::expressions::Case {
237                operand,
238                whens,
239                else_,
240                comments: Vec::new(),
241            }))
242        }
243
244        // Between - check for date coercion
245        Expression::Between(between) => {
246            let this = canonicalize_recursive(between.this, dialect);
247            let low = canonicalize_recursive(between.low, dialect);
248            let high = canonicalize_recursive(between.high, dialect);
249            Expression::Between(Box::new(crate::expressions::Between {
250                this,
251                low,
252                high,
253                not: between.not,
254                symmetric: between.symmetric,
255            }))
256        }
257
258        // In
259        Expression::In(in_expr) => {
260            let this = canonicalize_recursive(in_expr.this, dialect);
261            let expressions = in_expr
262                .expressions
263                .into_iter()
264                .map(|e| canonicalize_recursive(e, dialect))
265                .collect();
266            let query = in_expr.query.map(|q| canonicalize_recursive(q, dialect));
267            Expression::In(Box::new(crate::expressions::In {
268                this,
269                expressions,
270                query,
271                not: in_expr.not,
272                global: in_expr.global,
273                unnest: in_expr.unnest,
274                is_field: in_expr.is_field,
275            }))
276        }
277
278        // Subquery
279        Expression::Subquery(subquery) => {
280            let this = canonicalize_recursive(subquery.this, dialect);
281            Expression::Subquery(Box::new(crate::expressions::Subquery {
282                this,
283                alias: subquery.alias,
284                column_aliases: subquery.column_aliases,
285                order_by: subquery.order_by,
286                limit: subquery.limit,
287                offset: subquery.offset,
288                distribute_by: subquery.distribute_by,
289                sort_by: subquery.sort_by,
290                cluster_by: subquery.cluster_by,
291                lateral: subquery.lateral,
292                modifiers_inside: subquery.modifiers_inside,
293                trailing_comments: subquery.trailing_comments,
294            }))
295        }
296
297        // Set operations
298        Expression::Union(union) => {
299            let left = canonicalize_recursive(union.left, dialect);
300            let right = canonicalize_recursive(union.right, dialect);
301            Expression::Union(Box::new(crate::expressions::Union {
302                left,
303                right,
304                all: union.all,
305                distinct: union.distinct,
306                with: union.with,
307                order_by: union.order_by,
308                limit: union.limit,
309                offset: union.offset,
310                distribute_by: union.distribute_by,
311                sort_by: union.sort_by,
312                cluster_by: union.cluster_by,
313                by_name: union.by_name,
314                side: union.side,
315                kind: union.kind,
316                corresponding: union.corresponding,
317                strict: union.strict,
318                on_columns: union.on_columns,
319            }))
320        }
321        Expression::Intersect(intersect) => {
322            let left = canonicalize_recursive(intersect.left, dialect);
323            let right = canonicalize_recursive(intersect.right, dialect);
324            Expression::Intersect(Box::new(crate::expressions::Intersect {
325                left,
326                right,
327                all: intersect.all,
328                distinct: intersect.distinct,
329                with: intersect.with,
330                order_by: intersect.order_by,
331                limit: intersect.limit,
332                offset: intersect.offset,
333                distribute_by: intersect.distribute_by,
334                sort_by: intersect.sort_by,
335                cluster_by: intersect.cluster_by,
336                by_name: intersect.by_name,
337                side: intersect.side,
338                kind: intersect.kind,
339                corresponding: intersect.corresponding,
340                strict: intersect.strict,
341                on_columns: intersect.on_columns,
342            }))
343        }
344        Expression::Except(except) => {
345            let left = canonicalize_recursive(except.left, dialect);
346            let right = canonicalize_recursive(except.right, dialect);
347            Expression::Except(Box::new(crate::expressions::Except {
348                left,
349                right,
350                all: except.all,
351                distinct: except.distinct,
352                with: except.with,
353                order_by: except.order_by,
354                limit: except.limit,
355                offset: except.offset,
356                distribute_by: except.distribute_by,
357                sort_by: except.sort_by,
358                cluster_by: except.cluster_by,
359                by_name: except.by_name,
360                side: except.side,
361                kind: except.kind,
362                corresponding: except.corresponding,
363                strict: except.strict,
364                on_columns: except.on_columns,
365            }))
366        }
367
368        // Leaf nodes - return unchanged
369        other => other,
370    };
371
372    expr
373}
374
375/// Convert string addition to CONCAT.
376///
377/// When two TEXT types are added with +, convert to CONCAT.
378/// This is used by dialects like T-SQL and Redshift.
379fn add_text_to_concat(expression: Expression) -> Expression {
380    // In a full implementation, we would check if the operands are TEXT types
381    // and convert to CONCAT. For now, we return unchanged.
382    expression
383}
384
385/// Remove redundant cast expressions.
386///
387/// If casting to the same type the expression already is, remove the cast.
388fn remove_redundant_casts(expression: Expression) -> Expression {
389    if let Expression::Cast(cast) = &expression {
390        // Check if the inner expression's type matches the cast target
391        // In a full implementation with type annotations, we would compare types
392        // For now, just check simple cases
393
394        // If casting a literal to its natural type, we might be able to simplify
395        if let Expression::Literal(Literal::String(_)) = &cast.this {
396            if matches!(&cast.to, DataType::VarChar { .. } | DataType::Text) {
397                return cast.this.clone();
398            }
399        }
400        if let Expression::Literal(Literal::Number(_)) = &cast.this {
401            if matches!(
402                &cast.to,
403                DataType::Int { .. }
404                    | DataType::BigInt { .. }
405                    | DataType::Decimal { .. }
406                    | DataType::Float { .. }
407            ) {
408                // Could potentially remove cast, but be conservative
409            }
410        }
411    }
412    expression
413}
414
415/// Ensure expressions used as boolean predicates are actually boolean.
416///
417/// For example, in some dialects, integers can be used as booleans.
418/// This function ensures proper boolean semantics.
419fn ensure_bools(expression: Expression) -> Expression {
420    // In a full implementation, we would check if the expression is an integer
421    // and convert it to a comparison (e.g., x != 0).
422    // For now, return unchanged.
423    expression
424}
425
426/// Remove explicit ASC from ORDER BY clauses.
427///
428/// Since ASC is the default, `ORDER BY a ASC` can be simplified to `ORDER BY a`.
429fn remove_ascending_order(mut ordered: crate::expressions::Ordered) -> crate::expressions::Ordered {
430    // If ASC was explicitly written (not DESC), remove the explicit flag
431    // since ASC is the default ordering
432    if !ordered.desc && ordered.explicit_asc {
433        ordered.explicit_asc = false;
434    }
435    ordered
436}
437
438/// Canonicalize a binary comparison operation.
439fn canonicalize_comparison<F>(
440    constructor: F,
441    bin: crate::expressions::BinaryOp,
442    dialect: Option<DialectType>,
443) -> Expression
444where
445    F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
446{
447    let left = canonicalize_recursive(bin.left, dialect);
448    let right = canonicalize_recursive(bin.right, dialect);
449
450    // Check for date coercion opportunities
451    let (left, right) = coerce_date_operands(left, right);
452
453    constructor(Box::new(crate::expressions::BinaryOp {
454        left,
455        right,
456        left_comments: bin.left_comments,
457        operator_comments: bin.operator_comments,
458        trailing_comments: bin.trailing_comments,
459    }))
460}
461
462/// Canonicalize a regular binary operation.
463fn canonicalize_binary<F>(
464    constructor: F,
465    bin: crate::expressions::BinaryOp,
466    dialect: Option<DialectType>,
467) -> Expression
468where
469    F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
470{
471    let left = canonicalize_recursive(bin.left, dialect);
472    let right = canonicalize_recursive(bin.right, dialect);
473
474    constructor(Box::new(crate::expressions::BinaryOp {
475        left,
476        right,
477        left_comments: bin.left_comments,
478        operator_comments: bin.operator_comments,
479        trailing_comments: bin.trailing_comments,
480    }))
481}
482
483/// Coerce date operands in comparisons.
484///
485/// When comparing a date/datetime column with a string literal,
486/// add appropriate CAST to the string.
487fn coerce_date_operands(left: Expression, right: Expression) -> (Expression, Expression) {
488    // Check if we should cast string literals to date/datetime
489    let left = coerce_date_string(left, &right);
490    let right = coerce_date_string(right, &left);
491    (left, right)
492}
493
494/// Coerce a string literal to date/datetime if comparing with a temporal type.
495fn coerce_date_string(expr: Expression, _other: &Expression) -> Expression {
496    if let Expression::Literal(Literal::String(ref s)) = expr {
497        // Check if the string is an ISO date or datetime
498        if is_iso_date(s) {
499            // In a full implementation, we would add CAST to DATE
500            // For now, return unchanged
501        } else if is_iso_datetime(s) {
502            // In a full implementation, we would add CAST to DATETIME/TIMESTAMP
503            // For now, return unchanged
504        }
505    }
506    expr
507}
508
509#[cfg(test)]
510mod tests {
511    use super::*;
512    use crate::generator::Generator;
513    use crate::parser::Parser;
514
515    fn gen(expr: &Expression) -> String {
516        Generator::new().generate(expr).unwrap()
517    }
518
519    fn parse(sql: &str) -> Expression {
520        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
521    }
522
523    #[test]
524    fn test_canonicalize_simple() {
525        let expr = parse("SELECT a FROM t");
526        let result = canonicalize(expr, None);
527        let sql = gen(&result);
528        assert!(sql.contains("SELECT"));
529    }
530
531    #[test]
532    fn test_canonicalize_preserves_structure() {
533        let expr = parse("SELECT a, b FROM t WHERE c = 1");
534        let result = canonicalize(expr, None);
535        let sql = gen(&result);
536        assert!(sql.contains("WHERE"));
537    }
538
539    #[test]
540    fn test_canonicalize_and_or() {
541        let expr = parse("SELECT 1 WHERE a AND b OR c");
542        let result = canonicalize(expr, None);
543        let sql = gen(&result);
544        assert!(sql.contains("AND") || sql.contains("OR"));
545    }
546
547    #[test]
548    fn test_canonicalize_comparison() {
549        let expr = parse("SELECT 1 WHERE a = 1 AND b > 2");
550        let result = canonicalize(expr, None);
551        let sql = gen(&result);
552        assert!(sql.contains("=") && sql.contains(">"));
553    }
554
555    #[test]
556    fn test_canonicalize_case() {
557        let expr = parse("SELECT CASE WHEN a = 1 THEN 'yes' ELSE 'no' END FROM t");
558        let result = canonicalize(expr, None);
559        let sql = gen(&result);
560        assert!(sql.contains("CASE") && sql.contains("WHEN"));
561    }
562
563    #[test]
564    fn test_canonicalize_subquery() {
565        let expr = parse("SELECT a FROM (SELECT b FROM t) AS sub");
566        let result = canonicalize(expr, None);
567        let sql = gen(&result);
568        assert!(sql.contains("SELECT") && sql.contains("sub"));
569    }
570
571    #[test]
572    fn test_canonicalize_order_by() {
573        let expr = parse("SELECT a FROM t ORDER BY a");
574        let result = canonicalize(expr, None);
575        let sql = gen(&result);
576        assert!(sql.contains("ORDER BY"));
577    }
578
579    #[test]
580    fn test_canonicalize_union() {
581        let expr = parse("SELECT a FROM t UNION SELECT b FROM s");
582        let result = canonicalize(expr, None);
583        let sql = gen(&result);
584        assert!(sql.contains("UNION"));
585    }
586
587    #[test]
588    fn test_add_text_to_concat_passthrough() {
589        // Test that non-text additions pass through
590        let expr = parse("SELECT 1 + 2");
591        let result = canonicalize(expr, None);
592        let sql = gen(&result);
593        assert!(sql.contains("+"));
594    }
595
596    #[test]
597    fn test_canonicalize_function() {
598        let expr = parse("SELECT MAX(a) FROM t");
599        let result = canonicalize(expr, None);
600        let sql = gen(&result);
601        assert!(sql.contains("MAX"));
602    }
603
604    #[test]
605    fn test_canonicalize_between() {
606        let expr = parse("SELECT 1 WHERE a BETWEEN 1 AND 10");
607        let result = canonicalize(expr, None);
608        let sql = gen(&result);
609        assert!(sql.contains("BETWEEN"));
610    }
611}