1use crate::dialects::DialectType;
14use crate::expressions::{DataType, Expression, Literal, Null};
15use crate::helper::{is_iso_date, is_iso_datetime};
16
17pub fn canonicalize(expression: Expression, dialect: Option<DialectType>) -> Expression {
29 canonicalize_recursive(expression, dialect)
30}
31
32fn canonicalize_recursive(expression: Expression, dialect: Option<DialectType>) -> Expression {
34 let expr = match expression {
35 Expression::Select(mut select) => {
36 select.expressions = select
38 .expressions
39 .into_iter()
40 .map(|e| canonicalize_recursive(e, dialect))
41 .collect();
42
43 if let Some(mut from) = select.from {
45 from.expressions = from
46 .expressions
47 .into_iter()
48 .map(|e| canonicalize_recursive(e, dialect))
49 .collect();
50 select.from = Some(from);
51 }
52
53 if let Some(mut where_clause) = select.where_clause {
55 where_clause.this = canonicalize_recursive(where_clause.this, dialect);
56 where_clause.this = ensure_bools(where_clause.this);
57 select.where_clause = Some(where_clause);
58 }
59
60 if let Some(mut having) = select.having {
62 having.this = canonicalize_recursive(having.this, dialect);
63 having.this = ensure_bools(having.this);
64 select.having = Some(having);
65 }
66
67 if let Some(mut order_by) = select.order_by {
69 order_by.expressions = order_by
70 .expressions
71 .into_iter()
72 .map(|mut o| {
73 o.this = canonicalize_recursive(o.this, dialect);
74 o = remove_ascending_order(o);
75 o
76 })
77 .collect();
78 select.order_by = Some(order_by);
79 }
80
81 select.joins = select
83 .joins
84 .into_iter()
85 .map(|mut j| {
86 j.this = canonicalize_recursive(j.this, dialect);
87 if let Some(on) = j.on {
88 j.on = Some(canonicalize_recursive(on, dialect));
89 }
90 j
91 })
92 .collect();
93
94 Expression::Select(select)
95 }
96
97 Expression::Add(bin) => {
99 let left = canonicalize_recursive(bin.left, dialect);
100 let right = canonicalize_recursive(bin.right, dialect);
101 let result = Expression::Add(Box::new(crate::expressions::BinaryOp {
102 left,
103 right,
104 left_comments: bin.left_comments,
105 operator_comments: bin.operator_comments,
106 trailing_comments: bin.trailing_comments,
107 inferred_type: None,
108 }));
109 add_text_to_concat(result)
110 }
111
112 Expression::And(bin) => {
114 let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
115 let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
116 Expression::And(Box::new(crate::expressions::BinaryOp {
117 left,
118 right,
119 left_comments: bin.left_comments,
120 operator_comments: bin.operator_comments,
121 trailing_comments: bin.trailing_comments,
122 inferred_type: None,
123 }))
124 }
125 Expression::Or(bin) => {
126 let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
127 let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
128 Expression::Or(Box::new(crate::expressions::BinaryOp {
129 left,
130 right,
131 left_comments: bin.left_comments,
132 operator_comments: bin.operator_comments,
133 trailing_comments: bin.trailing_comments,
134 inferred_type: None,
135 }))
136 }
137
138 Expression::Not(un) => {
139 let inner = ensure_bools(canonicalize_recursive(un.this, dialect));
140 Expression::Not(Box::new(crate::expressions::UnaryOp {
141 this: inner,
142 inferred_type: None,
143 }))
144 }
145
146 Expression::Eq(bin) => canonicalize_comparison(Expression::Eq, *bin, dialect),
148 Expression::Neq(bin) => canonicalize_comparison(Expression::Neq, *bin, dialect),
149 Expression::Lt(bin) => canonicalize_comparison(Expression::Lt, *bin, dialect),
150 Expression::Lte(bin) => canonicalize_comparison(Expression::Lte, *bin, dialect),
151 Expression::Gt(bin) => canonicalize_comparison(Expression::Gt, *bin, dialect),
152 Expression::Gte(bin) => canonicalize_comparison(Expression::Gte, *bin, dialect),
153
154 Expression::Sub(bin) => canonicalize_comparison(Expression::Sub, *bin, dialect),
155 Expression::Mul(bin) => canonicalize_binary(Expression::Mul, *bin, dialect),
156 Expression::Div(bin) => canonicalize_binary(Expression::Div, *bin, dialect),
157
158 Expression::Cast(cast) => {
160 let inner = canonicalize_recursive(cast.this, dialect);
161 let result = Expression::Cast(Box::new(crate::expressions::Cast {
162 this: inner,
163 to: cast.to,
164 trailing_comments: cast.trailing_comments,
165 double_colon_syntax: cast.double_colon_syntax,
166 format: cast.format,
167 default: cast.default,
168 inferred_type: None,
169 }));
170 remove_redundant_casts(result)
171 }
172
173 Expression::Function(func) => {
175 let args = func
176 .args
177 .into_iter()
178 .map(|e| canonicalize_recursive(e, dialect))
179 .collect();
180 Expression::Function(Box::new(crate::expressions::Function {
181 name: func.name,
182 args,
183 distinct: func.distinct,
184 trailing_comments: func.trailing_comments,
185 use_bracket_syntax: func.use_bracket_syntax,
186 no_parens: func.no_parens,
187 quoted: func.quoted,
188 span: None,
189 inferred_type: None,
190 }))
191 }
192
193 Expression::AggregateFunction(agg) => {
194 let args = agg
195 .args
196 .into_iter()
197 .map(|e| canonicalize_recursive(e, dialect))
198 .collect();
199 Expression::AggregateFunction(Box::new(crate::expressions::AggregateFunction {
200 name: agg.name,
201 args,
202 distinct: agg.distinct,
203 filter: agg.filter.map(|f| canonicalize_recursive(f, dialect)),
204 order_by: agg.order_by,
205 limit: agg.limit,
206 ignore_nulls: agg.ignore_nulls,
207 inferred_type: None,
208 }))
209 }
210
211 Expression::Alias(alias) => {
213 let inner = canonicalize_recursive(alias.this, dialect);
214 Expression::Alias(Box::new(crate::expressions::Alias {
215 this: inner,
216 alias: alias.alias,
217 column_aliases: alias.column_aliases,
218 pre_alias_comments: alias.pre_alias_comments,
219 trailing_comments: alias.trailing_comments,
220 inferred_type: None,
221 }))
222 }
223
224 Expression::Paren(paren) => {
226 let inner = canonicalize_recursive(paren.this, dialect);
227 Expression::Paren(Box::new(crate::expressions::Paren {
228 this: inner,
229 trailing_comments: paren.trailing_comments,
230 }))
231 }
232
233 Expression::Case(case) => {
235 let operand = case.operand.map(|e| canonicalize_recursive(e, dialect));
236 let whens = case
237 .whens
238 .into_iter()
239 .map(|(w, t)| {
240 (
241 canonicalize_recursive(w, dialect),
242 canonicalize_recursive(t, dialect),
243 )
244 })
245 .collect();
246 let else_ = case.else_.map(|e| canonicalize_recursive(e, dialect));
247 Expression::Case(Box::new(crate::expressions::Case {
248 operand,
249 whens,
250 else_,
251 comments: Vec::new(),
252 inferred_type: None,
253 }))
254 }
255
256 Expression::Between(between) => {
258 let this = canonicalize_recursive(between.this, dialect);
259 let low = canonicalize_recursive(between.low, dialect);
260 let high = canonicalize_recursive(between.high, dialect);
261 Expression::Between(Box::new(crate::expressions::Between {
262 this,
263 low,
264 high,
265 not: between.not,
266 symmetric: between.symmetric,
267 }))
268 }
269
270 Expression::In(in_expr) => {
272 let this = canonicalize_recursive(in_expr.this, dialect);
273 let expressions = in_expr
274 .expressions
275 .into_iter()
276 .map(|e| canonicalize_recursive(e, dialect))
277 .collect();
278 let query = in_expr.query.map(|q| canonicalize_recursive(q, dialect));
279 Expression::In(Box::new(crate::expressions::In {
280 this,
281 expressions,
282 query,
283 not: in_expr.not,
284 global: in_expr.global,
285 unnest: in_expr.unnest,
286 is_field: in_expr.is_field,
287 }))
288 }
289
290 Expression::Subquery(subquery) => {
292 let this = canonicalize_recursive(subquery.this, dialect);
293 Expression::Subquery(Box::new(crate::expressions::Subquery {
294 this,
295 alias: subquery.alias,
296 column_aliases: subquery.column_aliases,
297 order_by: subquery.order_by,
298 limit: subquery.limit,
299 offset: subquery.offset,
300 distribute_by: subquery.distribute_by,
301 sort_by: subquery.sort_by,
302 cluster_by: subquery.cluster_by,
303 lateral: subquery.lateral,
304 modifiers_inside: subquery.modifiers_inside,
305 trailing_comments: subquery.trailing_comments,
306 inferred_type: None,
307 }))
308 }
309
310 Expression::Union(union) => {
312 let mut u = *union;
313 let left = std::mem::replace(&mut u.left, Expression::Null(Null));
314 u.left = canonicalize_recursive(left, dialect);
315 let right = std::mem::replace(&mut u.right, Expression::Null(Null));
316 u.right = canonicalize_recursive(right, dialect);
317 Expression::Union(Box::new(u))
318 }
319 Expression::Intersect(intersect) => {
320 let mut i = *intersect;
321 let left = std::mem::replace(&mut i.left, Expression::Null(Null));
322 i.left = canonicalize_recursive(left, dialect);
323 let right = std::mem::replace(&mut i.right, Expression::Null(Null));
324 i.right = canonicalize_recursive(right, dialect);
325 Expression::Intersect(Box::new(i))
326 }
327 Expression::Except(except) => {
328 let mut e = *except;
329 let left = std::mem::replace(&mut e.left, Expression::Null(Null));
330 e.left = canonicalize_recursive(left, dialect);
331 let right = std::mem::replace(&mut e.right, Expression::Null(Null));
332 e.right = canonicalize_recursive(right, dialect);
333 Expression::Except(Box::new(e))
334 }
335
336 other => other,
338 };
339
340 expr
341}
342
343fn add_text_to_concat(expression: Expression) -> Expression {
348 expression
351}
352
353fn remove_redundant_casts(expression: Expression) -> Expression {
357 if let Expression::Cast(cast) = &expression {
358 if let Expression::Literal(lit) = &cast.this {
364 if let Literal::String(_) = lit.as_ref() {
365 if matches!(&cast.to, DataType::VarChar { .. } | DataType::Text) {
366 return cast.this.clone();
367 }
368 }
369 }
370 if let Expression::Literal(lit) = &cast.this {
371 if let Literal::Number(_) = lit.as_ref() {
372 if matches!(
373 &cast.to,
374 DataType::Int { .. }
375 | DataType::BigInt { .. }
376 | DataType::Decimal { .. }
377 | DataType::Float { .. }
378 ) {
379 }
381 }
382 }
383 }
384 expression
385}
386
387fn ensure_bools(expression: Expression) -> Expression {
392 expression
396}
397
398fn remove_ascending_order(mut ordered: crate::expressions::Ordered) -> crate::expressions::Ordered {
402 if !ordered.desc && ordered.explicit_asc {
405 ordered.explicit_asc = false;
406 }
407 ordered
408}
409
410fn canonicalize_comparison<F>(
412 constructor: F,
413 bin: crate::expressions::BinaryOp,
414 dialect: Option<DialectType>,
415) -> Expression
416where
417 F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
418{
419 let left = canonicalize_recursive(bin.left, dialect);
420 let right = canonicalize_recursive(bin.right, dialect);
421
422 let (left, right) = coerce_date_operands(left, right);
424
425 constructor(Box::new(crate::expressions::BinaryOp {
426 left,
427 right,
428 left_comments: bin.left_comments,
429 operator_comments: bin.operator_comments,
430 trailing_comments: bin.trailing_comments,
431 inferred_type: None,
432 }))
433}
434
435fn canonicalize_binary<F>(
437 constructor: F,
438 bin: crate::expressions::BinaryOp,
439 dialect: Option<DialectType>,
440) -> Expression
441where
442 F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
443{
444 let left = canonicalize_recursive(bin.left, dialect);
445 let right = canonicalize_recursive(bin.right, dialect);
446
447 constructor(Box::new(crate::expressions::BinaryOp {
448 left,
449 right,
450 left_comments: bin.left_comments,
451 operator_comments: bin.operator_comments,
452 trailing_comments: bin.trailing_comments,
453 inferred_type: None,
454 }))
455}
456
457fn coerce_date_operands(left: Expression, right: Expression) -> (Expression, Expression) {
462 let left = coerce_date_string(left, &right);
464 let right = coerce_date_string(right, &left);
465 (left, right)
466}
467
468fn coerce_date_string(expr: Expression, _other: &Expression) -> Expression {
470 if let Expression::Literal(ref lit) = expr {
471 if let Literal::String(ref s) = lit.as_ref() {
472 if is_iso_date(s) {
474 } else if is_iso_datetime(s) {
477 }
480 }
481 }
482 expr
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488 use crate::generator::Generator;
489 use crate::parser::Parser;
490
491 fn gen(expr: &Expression) -> String {
492 Generator::new().generate(expr).unwrap()
493 }
494
495 fn parse(sql: &str) -> Expression {
496 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
497 }
498
499 #[test]
500 fn test_canonicalize_simple() {
501 let expr = parse("SELECT a FROM t");
502 let result = canonicalize(expr, None);
503 let sql = gen(&result);
504 assert!(sql.contains("SELECT"));
505 }
506
507 #[test]
508 fn test_canonicalize_preserves_structure() {
509 let expr = parse("SELECT a, b FROM t WHERE c = 1");
510 let result = canonicalize(expr, None);
511 let sql = gen(&result);
512 assert!(sql.contains("WHERE"));
513 }
514
515 #[test]
516 fn test_canonicalize_and_or() {
517 let expr = parse("SELECT 1 WHERE a AND b OR c");
518 let result = canonicalize(expr, None);
519 let sql = gen(&result);
520 assert!(sql.contains("AND") || sql.contains("OR"));
521 }
522
523 #[test]
524 fn test_canonicalize_comparison() {
525 let expr = parse("SELECT 1 WHERE a = 1 AND b > 2");
526 let result = canonicalize(expr, None);
527 let sql = gen(&result);
528 assert!(sql.contains("=") && sql.contains(">"));
529 }
530
531 #[test]
532 fn test_canonicalize_case() {
533 let expr = parse("SELECT CASE WHEN a = 1 THEN 'yes' ELSE 'no' END FROM t");
534 let result = canonicalize(expr, None);
535 let sql = gen(&result);
536 assert!(sql.contains("CASE") && sql.contains("WHEN"));
537 }
538
539 #[test]
540 fn test_canonicalize_subquery() {
541 let expr = parse("SELECT a FROM (SELECT b FROM t) AS sub");
542 let result = canonicalize(expr, None);
543 let sql = gen(&result);
544 assert!(sql.contains("SELECT") && sql.contains("sub"));
545 }
546
547 #[test]
548 fn test_canonicalize_order_by() {
549 let expr = parse("SELECT a FROM t ORDER BY a");
550 let result = canonicalize(expr, None);
551 let sql = gen(&result);
552 assert!(sql.contains("ORDER BY"));
553 }
554
555 #[test]
556 fn test_canonicalize_union() {
557 let expr = parse("SELECT a FROM t UNION SELECT b FROM s");
558 let result = canonicalize(expr, None);
559 let sql = gen(&result);
560 assert!(sql.contains("UNION"));
561 }
562
563 #[test]
564 fn test_add_text_to_concat_passthrough() {
565 let expr = parse("SELECT 1 + 2");
567 let result = canonicalize(expr, None);
568 let sql = gen(&result);
569 assert!(sql.contains("+"));
570 }
571
572 #[test]
573 fn test_canonicalize_function() {
574 let expr = parse("SELECT MAX(a) FROM t");
575 let result = canonicalize(expr, None);
576 let sql = gen(&result);
577 assert!(sql.contains("MAX"));
578 }
579
580 #[test]
581 fn test_canonicalize_between() {
582 let expr = parse("SELECT 1 WHERE a BETWEEN 1 AND 10");
583 let result = canonicalize(expr, None);
584 let sql = gen(&result);
585 assert!(sql.contains("BETWEEN"));
586 }
587}