1use crate::dialects::DialectType;
14use crate::expressions::{DataType, Expression, Literal};
15use crate::helper::{is_iso_date, is_iso_datetime};
16
17pub fn canonicalize(expression: Expression, dialect: Option<DialectType>) -> Expression {
29 canonicalize_recursive(expression, dialect)
30}
31
32fn canonicalize_recursive(expression: Expression, dialect: Option<DialectType>) -> Expression {
34 let expr = match expression {
35 Expression::Select(mut select) => {
36 select.expressions = select
38 .expressions
39 .into_iter()
40 .map(|e| canonicalize_recursive(e, dialect))
41 .collect();
42
43 if let Some(mut from) = select.from {
45 from.expressions = from
46 .expressions
47 .into_iter()
48 .map(|e| canonicalize_recursive(e, dialect))
49 .collect();
50 select.from = Some(from);
51 }
52
53 if let Some(mut where_clause) = select.where_clause {
55 where_clause.this = canonicalize_recursive(where_clause.this, dialect);
56 where_clause.this = ensure_bools(where_clause.this);
57 select.where_clause = Some(where_clause);
58 }
59
60 if let Some(mut having) = select.having {
62 having.this = canonicalize_recursive(having.this, dialect);
63 having.this = ensure_bools(having.this);
64 select.having = Some(having);
65 }
66
67 if let Some(mut order_by) = select.order_by {
69 order_by.expressions = order_by
70 .expressions
71 .into_iter()
72 .map(|mut o| {
73 o.this = canonicalize_recursive(o.this, dialect);
74 o = remove_ascending_order(o);
75 o
76 })
77 .collect();
78 select.order_by = Some(order_by);
79 }
80
81 select.joins = select
83 .joins
84 .into_iter()
85 .map(|mut j| {
86 j.this = canonicalize_recursive(j.this, dialect);
87 if let Some(on) = j.on {
88 j.on = Some(canonicalize_recursive(on, dialect));
89 }
90 j
91 })
92 .collect();
93
94 Expression::Select(select)
95 }
96
97 Expression::Add(bin) => {
99 let left = canonicalize_recursive(bin.left, dialect);
100 let right = canonicalize_recursive(bin.right, dialect);
101 let result = Expression::Add(Box::new(crate::expressions::BinaryOp {
102 left,
103 right,
104 left_comments: bin.left_comments,
105 operator_comments: bin.operator_comments,
106 trailing_comments: bin.trailing_comments,
107 }));
108 add_text_to_concat(result)
109 }
110
111 Expression::And(bin) => {
113 let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
114 let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
115 Expression::And(Box::new(crate::expressions::BinaryOp {
116 left,
117 right,
118 left_comments: bin.left_comments,
119 operator_comments: bin.operator_comments,
120 trailing_comments: bin.trailing_comments,
121 }))
122 }
123 Expression::Or(bin) => {
124 let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
125 let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
126 Expression::Or(Box::new(crate::expressions::BinaryOp {
127 left,
128 right,
129 left_comments: bin.left_comments,
130 operator_comments: bin.operator_comments,
131 trailing_comments: bin.trailing_comments,
132 }))
133 }
134
135 Expression::Not(un) => {
136 let inner = ensure_bools(canonicalize_recursive(un.this, dialect));
137 Expression::Not(Box::new(crate::expressions::UnaryOp { this: inner }))
138 }
139
140 Expression::Eq(bin) => canonicalize_comparison(Expression::Eq, *bin, dialect),
142 Expression::Neq(bin) => canonicalize_comparison(Expression::Neq, *bin, dialect),
143 Expression::Lt(bin) => canonicalize_comparison(Expression::Lt, *bin, dialect),
144 Expression::Lte(bin) => canonicalize_comparison(Expression::Lte, *bin, dialect),
145 Expression::Gt(bin) => canonicalize_comparison(Expression::Gt, *bin, dialect),
146 Expression::Gte(bin) => canonicalize_comparison(Expression::Gte, *bin, dialect),
147
148 Expression::Sub(bin) => canonicalize_comparison(Expression::Sub, *bin, dialect),
149 Expression::Mul(bin) => canonicalize_binary(Expression::Mul, *bin, dialect),
150 Expression::Div(bin) => canonicalize_binary(Expression::Div, *bin, dialect),
151
152 Expression::Cast(cast) => {
154 let inner = canonicalize_recursive(cast.this, dialect);
155 let result = Expression::Cast(Box::new(crate::expressions::Cast {
156 this: inner,
157 to: cast.to,
158 trailing_comments: cast.trailing_comments,
159 double_colon_syntax: cast.double_colon_syntax,
160 format: cast.format,
161 default: cast.default,
162 }));
163 remove_redundant_casts(result)
164 }
165
166 Expression::Function(func) => {
168 let args = func
169 .args
170 .into_iter()
171 .map(|e| canonicalize_recursive(e, dialect))
172 .collect();
173 Expression::Function(Box::new(crate::expressions::Function {
174 name: func.name,
175 args,
176 distinct: func.distinct,
177 trailing_comments: func.trailing_comments,
178 use_bracket_syntax: func.use_bracket_syntax,
179 no_parens: func.no_parens,
180 quoted: func.quoted,
181 }))
182 }
183
184 Expression::AggregateFunction(agg) => {
185 let args = agg
186 .args
187 .into_iter()
188 .map(|e| canonicalize_recursive(e, dialect))
189 .collect();
190 Expression::AggregateFunction(Box::new(crate::expressions::AggregateFunction {
191 name: agg.name,
192 args,
193 distinct: agg.distinct,
194 filter: agg.filter.map(|f| canonicalize_recursive(f, dialect)),
195 order_by: agg.order_by,
196 limit: agg.limit,
197 ignore_nulls: agg.ignore_nulls,
198 }))
199 }
200
201 Expression::Alias(alias) => {
203 let inner = canonicalize_recursive(alias.this, dialect);
204 Expression::Alias(Box::new(crate::expressions::Alias {
205 this: inner,
206 alias: alias.alias,
207 column_aliases: alias.column_aliases,
208 pre_alias_comments: alias.pre_alias_comments,
209 trailing_comments: alias.trailing_comments,
210 }))
211 }
212
213 Expression::Paren(paren) => {
215 let inner = canonicalize_recursive(paren.this, dialect);
216 Expression::Paren(Box::new(crate::expressions::Paren {
217 this: inner,
218 trailing_comments: paren.trailing_comments,
219 }))
220 }
221
222 Expression::Case(case) => {
224 let operand = case.operand.map(|e| canonicalize_recursive(e, dialect));
225 let whens = case
226 .whens
227 .into_iter()
228 .map(|(w, t)| {
229 (
230 canonicalize_recursive(w, dialect),
231 canonicalize_recursive(t, dialect),
232 )
233 })
234 .collect();
235 let else_ = case.else_.map(|e| canonicalize_recursive(e, dialect));
236 Expression::Case(Box::new(crate::expressions::Case {
237 operand,
238 whens,
239 else_,
240 }))
241 }
242
243 Expression::Between(between) => {
245 let this = canonicalize_recursive(between.this, dialect);
246 let low = canonicalize_recursive(between.low, dialect);
247 let high = canonicalize_recursive(between.high, dialect);
248 Expression::Between(Box::new(crate::expressions::Between {
249 this,
250 low,
251 high,
252 not: between.not,
253 }))
254 }
255
256 Expression::In(in_expr) => {
258 let this = canonicalize_recursive(in_expr.this, dialect);
259 let expressions = in_expr
260 .expressions
261 .into_iter()
262 .map(|e| canonicalize_recursive(e, dialect))
263 .collect();
264 let query = in_expr.query.map(|q| canonicalize_recursive(q, dialect));
265 Expression::In(Box::new(crate::expressions::In {
266 this,
267 expressions,
268 query,
269 not: in_expr.not,
270 global: in_expr.global,
271 unnest: in_expr.unnest,
272 }))
273 }
274
275 Expression::Subquery(subquery) => {
277 let this = canonicalize_recursive(subquery.this, dialect);
278 Expression::Subquery(Box::new(crate::expressions::Subquery {
279 this,
280 alias: subquery.alias,
281 column_aliases: subquery.column_aliases,
282 order_by: subquery.order_by,
283 limit: subquery.limit,
284 offset: subquery.offset,
285 distribute_by: subquery.distribute_by,
286 sort_by: subquery.sort_by,
287 cluster_by: subquery.cluster_by,
288 lateral: subquery.lateral,
289 modifiers_inside: subquery.modifiers_inside,
290 trailing_comments: subquery.trailing_comments,
291 }))
292 }
293
294 Expression::Union(union) => {
296 let left = canonicalize_recursive(union.left, dialect);
297 let right = canonicalize_recursive(union.right, dialect);
298 Expression::Union(Box::new(crate::expressions::Union {
299 left,
300 right,
301 all: union.all,
302 distinct: union.distinct,
303 with: union.with,
304 order_by: union.order_by,
305 limit: union.limit,
306 offset: union.offset,
307 distribute_by: union.distribute_by,
308 sort_by: union.sort_by,
309 cluster_by: union.cluster_by,
310 by_name: union.by_name,
311 side: union.side,
312 kind: union.kind,
313 corresponding: union.corresponding,
314 strict: union.strict,
315 on_columns: union.on_columns,
316 }))
317 }
318 Expression::Intersect(intersect) => {
319 let left = canonicalize_recursive(intersect.left, dialect);
320 let right = canonicalize_recursive(intersect.right, dialect);
321 Expression::Intersect(Box::new(crate::expressions::Intersect {
322 left,
323 right,
324 all: intersect.all,
325 distinct: intersect.distinct,
326 with: intersect.with,
327 order_by: intersect.order_by,
328 limit: intersect.limit,
329 offset: intersect.offset,
330 distribute_by: intersect.distribute_by,
331 sort_by: intersect.sort_by,
332 cluster_by: intersect.cluster_by,
333 by_name: intersect.by_name,
334 side: intersect.side,
335 kind: intersect.kind,
336 corresponding: intersect.corresponding,
337 strict: intersect.strict,
338 on_columns: intersect.on_columns,
339 }))
340 }
341 Expression::Except(except) => {
342 let left = canonicalize_recursive(except.left, dialect);
343 let right = canonicalize_recursive(except.right, dialect);
344 Expression::Except(Box::new(crate::expressions::Except {
345 left,
346 right,
347 all: except.all,
348 distinct: except.distinct,
349 with: except.with,
350 order_by: except.order_by,
351 limit: except.limit,
352 offset: except.offset,
353 distribute_by: except.distribute_by,
354 sort_by: except.sort_by,
355 cluster_by: except.cluster_by,
356 by_name: except.by_name,
357 side: except.side,
358 kind: except.kind,
359 corresponding: except.corresponding,
360 strict: except.strict,
361 on_columns: except.on_columns,
362 }))
363 }
364
365 other => other,
367 };
368
369 expr
370}
371
372fn add_text_to_concat(expression: Expression) -> Expression {
377 expression
380}
381
382fn remove_redundant_casts(expression: Expression) -> Expression {
386 if let Expression::Cast(cast) = &expression {
387 if let Expression::Literal(Literal::String(_)) = &cast.this {
393 if matches!(&cast.to, DataType::VarChar { .. } | DataType::Text) {
394 return cast.this.clone();
395 }
396 }
397 if let Expression::Literal(Literal::Number(_)) = &cast.this {
398 if matches!(
399 &cast.to,
400 DataType::Int { .. } | DataType::BigInt { .. } | DataType::Decimal { .. } | DataType::Float { .. }
401 ) {
402 }
404 }
405 }
406 expression
407}
408
409fn ensure_bools(expression: Expression) -> Expression {
414 expression
418}
419
420fn remove_ascending_order(mut ordered: crate::expressions::Ordered) -> crate::expressions::Ordered {
424 if !ordered.desc && ordered.explicit_asc {
427 ordered.explicit_asc = false;
428 }
429 ordered
430}
431
432fn canonicalize_comparison<F>(
434 constructor: F,
435 bin: crate::expressions::BinaryOp,
436 dialect: Option<DialectType>,
437) -> Expression
438where
439 F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
440{
441 let left = canonicalize_recursive(bin.left, dialect);
442 let right = canonicalize_recursive(bin.right, dialect);
443
444 let (left, right) = coerce_date_operands(left, right);
446
447 constructor(Box::new(crate::expressions::BinaryOp {
448 left,
449 right,
450 left_comments: bin.left_comments,
451 operator_comments: bin.operator_comments,
452 trailing_comments: bin.trailing_comments,
453 }))
454}
455
456fn canonicalize_binary<F>(
458 constructor: F,
459 bin: crate::expressions::BinaryOp,
460 dialect: Option<DialectType>,
461) -> Expression
462where
463 F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
464{
465 let left = canonicalize_recursive(bin.left, dialect);
466 let right = canonicalize_recursive(bin.right, dialect);
467
468 constructor(Box::new(crate::expressions::BinaryOp {
469 left,
470 right,
471 left_comments: bin.left_comments,
472 operator_comments: bin.operator_comments,
473 trailing_comments: bin.trailing_comments,
474 }))
475}
476
477fn coerce_date_operands(left: Expression, right: Expression) -> (Expression, Expression) {
482 let left = coerce_date_string(left, &right);
484 let right = coerce_date_string(right, &left);
485 (left, right)
486}
487
488fn coerce_date_string(expr: Expression, _other: &Expression) -> Expression {
490 if let Expression::Literal(Literal::String(ref s)) = expr {
491 if is_iso_date(s) {
493 } else if is_iso_datetime(s) {
496 }
499 }
500 expr
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506 use crate::generator::Generator;
507 use crate::parser::Parser;
508
509 fn gen(expr: &Expression) -> String {
510 Generator::new().generate(expr).unwrap()
511 }
512
513 fn parse(sql: &str) -> Expression {
514 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
515 }
516
517 #[test]
518 fn test_canonicalize_simple() {
519 let expr = parse("SELECT a FROM t");
520 let result = canonicalize(expr, None);
521 let sql = gen(&result);
522 assert!(sql.contains("SELECT"));
523 }
524
525 #[test]
526 fn test_canonicalize_preserves_structure() {
527 let expr = parse("SELECT a, b FROM t WHERE c = 1");
528 let result = canonicalize(expr, None);
529 let sql = gen(&result);
530 assert!(sql.contains("WHERE"));
531 }
532
533 #[test]
534 fn test_canonicalize_and_or() {
535 let expr = parse("SELECT 1 WHERE a AND b OR c");
536 let result = canonicalize(expr, None);
537 let sql = gen(&result);
538 assert!(sql.contains("AND") || sql.contains("OR"));
539 }
540
541 #[test]
542 fn test_canonicalize_comparison() {
543 let expr = parse("SELECT 1 WHERE a = 1 AND b > 2");
544 let result = canonicalize(expr, None);
545 let sql = gen(&result);
546 assert!(sql.contains("=") && sql.contains(">"));
547 }
548
549 #[test]
550 fn test_canonicalize_case() {
551 let expr = parse("SELECT CASE WHEN a = 1 THEN 'yes' ELSE 'no' END FROM t");
552 let result = canonicalize(expr, None);
553 let sql = gen(&result);
554 assert!(sql.contains("CASE") && sql.contains("WHEN"));
555 }
556
557 #[test]
558 fn test_canonicalize_subquery() {
559 let expr = parse("SELECT a FROM (SELECT b FROM t) AS sub");
560 let result = canonicalize(expr, None);
561 let sql = gen(&result);
562 assert!(sql.contains("SELECT") && sql.contains("sub"));
563 }
564
565 #[test]
566 fn test_canonicalize_order_by() {
567 let expr = parse("SELECT a FROM t ORDER BY a");
568 let result = canonicalize(expr, None);
569 let sql = gen(&result);
570 assert!(sql.contains("ORDER BY"));
571 }
572
573 #[test]
574 fn test_canonicalize_union() {
575 let expr = parse("SELECT a FROM t UNION SELECT b FROM s");
576 let result = canonicalize(expr, None);
577 let sql = gen(&result);
578 assert!(sql.contains("UNION"));
579 }
580
581 #[test]
582 fn test_add_text_to_concat_passthrough() {
583 let expr = parse("SELECT 1 + 2");
585 let result = canonicalize(expr, None);
586 let sql = gen(&result);
587 assert!(sql.contains("+"));
588 }
589
590 #[test]
591 fn test_canonicalize_function() {
592 let expr = parse("SELECT MAX(a) FROM t");
593 let result = canonicalize(expr, None);
594 let sql = gen(&result);
595 assert!(sql.contains("MAX"));
596 }
597
598 #[test]
599 fn test_canonicalize_between() {
600 let expr = parse("SELECT 1 WHERE a BETWEEN 1 AND 10");
601 let result = canonicalize(expr, None);
602 let sql = gen(&result);
603 assert!(sql.contains("BETWEEN"));
604 }
605}