Skip to main content

polyglot_sql/optimizer/
canonicalize.rs

1//! Canonicalization Module
2//!
3//! This module provides functionality for converting SQL expressions into a
4//! standard canonical form. This includes:
5//! - Converting string addition to CONCAT
6//! - Replacing date functions with casts
7//! - Removing redundant type casts
8//! - Ensuring boolean predicates
9//! - Removing unnecessary ASC from ORDER BY
10//!
11//! Ported from sqlglot's optimizer/canonicalize.py
12
13use crate::dialects::DialectType;
14use crate::expressions::{DataType, Expression, Literal};
15use crate::helper::{is_iso_date, is_iso_datetime};
16
17/// Converts a SQL expression into a standard canonical form.
18///
19/// This transformation relies on type annotations because many of the
20/// conversions depend on type inference.
21///
22/// # Arguments
23/// * `expression` - The expression to canonicalize
24/// * `dialect` - Optional dialect for dialect-specific behavior
25///
26/// # Returns
27/// The canonicalized expression
28pub fn canonicalize(expression: Expression, dialect: Option<DialectType>) -> Expression {
29    canonicalize_recursive(expression, dialect)
30}
31
32/// Recursively canonicalize an expression and its children
33fn canonicalize_recursive(expression: Expression, dialect: Option<DialectType>) -> Expression {
34    let expr = match expression {
35        Expression::Select(mut select) => {
36            // Canonicalize SELECT expressions
37            select.expressions = select
38                .expressions
39                .into_iter()
40                .map(|e| canonicalize_recursive(e, dialect))
41                .collect();
42
43            // Canonicalize FROM
44            if let Some(mut from) = select.from {
45                from.expressions = from
46                    .expressions
47                    .into_iter()
48                    .map(|e| canonicalize_recursive(e, dialect))
49                    .collect();
50                select.from = Some(from);
51            }
52
53            // Canonicalize WHERE
54            if let Some(mut where_clause) = select.where_clause {
55                where_clause.this = canonicalize_recursive(where_clause.this, dialect);
56                where_clause.this = ensure_bools(where_clause.this);
57                select.where_clause = Some(where_clause);
58            }
59
60            // Canonicalize HAVING
61            if let Some(mut having) = select.having {
62                having.this = canonicalize_recursive(having.this, dialect);
63                having.this = ensure_bools(having.this);
64                select.having = Some(having);
65            }
66
67            // Canonicalize ORDER BY
68            if let Some(mut order_by) = select.order_by {
69                order_by.expressions = order_by
70                    .expressions
71                    .into_iter()
72                    .map(|mut o| {
73                        o.this = canonicalize_recursive(o.this, dialect);
74                        o = remove_ascending_order(o);
75                        o
76                    })
77                    .collect();
78                select.order_by = Some(order_by);
79            }
80
81            // Canonicalize JOINs
82            select.joins = select
83                .joins
84                .into_iter()
85                .map(|mut j| {
86                    j.this = canonicalize_recursive(j.this, dialect);
87                    if let Some(on) = j.on {
88                        j.on = Some(canonicalize_recursive(on, dialect));
89                    }
90                    j
91                })
92                .collect();
93
94            Expression::Select(select)
95        }
96
97        // Binary operations that might involve string addition
98        Expression::Add(bin) => {
99            let left = canonicalize_recursive(bin.left, dialect);
100            let right = canonicalize_recursive(bin.right, dialect);
101            let result = Expression::Add(Box::new(crate::expressions::BinaryOp {
102                left,
103                right,
104                left_comments: bin.left_comments,
105                operator_comments: bin.operator_comments,
106                trailing_comments: bin.trailing_comments,
107            }));
108            add_text_to_concat(result)
109        }
110
111        // Other binary operations
112        Expression::And(bin) => {
113            let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
114            let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
115            Expression::And(Box::new(crate::expressions::BinaryOp {
116                left,
117                right,
118                left_comments: bin.left_comments,
119                operator_comments: bin.operator_comments,
120                trailing_comments: bin.trailing_comments,
121            }))
122        }
123        Expression::Or(bin) => {
124            let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
125            let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
126            Expression::Or(Box::new(crate::expressions::BinaryOp {
127                left,
128                right,
129                left_comments: bin.left_comments,
130                operator_comments: bin.operator_comments,
131                trailing_comments: bin.trailing_comments,
132            }))
133        }
134
135        Expression::Not(un) => {
136            let inner = ensure_bools(canonicalize_recursive(un.this, dialect));
137            Expression::Not(Box::new(crate::expressions::UnaryOp { this: inner }))
138        }
139
140        // Comparison operations - check for date coercion
141        Expression::Eq(bin) => canonicalize_comparison(Expression::Eq, *bin, dialect),
142        Expression::Neq(bin) => canonicalize_comparison(Expression::Neq, *bin, dialect),
143        Expression::Lt(bin) => canonicalize_comparison(Expression::Lt, *bin, dialect),
144        Expression::Lte(bin) => canonicalize_comparison(Expression::Lte, *bin, dialect),
145        Expression::Gt(bin) => canonicalize_comparison(Expression::Gt, *bin, dialect),
146        Expression::Gte(bin) => canonicalize_comparison(Expression::Gte, *bin, dialect),
147
148        Expression::Sub(bin) => canonicalize_comparison(Expression::Sub, *bin, dialect),
149        Expression::Mul(bin) => canonicalize_binary(Expression::Mul, *bin, dialect),
150        Expression::Div(bin) => canonicalize_binary(Expression::Div, *bin, dialect),
151
152        // Cast - check for redundancy
153        Expression::Cast(cast) => {
154            let inner = canonicalize_recursive(cast.this, dialect);
155            let result = Expression::Cast(Box::new(crate::expressions::Cast {
156                this: inner,
157                to: cast.to,
158                trailing_comments: cast.trailing_comments,
159                double_colon_syntax: cast.double_colon_syntax,
160                format: cast.format,
161                default: cast.default,
162            }));
163            remove_redundant_casts(result)
164        }
165
166        // Function expressions
167        Expression::Function(func) => {
168            let args = func
169                .args
170                .into_iter()
171                .map(|e| canonicalize_recursive(e, dialect))
172                .collect();
173            Expression::Function(Box::new(crate::expressions::Function {
174                name: func.name,
175                args,
176                distinct: func.distinct,
177                trailing_comments: func.trailing_comments,
178                use_bracket_syntax: func.use_bracket_syntax,
179                no_parens: func.no_parens,
180                quoted: func.quoted,
181            }))
182        }
183
184        Expression::AggregateFunction(agg) => {
185            let args = agg
186                .args
187                .into_iter()
188                .map(|e| canonicalize_recursive(e, dialect))
189                .collect();
190            Expression::AggregateFunction(Box::new(crate::expressions::AggregateFunction {
191                name: agg.name,
192                args,
193                distinct: agg.distinct,
194                filter: agg.filter.map(|f| canonicalize_recursive(f, dialect)),
195                order_by: agg.order_by,
196                limit: agg.limit,
197                ignore_nulls: agg.ignore_nulls,
198            }))
199        }
200
201        // Alias
202        Expression::Alias(alias) => {
203            let inner = canonicalize_recursive(alias.this, dialect);
204            Expression::Alias(Box::new(crate::expressions::Alias {
205                this: inner,
206                alias: alias.alias,
207                column_aliases: alias.column_aliases,
208                pre_alias_comments: alias.pre_alias_comments,
209                trailing_comments: alias.trailing_comments,
210            }))
211        }
212
213        // Paren
214        Expression::Paren(paren) => {
215            let inner = canonicalize_recursive(paren.this, dialect);
216            Expression::Paren(Box::new(crate::expressions::Paren {
217                this: inner,
218                trailing_comments: paren.trailing_comments,
219            }))
220        }
221
222        // Case
223        Expression::Case(case) => {
224            let operand = case.operand.map(|e| canonicalize_recursive(e, dialect));
225            let whens = case
226                .whens
227                .into_iter()
228                .map(|(w, t)| {
229                    (
230                        canonicalize_recursive(w, dialect),
231                        canonicalize_recursive(t, dialect),
232                    )
233                })
234                .collect();
235            let else_ = case.else_.map(|e| canonicalize_recursive(e, dialect));
236            Expression::Case(Box::new(crate::expressions::Case {
237                operand,
238                whens,
239                else_,
240            }))
241        }
242
243        // Between - check for date coercion
244        Expression::Between(between) => {
245            let this = canonicalize_recursive(between.this, dialect);
246            let low = canonicalize_recursive(between.low, dialect);
247            let high = canonicalize_recursive(between.high, dialect);
248            Expression::Between(Box::new(crate::expressions::Between {
249                this,
250                low,
251                high,
252                not: between.not,
253            }))
254        }
255
256        // In
257        Expression::In(in_expr) => {
258            let this = canonicalize_recursive(in_expr.this, dialect);
259            let expressions = in_expr
260                .expressions
261                .into_iter()
262                .map(|e| canonicalize_recursive(e, dialect))
263                .collect();
264            let query = in_expr.query.map(|q| canonicalize_recursive(q, dialect));
265            Expression::In(Box::new(crate::expressions::In {
266                this,
267                expressions,
268                query,
269                not: in_expr.not,
270                global: in_expr.global,
271                unnest: in_expr.unnest,
272            }))
273        }
274
275        // Subquery
276        Expression::Subquery(subquery) => {
277            let this = canonicalize_recursive(subquery.this, dialect);
278            Expression::Subquery(Box::new(crate::expressions::Subquery {
279                this,
280                alias: subquery.alias,
281                column_aliases: subquery.column_aliases,
282                order_by: subquery.order_by,
283                limit: subquery.limit,
284                offset: subquery.offset,
285                distribute_by: subquery.distribute_by,
286                sort_by: subquery.sort_by,
287                cluster_by: subquery.cluster_by,
288                lateral: subquery.lateral,
289                modifiers_inside: subquery.modifiers_inside,
290                trailing_comments: subquery.trailing_comments,
291            }))
292        }
293
294        // Set operations
295        Expression::Union(union) => {
296            let left = canonicalize_recursive(union.left, dialect);
297            let right = canonicalize_recursive(union.right, dialect);
298            Expression::Union(Box::new(crate::expressions::Union {
299                left,
300                right,
301                all: union.all,
302                distinct: union.distinct,
303                with: union.with,
304                order_by: union.order_by,
305                limit: union.limit,
306                offset: union.offset,
307                distribute_by: union.distribute_by,
308                sort_by: union.sort_by,
309                cluster_by: union.cluster_by,
310                by_name: union.by_name,
311                side: union.side,
312                kind: union.kind,
313                corresponding: union.corresponding,
314                strict: union.strict,
315                on_columns: union.on_columns,
316            }))
317        }
318        Expression::Intersect(intersect) => {
319            let left = canonicalize_recursive(intersect.left, dialect);
320            let right = canonicalize_recursive(intersect.right, dialect);
321            Expression::Intersect(Box::new(crate::expressions::Intersect {
322                left,
323                right,
324                all: intersect.all,
325                distinct: intersect.distinct,
326                with: intersect.with,
327                order_by: intersect.order_by,
328                limit: intersect.limit,
329                offset: intersect.offset,
330                distribute_by: intersect.distribute_by,
331                sort_by: intersect.sort_by,
332                cluster_by: intersect.cluster_by,
333                by_name: intersect.by_name,
334                side: intersect.side,
335                kind: intersect.kind,
336                corresponding: intersect.corresponding,
337                strict: intersect.strict,
338                on_columns: intersect.on_columns,
339            }))
340        }
341        Expression::Except(except) => {
342            let left = canonicalize_recursive(except.left, dialect);
343            let right = canonicalize_recursive(except.right, dialect);
344            Expression::Except(Box::new(crate::expressions::Except {
345                left,
346                right,
347                all: except.all,
348                distinct: except.distinct,
349                with: except.with,
350                order_by: except.order_by,
351                limit: except.limit,
352                offset: except.offset,
353                distribute_by: except.distribute_by,
354                sort_by: except.sort_by,
355                cluster_by: except.cluster_by,
356                by_name: except.by_name,
357                side: except.side,
358                kind: except.kind,
359                corresponding: except.corresponding,
360                strict: except.strict,
361                on_columns: except.on_columns,
362            }))
363        }
364
365        // Leaf nodes - return unchanged
366        other => other,
367    };
368
369    expr
370}
371
372/// Convert string addition to CONCAT.
373///
374/// When two TEXT types are added with +, convert to CONCAT.
375/// This is used by dialects like T-SQL and Redshift.
376fn add_text_to_concat(expression: Expression) -> Expression {
377    // In a full implementation, we would check if the operands are TEXT types
378    // and convert to CONCAT. For now, we return unchanged.
379    expression
380}
381
382/// Remove redundant cast expressions.
383///
384/// If casting to the same type the expression already is, remove the cast.
385fn remove_redundant_casts(expression: Expression) -> Expression {
386    if let Expression::Cast(cast) = &expression {
387        // Check if the inner expression's type matches the cast target
388        // In a full implementation with type annotations, we would compare types
389        // For now, just check simple cases
390
391        // If casting a literal to its natural type, we might be able to simplify
392        if let Expression::Literal(Literal::String(_)) = &cast.this {
393            if matches!(&cast.to, DataType::VarChar { .. } | DataType::Text) {
394                return cast.this.clone();
395            }
396        }
397        if let Expression::Literal(Literal::Number(_)) = &cast.this {
398            if matches!(
399                &cast.to,
400                DataType::Int { .. } | DataType::BigInt { .. } | DataType::Decimal { .. } | DataType::Float { .. }
401            ) {
402                // Could potentially remove cast, but be conservative
403            }
404        }
405    }
406    expression
407}
408
409/// Ensure expressions used as boolean predicates are actually boolean.
410///
411/// For example, in some dialects, integers can be used as booleans.
412/// This function ensures proper boolean semantics.
413fn ensure_bools(expression: Expression) -> Expression {
414    // In a full implementation, we would check if the expression is an integer
415    // and convert it to a comparison (e.g., x != 0).
416    // For now, return unchanged.
417    expression
418}
419
420/// Remove explicit ASC from ORDER BY clauses.
421///
422/// Since ASC is the default, `ORDER BY a ASC` can be simplified to `ORDER BY a`.
423fn remove_ascending_order(mut ordered: crate::expressions::Ordered) -> crate::expressions::Ordered {
424    // If ASC was explicitly written (not DESC), remove the explicit flag
425    // since ASC is the default ordering
426    if !ordered.desc && ordered.explicit_asc {
427        ordered.explicit_asc = false;
428    }
429    ordered
430}
431
432/// Canonicalize a binary comparison operation.
433fn canonicalize_comparison<F>(
434    constructor: F,
435    bin: crate::expressions::BinaryOp,
436    dialect: Option<DialectType>,
437) -> Expression
438where
439    F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
440{
441    let left = canonicalize_recursive(bin.left, dialect);
442    let right = canonicalize_recursive(bin.right, dialect);
443
444    // Check for date coercion opportunities
445    let (left, right) = coerce_date_operands(left, right);
446
447    constructor(Box::new(crate::expressions::BinaryOp {
448        left,
449        right,
450        left_comments: bin.left_comments,
451        operator_comments: bin.operator_comments,
452        trailing_comments: bin.trailing_comments,
453    }))
454}
455
456/// Canonicalize a regular binary operation.
457fn canonicalize_binary<F>(
458    constructor: F,
459    bin: crate::expressions::BinaryOp,
460    dialect: Option<DialectType>,
461) -> Expression
462where
463    F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
464{
465    let left = canonicalize_recursive(bin.left, dialect);
466    let right = canonicalize_recursive(bin.right, dialect);
467
468    constructor(Box::new(crate::expressions::BinaryOp {
469        left,
470        right,
471        left_comments: bin.left_comments,
472        operator_comments: bin.operator_comments,
473        trailing_comments: bin.trailing_comments,
474    }))
475}
476
477/// Coerce date operands in comparisons.
478///
479/// When comparing a date/datetime column with a string literal,
480/// add appropriate CAST to the string.
481fn coerce_date_operands(left: Expression, right: Expression) -> (Expression, Expression) {
482    // Check if we should cast string literals to date/datetime
483    let left = coerce_date_string(left, &right);
484    let right = coerce_date_string(right, &left);
485    (left, right)
486}
487
488/// Coerce a string literal to date/datetime if comparing with a temporal type.
489fn coerce_date_string(expr: Expression, _other: &Expression) -> Expression {
490    if let Expression::Literal(Literal::String(ref s)) = expr {
491        // Check if the string is an ISO date or datetime
492        if is_iso_date(s) {
493            // In a full implementation, we would add CAST to DATE
494            // For now, return unchanged
495        } else if is_iso_datetime(s) {
496            // In a full implementation, we would add CAST to DATETIME/TIMESTAMP
497            // For now, return unchanged
498        }
499    }
500    expr
501}
502
503#[cfg(test)]
504mod tests {
505    use super::*;
506    use crate::generator::Generator;
507    use crate::parser::Parser;
508
509    fn gen(expr: &Expression) -> String {
510        Generator::new().generate(expr).unwrap()
511    }
512
513    fn parse(sql: &str) -> Expression {
514        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
515    }
516
517    #[test]
518    fn test_canonicalize_simple() {
519        let expr = parse("SELECT a FROM t");
520        let result = canonicalize(expr, None);
521        let sql = gen(&result);
522        assert!(sql.contains("SELECT"));
523    }
524
525    #[test]
526    fn test_canonicalize_preserves_structure() {
527        let expr = parse("SELECT a, b FROM t WHERE c = 1");
528        let result = canonicalize(expr, None);
529        let sql = gen(&result);
530        assert!(sql.contains("WHERE"));
531    }
532
533    #[test]
534    fn test_canonicalize_and_or() {
535        let expr = parse("SELECT 1 WHERE a AND b OR c");
536        let result = canonicalize(expr, None);
537        let sql = gen(&result);
538        assert!(sql.contains("AND") || sql.contains("OR"));
539    }
540
541    #[test]
542    fn test_canonicalize_comparison() {
543        let expr = parse("SELECT 1 WHERE a = 1 AND b > 2");
544        let result = canonicalize(expr, None);
545        let sql = gen(&result);
546        assert!(sql.contains("=") && sql.contains(">"));
547    }
548
549    #[test]
550    fn test_canonicalize_case() {
551        let expr = parse("SELECT CASE WHEN a = 1 THEN 'yes' ELSE 'no' END FROM t");
552        let result = canonicalize(expr, None);
553        let sql = gen(&result);
554        assert!(sql.contains("CASE") && sql.contains("WHEN"));
555    }
556
557    #[test]
558    fn test_canonicalize_subquery() {
559        let expr = parse("SELECT a FROM (SELECT b FROM t) AS sub");
560        let result = canonicalize(expr, None);
561        let sql = gen(&result);
562        assert!(sql.contains("SELECT") && sql.contains("sub"));
563    }
564
565    #[test]
566    fn test_canonicalize_order_by() {
567        let expr = parse("SELECT a FROM t ORDER BY a");
568        let result = canonicalize(expr, None);
569        let sql = gen(&result);
570        assert!(sql.contains("ORDER BY"));
571    }
572
573    #[test]
574    fn test_canonicalize_union() {
575        let expr = parse("SELECT a FROM t UNION SELECT b FROM s");
576        let result = canonicalize(expr, None);
577        let sql = gen(&result);
578        assert!(sql.contains("UNION"));
579    }
580
581    #[test]
582    fn test_add_text_to_concat_passthrough() {
583        // Test that non-text additions pass through
584        let expr = parse("SELECT 1 + 2");
585        let result = canonicalize(expr, None);
586        let sql = gen(&result);
587        assert!(sql.contains("+"));
588    }
589
590    #[test]
591    fn test_canonicalize_function() {
592        let expr = parse("SELECT MAX(a) FROM t");
593        let result = canonicalize(expr, None);
594        let sql = gen(&result);
595        assert!(sql.contains("MAX"));
596    }
597
598    #[test]
599    fn test_canonicalize_between() {
600        let expr = parse("SELECT 1 WHERE a BETWEEN 1 AND 10");
601        let result = canonicalize(expr, None);
602        let sql = gen(&result);
603        assert!(sql.contains("BETWEEN"));
604    }
605}