Skip to main content

polyglot_sql/optimizer/
canonicalize.rs

1//! Canonicalization Module
2//!
3//! This module provides functionality for converting SQL expressions into a
4//! standard canonical form. This includes:
5//! - Converting string addition to CONCAT
6//! - Replacing date functions with casts
7//! - Removing redundant type casts
8//! - Ensuring boolean predicates
9//! - Removing unnecessary ASC from ORDER BY
10//!
11//! Ported from sqlglot's optimizer/canonicalize.py
12
13use crate::dialects::DialectType;
14use crate::expressions::{DataType, Expression, Literal, Null};
15use crate::helper::{is_iso_date, is_iso_datetime};
16
17/// Converts a SQL expression into a standard canonical form.
18///
19/// This transformation relies on type annotations because many of the
20/// conversions depend on type inference.
21///
22/// # Arguments
23/// * `expression` - The expression to canonicalize
24/// * `dialect` - Optional dialect for dialect-specific behavior
25///
26/// # Returns
27/// The canonicalized expression
28pub fn canonicalize(expression: Expression, dialect: Option<DialectType>) -> Expression {
29    canonicalize_recursive(expression, dialect)
30}
31
32/// Recursively canonicalize an expression and its children
33fn canonicalize_recursive(expression: Expression, dialect: Option<DialectType>) -> Expression {
34    let expr = match expression {
35        Expression::Select(mut select) => {
36            // Canonicalize SELECT expressions
37            select.expressions = select
38                .expressions
39                .into_iter()
40                .map(|e| canonicalize_recursive(e, dialect))
41                .collect();
42
43            // Canonicalize FROM
44            if let Some(mut from) = select.from {
45                from.expressions = from
46                    .expressions
47                    .into_iter()
48                    .map(|e| canonicalize_recursive(e, dialect))
49                    .collect();
50                select.from = Some(from);
51            }
52
53            // Canonicalize WHERE
54            if let Some(mut where_clause) = select.where_clause {
55                where_clause.this = canonicalize_recursive(where_clause.this, dialect);
56                where_clause.this = ensure_bools(where_clause.this);
57                select.where_clause = Some(where_clause);
58            }
59
60            // Canonicalize HAVING
61            if let Some(mut having) = select.having {
62                having.this = canonicalize_recursive(having.this, dialect);
63                having.this = ensure_bools(having.this);
64                select.having = Some(having);
65            }
66
67            // Canonicalize ORDER BY
68            if let Some(mut order_by) = select.order_by {
69                order_by.expressions = order_by
70                    .expressions
71                    .into_iter()
72                    .map(|mut o| {
73                        o.this = canonicalize_recursive(o.this, dialect);
74                        o = remove_ascending_order(o);
75                        o
76                    })
77                    .collect();
78                select.order_by = Some(order_by);
79            }
80
81            // Canonicalize JOINs
82            select.joins = select
83                .joins
84                .into_iter()
85                .map(|mut j| {
86                    j.this = canonicalize_recursive(j.this, dialect);
87                    if let Some(on) = j.on {
88                        j.on = Some(canonicalize_recursive(on, dialect));
89                    }
90                    j
91                })
92                .collect();
93
94            Expression::Select(select)
95        }
96
97        // Binary operations that might involve string addition
98        Expression::Add(bin) => {
99            let left = canonicalize_recursive(bin.left, dialect);
100            let right = canonicalize_recursive(bin.right, dialect);
101            let result = Expression::Add(Box::new(crate::expressions::BinaryOp {
102                left,
103                right,
104                left_comments: bin.left_comments,
105                operator_comments: bin.operator_comments,
106                trailing_comments: bin.trailing_comments,
107                inferred_type: None,
108            }));
109            add_text_to_concat(result)
110        }
111
112        // Other binary operations
113        Expression::And(bin) => {
114            let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
115            let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
116            Expression::And(Box::new(crate::expressions::BinaryOp {
117                left,
118                right,
119                left_comments: bin.left_comments,
120                operator_comments: bin.operator_comments,
121                trailing_comments: bin.trailing_comments,
122                inferred_type: None,
123            }))
124        }
125        Expression::Or(bin) => {
126            let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
127            let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
128            Expression::Or(Box::new(crate::expressions::BinaryOp {
129                left,
130                right,
131                left_comments: bin.left_comments,
132                operator_comments: bin.operator_comments,
133                trailing_comments: bin.trailing_comments,
134                inferred_type: None,
135            }))
136        }
137
138        Expression::Not(un) => {
139            let inner = ensure_bools(canonicalize_recursive(un.this, dialect));
140            Expression::Not(Box::new(crate::expressions::UnaryOp {
141                this: inner,
142                inferred_type: None,
143            }))
144        }
145
146        // Comparison operations - check for date coercion
147        Expression::Eq(bin) => canonicalize_comparison(Expression::Eq, *bin, dialect),
148        Expression::Neq(bin) => canonicalize_comparison(Expression::Neq, *bin, dialect),
149        Expression::Lt(bin) => canonicalize_comparison(Expression::Lt, *bin, dialect),
150        Expression::Lte(bin) => canonicalize_comparison(Expression::Lte, *bin, dialect),
151        Expression::Gt(bin) => canonicalize_comparison(Expression::Gt, *bin, dialect),
152        Expression::Gte(bin) => canonicalize_comparison(Expression::Gte, *bin, dialect),
153
154        Expression::Sub(bin) => canonicalize_comparison(Expression::Sub, *bin, dialect),
155        Expression::Mul(bin) => canonicalize_binary(Expression::Mul, *bin, dialect),
156        Expression::Div(bin) => canonicalize_binary(Expression::Div, *bin, dialect),
157
158        // Cast - check for redundancy
159        Expression::Cast(cast) => {
160            let inner = canonicalize_recursive(cast.this, dialect);
161            let result = Expression::Cast(Box::new(crate::expressions::Cast {
162                this: inner,
163                to: cast.to,
164                trailing_comments: cast.trailing_comments,
165                double_colon_syntax: cast.double_colon_syntax,
166                format: cast.format,
167                default: cast.default,
168                inferred_type: None,
169            }));
170            remove_redundant_casts(result)
171        }
172
173        // Function expressions
174        Expression::Function(func) => {
175            let args = func
176                .args
177                .into_iter()
178                .map(|e| canonicalize_recursive(e, dialect))
179                .collect();
180            Expression::Function(Box::new(crate::expressions::Function {
181                name: func.name,
182                args,
183                distinct: func.distinct,
184                trailing_comments: func.trailing_comments,
185                use_bracket_syntax: func.use_bracket_syntax,
186                no_parens: func.no_parens,
187                quoted: func.quoted,
188                span: None,
189                inferred_type: None,
190            }))
191        }
192
193        Expression::AggregateFunction(agg) => {
194            let args = agg
195                .args
196                .into_iter()
197                .map(|e| canonicalize_recursive(e, dialect))
198                .collect();
199            Expression::AggregateFunction(Box::new(crate::expressions::AggregateFunction {
200                name: agg.name,
201                args,
202                distinct: agg.distinct,
203                filter: agg.filter.map(|f| canonicalize_recursive(f, dialect)),
204                order_by: agg.order_by,
205                limit: agg.limit,
206                ignore_nulls: agg.ignore_nulls,
207                inferred_type: None,
208            }))
209        }
210
211        // Alias
212        Expression::Alias(alias) => {
213            let inner = canonicalize_recursive(alias.this, dialect);
214            Expression::Alias(Box::new(crate::expressions::Alias {
215                this: inner,
216                alias: alias.alias,
217                column_aliases: alias.column_aliases,
218                alias_explicit_as: false,
219                alias_keyword: None,
220                pre_alias_comments: alias.pre_alias_comments,
221                trailing_comments: alias.trailing_comments,
222                inferred_type: None,
223            }))
224        }
225
226        // Paren
227        Expression::Paren(paren) => {
228            let inner = canonicalize_recursive(paren.this, dialect);
229            Expression::Paren(Box::new(crate::expressions::Paren {
230                this: inner,
231                trailing_comments: paren.trailing_comments,
232            }))
233        }
234
235        // Case
236        Expression::Case(case) => {
237            let operand = case.operand.map(|e| canonicalize_recursive(e, dialect));
238            let whens = case
239                .whens
240                .into_iter()
241                .map(|(w, t)| {
242                    (
243                        canonicalize_recursive(w, dialect),
244                        canonicalize_recursive(t, dialect),
245                    )
246                })
247                .collect();
248            let else_ = case.else_.map(|e| canonicalize_recursive(e, dialect));
249            Expression::Case(Box::new(crate::expressions::Case {
250                operand,
251                whens,
252                else_,
253                comments: Vec::new(),
254                inferred_type: None,
255            }))
256        }
257
258        // Between - check for date coercion
259        Expression::Between(between) => {
260            let this = canonicalize_recursive(between.this, dialect);
261            let low = canonicalize_recursive(between.low, dialect);
262            let high = canonicalize_recursive(between.high, dialect);
263            Expression::Between(Box::new(crate::expressions::Between {
264                this,
265                low,
266                high,
267                not: between.not,
268                symmetric: between.symmetric,
269            }))
270        }
271
272        // In
273        Expression::In(in_expr) => {
274            let this = canonicalize_recursive(in_expr.this, dialect);
275            let expressions = in_expr
276                .expressions
277                .into_iter()
278                .map(|e| canonicalize_recursive(e, dialect))
279                .collect();
280            let query = in_expr.query.map(|q| canonicalize_recursive(q, dialect));
281            Expression::In(Box::new(crate::expressions::In {
282                this,
283                expressions,
284                query,
285                not: in_expr.not,
286                global: in_expr.global,
287                unnest: in_expr.unnest,
288                is_field: in_expr.is_field,
289            }))
290        }
291
292        // Subquery
293        Expression::Subquery(subquery) => {
294            let this = canonicalize_recursive(subquery.this, dialect);
295            Expression::Subquery(Box::new(crate::expressions::Subquery {
296                this,
297                alias: subquery.alias,
298                column_aliases: subquery.column_aliases,
299                alias_explicit_as: subquery.alias_explicit_as,
300                alias_keyword: subquery.alias_keyword,
301                order_by: subquery.order_by,
302                limit: subquery.limit,
303                offset: subquery.offset,
304                distribute_by: subquery.distribute_by,
305                sort_by: subquery.sort_by,
306                cluster_by: subquery.cluster_by,
307                lateral: subquery.lateral,
308                modifiers_inside: subquery.modifiers_inside,
309                trailing_comments: subquery.trailing_comments,
310                inferred_type: None,
311            }))
312        }
313
314        // Set operations
315        Expression::Union(union) => {
316            let mut u = *union;
317            let left = std::mem::replace(&mut u.left, Expression::Null(Null));
318            u.left = canonicalize_recursive(left, dialect);
319            let right = std::mem::replace(&mut u.right, Expression::Null(Null));
320            u.right = canonicalize_recursive(right, dialect);
321            Expression::Union(Box::new(u))
322        }
323        Expression::Intersect(intersect) => {
324            let mut i = *intersect;
325            let left = std::mem::replace(&mut i.left, Expression::Null(Null));
326            i.left = canonicalize_recursive(left, dialect);
327            let right = std::mem::replace(&mut i.right, Expression::Null(Null));
328            i.right = canonicalize_recursive(right, dialect);
329            Expression::Intersect(Box::new(i))
330        }
331        Expression::Except(except) => {
332            let mut e = *except;
333            let left = std::mem::replace(&mut e.left, Expression::Null(Null));
334            e.left = canonicalize_recursive(left, dialect);
335            let right = std::mem::replace(&mut e.right, Expression::Null(Null));
336            e.right = canonicalize_recursive(right, dialect);
337            Expression::Except(Box::new(e))
338        }
339
340        // Leaf nodes - return unchanged
341        other => other,
342    };
343
344    expr
345}
346
347/// Convert string addition to CONCAT.
348///
349/// When two TEXT types are added with +, convert to CONCAT.
350/// This is used by dialects like T-SQL and Redshift.
351fn add_text_to_concat(expression: Expression) -> Expression {
352    // In a full implementation, we would check if the operands are TEXT types
353    // and convert to CONCAT. For now, we return unchanged.
354    expression
355}
356
357/// Remove redundant cast expressions.
358///
359/// If casting to the same type the expression already is, remove the cast.
360fn remove_redundant_casts(expression: Expression) -> Expression {
361    if let Expression::Cast(cast) = &expression {
362        // Check if the inner expression's type matches the cast target
363        // In a full implementation with type annotations, we would compare types
364        // For now, just check simple cases
365
366        // If casting a literal to its natural type, we might be able to simplify
367        if let Expression::Literal(lit) = &cast.this {
368            if let Literal::String(_) = lit.as_ref() {
369                if matches!(&cast.to, DataType::VarChar { .. } | DataType::Text) {
370                    return cast.this.clone();
371                }
372            }
373        }
374        if let Expression::Literal(lit) = &cast.this {
375            if let Literal::Number(_) = lit.as_ref() {
376                if matches!(
377                    &cast.to,
378                    DataType::Int { .. }
379                        | DataType::BigInt { .. }
380                        | DataType::Decimal { .. }
381                        | DataType::Float { .. }
382                ) {
383                    // Could potentially remove cast, but be conservative
384                }
385            }
386        }
387    }
388    expression
389}
390
391/// Ensure expressions used as boolean predicates are actually boolean.
392///
393/// For example, in some dialects, integers can be used as booleans.
394/// This function ensures proper boolean semantics.
395fn ensure_bools(expression: Expression) -> Expression {
396    // In a full implementation, we would check if the expression is an integer
397    // and convert it to a comparison (e.g., x != 0).
398    // For now, return unchanged.
399    expression
400}
401
402/// Remove explicit ASC from ORDER BY clauses.
403///
404/// Since ASC is the default, `ORDER BY a ASC` can be simplified to `ORDER BY a`.
405fn remove_ascending_order(mut ordered: crate::expressions::Ordered) -> crate::expressions::Ordered {
406    // If ASC was explicitly written (not DESC), remove the explicit flag
407    // since ASC is the default ordering
408    if !ordered.desc && ordered.explicit_asc {
409        ordered.explicit_asc = false;
410    }
411    ordered
412}
413
414/// Canonicalize a binary comparison operation.
415fn canonicalize_comparison<F>(
416    constructor: F,
417    bin: crate::expressions::BinaryOp,
418    dialect: Option<DialectType>,
419) -> Expression
420where
421    F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
422{
423    let left = canonicalize_recursive(bin.left, dialect);
424    let right = canonicalize_recursive(bin.right, dialect);
425
426    // Check for date coercion opportunities
427    let (left, right) = coerce_date_operands(left, right);
428
429    constructor(Box::new(crate::expressions::BinaryOp {
430        left,
431        right,
432        left_comments: bin.left_comments,
433        operator_comments: bin.operator_comments,
434        trailing_comments: bin.trailing_comments,
435        inferred_type: None,
436    }))
437}
438
439/// Canonicalize a regular binary operation.
440fn canonicalize_binary<F>(
441    constructor: F,
442    bin: crate::expressions::BinaryOp,
443    dialect: Option<DialectType>,
444) -> Expression
445where
446    F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
447{
448    let left = canonicalize_recursive(bin.left, dialect);
449    let right = canonicalize_recursive(bin.right, dialect);
450
451    constructor(Box::new(crate::expressions::BinaryOp {
452        left,
453        right,
454        left_comments: bin.left_comments,
455        operator_comments: bin.operator_comments,
456        trailing_comments: bin.trailing_comments,
457        inferred_type: None,
458    }))
459}
460
461/// Coerce date operands in comparisons.
462///
463/// When comparing a date/datetime column with a string literal,
464/// add appropriate CAST to the string.
465fn coerce_date_operands(left: Expression, right: Expression) -> (Expression, Expression) {
466    // Check if we should cast string literals to date/datetime
467    let left = coerce_date_string(left, &right);
468    let right = coerce_date_string(right, &left);
469    (left, right)
470}
471
472/// Coerce a string literal to date/datetime if comparing with a temporal type.
473fn coerce_date_string(expr: Expression, _other: &Expression) -> Expression {
474    if let Expression::Literal(ref lit) = expr {
475        if let Literal::String(ref s) = lit.as_ref() {
476            // Check if the string is an ISO date or datetime
477            if is_iso_date(s) {
478                // In a full implementation, we would add CAST to DATE
479                // For now, return unchanged
480            } else if is_iso_datetime(s) {
481                // In a full implementation, we would add CAST to DATETIME/TIMESTAMP
482                // For now, return unchanged
483            }
484        }
485    }
486    expr
487}
488
489#[cfg(test)]
490mod tests {
491    use super::*;
492    use crate::generator::Generator;
493    use crate::parser::Parser;
494
495    fn gen(expr: &Expression) -> String {
496        Generator::new().generate(expr).unwrap()
497    }
498
499    fn parse(sql: &str) -> Expression {
500        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
501    }
502
503    #[test]
504    fn test_canonicalize_simple() {
505        let expr = parse("SELECT a FROM t");
506        let result = canonicalize(expr, None);
507        let sql = gen(&result);
508        assert!(sql.contains("SELECT"));
509    }
510
511    #[test]
512    fn test_canonicalize_preserves_structure() {
513        let expr = parse("SELECT a, b FROM t WHERE c = 1");
514        let result = canonicalize(expr, None);
515        let sql = gen(&result);
516        assert!(sql.contains("WHERE"));
517    }
518
519    #[test]
520    fn test_canonicalize_and_or() {
521        let expr = parse("SELECT 1 WHERE a AND b OR c");
522        let result = canonicalize(expr, None);
523        let sql = gen(&result);
524        assert!(sql.contains("AND") || sql.contains("OR"));
525    }
526
527    #[test]
528    fn test_canonicalize_comparison() {
529        let expr = parse("SELECT 1 WHERE a = 1 AND b > 2");
530        let result = canonicalize(expr, None);
531        let sql = gen(&result);
532        assert!(sql.contains("=") && sql.contains(">"));
533    }
534
535    #[test]
536    fn test_canonicalize_case() {
537        let expr = parse("SELECT CASE WHEN a = 1 THEN 'yes' ELSE 'no' END FROM t");
538        let result = canonicalize(expr, None);
539        let sql = gen(&result);
540        assert!(sql.contains("CASE") && sql.contains("WHEN"));
541    }
542
543    #[test]
544    fn test_canonicalize_subquery() {
545        let expr = parse("SELECT a FROM (SELECT b FROM t) AS sub");
546        let result = canonicalize(expr, None);
547        let sql = gen(&result);
548        assert!(sql.contains("SELECT") && sql.contains("sub"));
549    }
550
551    #[test]
552    fn test_canonicalize_order_by() {
553        let expr = parse("SELECT a FROM t ORDER BY a");
554        let result = canonicalize(expr, None);
555        let sql = gen(&result);
556        assert!(sql.contains("ORDER BY"));
557    }
558
559    #[test]
560    fn test_canonicalize_union() {
561        let expr = parse("SELECT a FROM t UNION SELECT b FROM s");
562        let result = canonicalize(expr, None);
563        let sql = gen(&result);
564        assert!(sql.contains("UNION"));
565    }
566
567    #[test]
568    fn test_add_text_to_concat_passthrough() {
569        // Test that non-text additions pass through
570        let expr = parse("SELECT 1 + 2");
571        let result = canonicalize(expr, None);
572        let sql = gen(&result);
573        assert!(sql.contains("+"));
574    }
575
576    #[test]
577    fn test_canonicalize_function() {
578        let expr = parse("SELECT MAX(a) FROM t");
579        let result = canonicalize(expr, None);
580        let sql = gen(&result);
581        assert!(sql.contains("MAX"));
582    }
583
584    #[test]
585    fn test_canonicalize_between() {
586        let expr = parse("SELECT 1 WHERE a BETWEEN 1 AND 10");
587        let result = canonicalize(expr, None);
588        let sql = gen(&result);
589        assert!(sql.contains("BETWEEN"));
590    }
591}