1use std::collections::HashMap;
13
14use crate::dialects::DialectType;
15use crate::expressions::{
16 BinaryOp, DataType, Expression, Function, IfFunc, ListAggOverflow, Literal, Map, Nvl2Func,
17 Struct, StructField, Subscript,
18};
19use crate::schema::Schema;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
24pub enum TypeCoercionClass {
25 Text = 0,
27 Numeric = 1,
29 Timelike = 2,
31}
32
33impl TypeCoercionClass {
34 pub fn from_data_type(dt: &DataType) -> Option<Self> {
36 match dt {
37 DataType::Char { .. }
39 | DataType::VarChar { .. }
40 | DataType::Text
41 | DataType::Binary { .. }
42 | DataType::VarBinary { .. }
43 | DataType::Blob => Some(TypeCoercionClass::Text),
44
45 DataType::Boolean
47 | DataType::TinyInt { .. }
48 | DataType::SmallInt { .. }
49 | DataType::Int { .. }
50 | DataType::BigInt { .. }
51 | DataType::Float { .. }
52 | DataType::Double { .. }
53 | DataType::Decimal { .. } => Some(TypeCoercionClass::Numeric),
54
55 DataType::Date
57 | DataType::Time { .. }
58 | DataType::Timestamp { .. }
59 | DataType::Interval { .. } => Some(TypeCoercionClass::Timelike),
60
61 _ => None,
63 }
64 }
65}
66
67pub struct TypeAnnotator<'a> {
69 _schema: Option<&'a dyn Schema>,
71 _dialect: Option<DialectType>,
73 annotate_aggregates: bool,
75 function_return_types: HashMap<String, DataType>,
77}
78
79impl<'a> TypeAnnotator<'a> {
80 pub fn new(schema: Option<&'a dyn Schema>, dialect: Option<DialectType>) -> Self {
82 let mut annotator = Self {
83 _schema: schema,
84 _dialect: dialect,
85 annotate_aggregates: true,
86 function_return_types: HashMap::new(),
87 };
88 annotator.init_function_return_types();
89 annotator
90 }
91
92 fn init_function_return_types(&mut self) {
94 self.function_return_types
96 .insert("COUNT".to_string(), DataType::BigInt { length: None });
97 self.function_return_types.insert(
98 "SUM".to_string(),
99 DataType::Decimal {
100 precision: None,
101 scale: None,
102 },
103 );
104 self.function_return_types.insert(
105 "AVG".to_string(),
106 DataType::Double {
107 precision: None,
108 scale: None,
109 },
110 );
111
112 self.function_return_types.insert(
114 "CONCAT".to_string(),
115 DataType::VarChar {
116 length: None,
117 parenthesized_length: false,
118 },
119 );
120 self.function_return_types.insert(
121 "UPPER".to_string(),
122 DataType::VarChar {
123 length: None,
124 parenthesized_length: false,
125 },
126 );
127 self.function_return_types.insert(
128 "LOWER".to_string(),
129 DataType::VarChar {
130 length: None,
131 parenthesized_length: false,
132 },
133 );
134 self.function_return_types.insert(
135 "TRIM".to_string(),
136 DataType::VarChar {
137 length: None,
138 parenthesized_length: false,
139 },
140 );
141 self.function_return_types.insert(
142 "LTRIM".to_string(),
143 DataType::VarChar {
144 length: None,
145 parenthesized_length: false,
146 },
147 );
148 self.function_return_types.insert(
149 "RTRIM".to_string(),
150 DataType::VarChar {
151 length: None,
152 parenthesized_length: false,
153 },
154 );
155 self.function_return_types.insert(
156 "SUBSTRING".to_string(),
157 DataType::VarChar {
158 length: None,
159 parenthesized_length: false,
160 },
161 );
162 self.function_return_types.insert(
163 "SUBSTR".to_string(),
164 DataType::VarChar {
165 length: None,
166 parenthesized_length: false,
167 },
168 );
169 self.function_return_types.insert(
170 "REPLACE".to_string(),
171 DataType::VarChar {
172 length: None,
173 parenthesized_length: false,
174 },
175 );
176 self.function_return_types.insert(
177 "LENGTH".to_string(),
178 DataType::Int {
179 length: None,
180 integer_spelling: false,
181 },
182 );
183 self.function_return_types.insert(
184 "CHAR_LENGTH".to_string(),
185 DataType::Int {
186 length: None,
187 integer_spelling: false,
188 },
189 );
190
191 self.function_return_types.insert(
193 "NOW".to_string(),
194 DataType::Timestamp {
195 precision: None,
196 timezone: false,
197 },
198 );
199 self.function_return_types.insert(
200 "CURRENT_TIMESTAMP".to_string(),
201 DataType::Timestamp {
202 precision: None,
203 timezone: false,
204 },
205 );
206 self.function_return_types
207 .insert("CURRENT_DATE".to_string(), DataType::Date);
208 self.function_return_types.insert(
209 "CURRENT_TIME".to_string(),
210 DataType::Time {
211 precision: None,
212 timezone: false,
213 },
214 );
215 self.function_return_types
216 .insert("DATE".to_string(), DataType::Date);
217 self.function_return_types.insert(
218 "YEAR".to_string(),
219 DataType::Int {
220 length: None,
221 integer_spelling: false,
222 },
223 );
224 self.function_return_types.insert(
225 "MONTH".to_string(),
226 DataType::Int {
227 length: None,
228 integer_spelling: false,
229 },
230 );
231 self.function_return_types.insert(
232 "DAY".to_string(),
233 DataType::Int {
234 length: None,
235 integer_spelling: false,
236 },
237 );
238 self.function_return_types.insert(
239 "HOUR".to_string(),
240 DataType::Int {
241 length: None,
242 integer_spelling: false,
243 },
244 );
245 self.function_return_types.insert(
246 "MINUTE".to_string(),
247 DataType::Int {
248 length: None,
249 integer_spelling: false,
250 },
251 );
252 self.function_return_types.insert(
253 "SECOND".to_string(),
254 DataType::Int {
255 length: None,
256 integer_spelling: false,
257 },
258 );
259 self.function_return_types.insert(
260 "EXTRACT".to_string(),
261 DataType::Int {
262 length: None,
263 integer_spelling: false,
264 },
265 );
266 self.function_return_types.insert(
267 "DATE_DIFF".to_string(),
268 DataType::Int {
269 length: None,
270 integer_spelling: false,
271 },
272 );
273 self.function_return_types.insert(
274 "DATEDIFF".to_string(),
275 DataType::Int {
276 length: None,
277 integer_spelling: false,
278 },
279 );
280
281 self.function_return_types.insert(
283 "ABS".to_string(),
284 DataType::Double {
285 precision: None,
286 scale: None,
287 },
288 );
289 self.function_return_types.insert(
290 "ROUND".to_string(),
291 DataType::Double {
292 precision: None,
293 scale: None,
294 },
295 );
296 self.function_return_types.insert(
297 "DATE_FORMAT".to_string(),
298 DataType::VarChar {
299 length: None,
300 parenthesized_length: false,
301 },
302 );
303 self.function_return_types.insert(
304 "FORMAT_DATE".to_string(),
305 DataType::VarChar {
306 length: None,
307 parenthesized_length: false,
308 },
309 );
310 self.function_return_types.insert(
311 "TIME_TO_STR".to_string(),
312 DataType::VarChar {
313 length: None,
314 parenthesized_length: false,
315 },
316 );
317 self.function_return_types.insert(
318 "SQRT".to_string(),
319 DataType::Double {
320 precision: None,
321 scale: None,
322 },
323 );
324 self.function_return_types.insert(
325 "POWER".to_string(),
326 DataType::Double {
327 precision: None,
328 scale: None,
329 },
330 );
331 self.function_return_types.insert(
332 "MOD".to_string(),
333 DataType::Int {
334 length: None,
335 integer_spelling: false,
336 },
337 );
338 self.function_return_types.insert(
339 "LOG".to_string(),
340 DataType::Double {
341 precision: None,
342 scale: None,
343 },
344 );
345 self.function_return_types.insert(
346 "LN".to_string(),
347 DataType::Double {
348 precision: None,
349 scale: None,
350 },
351 );
352 self.function_return_types.insert(
353 "EXP".to_string(),
354 DataType::Double {
355 precision: None,
356 scale: None,
357 },
358 );
359
360 self.function_return_types
362 .insert("COALESCE".to_string(), DataType::Unknown);
363 self.function_return_types
364 .insert("NULLIF".to_string(), DataType::Unknown);
365 self.function_return_types
366 .insert("GREATEST".to_string(), DataType::Unknown);
367 self.function_return_types
368 .insert("LEAST".to_string(), DataType::Unknown);
369 }
370
371 pub fn annotate(&mut self, expr: &Expression) -> Option<DataType> {
373 match expr {
374 Expression::Literal(lit) => self.annotate_literal(lit),
376 Expression::Boolean(_) => Some(DataType::Boolean),
377 Expression::Null(_) => None, Expression::Add(op)
381 | Expression::Sub(op)
382 | Expression::Mul(op)
383 | Expression::Div(op)
384 | Expression::Mod(op) => self.annotate_arithmetic(op),
385
386 Expression::Eq(_)
388 | Expression::Neq(_)
389 | Expression::Lt(_)
390 | Expression::Lte(_)
391 | Expression::Gt(_)
392 | Expression::Gte(_)
393 | Expression::Like(_)
394 | Expression::ILike(_) => Some(DataType::Boolean),
395
396 Expression::And(_) | Expression::Or(_) | Expression::Not(_) => Some(DataType::Boolean),
398
399 Expression::Between(_)
401 | Expression::In(_)
402 | Expression::IsNull(_)
403 | Expression::IsTrue(_)
404 | Expression::IsFalse(_)
405 | Expression::Is(_)
406 | Expression::Exists(_) => Some(DataType::Boolean),
407
408 Expression::Concat(_) => Some(DataType::VarChar {
410 length: None,
411 parenthesized_length: false,
412 }),
413
414 Expression::BitwiseAnd(_)
416 | Expression::BitwiseOr(_)
417 | Expression::BitwiseXor(_)
418 | Expression::BitwiseNot(_) => Some(DataType::BigInt { length: None }),
419
420 Expression::Neg(op) => self.annotate(&op.this),
422
423 Expression::Function(func) => self.annotate_function(func),
425 Expression::IfFunc(if_func) => self.annotate_if_func(if_func),
426 Expression::Nvl2(nvl2) => self.annotate_nvl2(nvl2),
427
428 Expression::Count(_) => Some(DataType::BigInt { length: None }),
430 Expression::Sum(agg) => self.annotate_sum(&agg.this),
431 Expression::SumIf(f) => self.annotate_sum(&f.this),
432 Expression::Avg(_) => Some(DataType::Double {
433 precision: None,
434 scale: None,
435 }),
436 Expression::Min(agg) => self.annotate(&agg.this),
437 Expression::Max(agg) => self.annotate(&agg.this),
438 Expression::GroupConcat(_) | Expression::StringAgg(_) | Expression::ListAgg(_) => {
439 Some(DataType::VarChar {
440 length: None,
441 parenthesized_length: false,
442 })
443 }
444
445 Expression::AggregateFunction(agg) => {
447 if !self.annotate_aggregates {
448 return None;
449 }
450 let func_name = agg.name.to_uppercase();
451 self.get_aggregate_return_type(&func_name, &agg.args)
452 }
453
454 Expression::Column(col) => {
456 if let Some(schema) = &self._schema {
457 let table_name = col.table.as_ref().map(|t| t.name.as_str()).unwrap_or("");
458 schema.get_column_type(table_name, &col.name.name).ok()
459 } else {
460 None
461 }
462 }
463
464 Expression::Cast(cast) => Some(cast.to.clone()),
466 Expression::SafeCast(cast) => Some(cast.to.clone()),
467 Expression::TryCast(cast) => Some(cast.to.clone()),
468
469 Expression::Subquery(subq) => {
471 if let Expression::Select(select) = &subq.this {
472 if let Some(first) = select.expressions.first() {
473 self.annotate(first)
474 } else {
475 None
476 }
477 } else {
478 None
479 }
480 }
481
482 Expression::Case(case) => {
484 if let Some(else_expr) = &case.else_ {
485 self.annotate(else_expr)
486 } else if let Some((_, then_expr)) = case.whens.first() {
487 self.annotate(then_expr)
488 } else {
489 None
490 }
491 }
492
493 Expression::Array(arr) => {
495 if let Some(first) = arr.expressions.first() {
496 if let Some(elem_type) = self.annotate(first) {
497 Some(DataType::Array {
498 element_type: Box::new(elem_type),
499 dimension: None,
500 })
501 } else {
502 Some(DataType::Array {
503 element_type: Box::new(DataType::Unknown),
504 dimension: None,
505 })
506 }
507 } else {
508 Some(DataType::Array {
509 element_type: Box::new(DataType::Unknown),
510 dimension: None,
511 })
512 }
513 }
514
515 Expression::Interval(_) => Some(DataType::Interval {
517 unit: None,
518 to: None,
519 }),
520
521 Expression::WindowFunction(window) => self.annotate(&window.this),
523
524 Expression::CurrentDate(_) => Some(DataType::Date),
526 Expression::CurrentTime(_) => Some(DataType::Time {
527 precision: None,
528 timezone: false,
529 }),
530 Expression::CurrentTimestamp(_) | Expression::CurrentTimestampLTZ(_) => {
531 Some(DataType::Timestamp {
532 precision: None,
533 timezone: false,
534 })
535 }
536
537 Expression::DateAdd(_)
539 | Expression::DateSub(_)
540 | Expression::ToDate(_)
541 | Expression::Date(_) => Some(DataType::Date),
542 Expression::DateDiff(_) | Expression::Extract(_) => Some(DataType::Int {
543 length: None,
544 integer_spelling: false,
545 }),
546 Expression::ToTimestamp(_) => Some(DataType::Timestamp {
547 precision: None,
548 timezone: false,
549 }),
550
551 Expression::Upper(_)
553 | Expression::Lower(_)
554 | Expression::Trim(_)
555 | Expression::LTrim(_)
556 | Expression::RTrim(_)
557 | Expression::Replace(_)
558 | Expression::Substring(_)
559 | Expression::Reverse(_)
560 | Expression::Left(_)
561 | Expression::Right(_)
562 | Expression::Repeat(_)
563 | Expression::Lpad(_)
564 | Expression::Rpad(_)
565 | Expression::ConcatWs(_)
566 | Expression::Overlay(_) => Some(DataType::VarChar {
567 length: None,
568 parenthesized_length: false,
569 }),
570 Expression::Length(_) => Some(DataType::Int {
571 length: None,
572 integer_spelling: false,
573 }),
574
575 Expression::Abs(_)
577 | Expression::Sqrt(_)
578 | Expression::Cbrt(_)
579 | Expression::Ln(_)
580 | Expression::Exp(_)
581 | Expression::Power(_)
582 | Expression::Log(_) => Some(DataType::Double {
583 precision: None,
584 scale: None,
585 }),
586 Expression::Round(_) => Some(DataType::Double {
587 precision: None,
588 scale: None,
589 }),
590 Expression::Floor(f) => self.annotate_math_function(&f.this),
591 Expression::Ceil(f) => self.annotate_math_function(&f.this),
592 Expression::Sign(s) => self.annotate(&s.this),
593 Expression::DateFormat(_) | Expression::FormatDate(_) | Expression::TimeToStr(_) => {
594 Some(DataType::VarChar {
595 length: None,
596 parenthesized_length: false,
597 })
598 }
599
600 Expression::Greatest(v) | Expression::Least(v) => self.coerce_arg_types(&v.expressions),
602
603 Expression::Alias(alias) => self.annotate(&alias.this),
605
606 Expression::Select(_) => None,
608
609 Expression::Subscript(sub) => self.annotate_subscript(sub),
613
614 Expression::Dot(_) => None,
616
617 Expression::Struct(s) => self.annotate_struct(s),
621
622 Expression::Map(map) => self.annotate_map(map),
626 Expression::MapFromEntries(mfe) => {
627 if let Some(DataType::Array { element_type, .. }) = self.annotate(&mfe.this) {
629 if let DataType::Struct { fields, .. } = *element_type {
630 if fields.len() >= 2 {
631 return Some(DataType::Map {
632 key_type: Box::new(fields[0].data_type.clone()),
633 value_type: Box::new(fields[1].data_type.clone()),
634 });
635 }
636 }
637 }
638 Some(DataType::Map {
639 key_type: Box::new(DataType::Unknown),
640 value_type: Box::new(DataType::Unknown),
641 })
642 }
643
644 Expression::Union(union) => self.annotate_set_operation(&union.left, &union.right),
648 Expression::Intersect(intersect) => {
649 self.annotate_set_operation(&intersect.left, &intersect.right)
650 }
651 Expression::Except(except) => self.annotate_set_operation(&except.left, &except.right),
652
653 Expression::Lateral(lateral) => {
657 self.annotate(&lateral.this)
659 }
660 Expression::LateralView(lv) => {
661 self.annotate_lateral_view(lv)
663 }
664 Expression::Unnest(unnest) => {
665 if let Some(DataType::Array { element_type, .. }) = self.annotate(&unnest.this) {
667 Some(*element_type)
668 } else {
669 None
670 }
671 }
672 Expression::Explode(explode) => {
673 if let Some(DataType::Array { element_type, .. }) = self.annotate(&explode.this) {
675 Some(*element_type)
676 } else if let Some(DataType::Map {
677 key_type,
678 value_type,
679 }) = self.annotate(&explode.this)
680 {
681 Some(DataType::Struct {
683 fields: vec![
684 StructField::new("key".to_string(), *key_type),
685 StructField::new("value".to_string(), *value_type),
686 ],
687 nested: false,
688 })
689 } else {
690 None
691 }
692 }
693 Expression::ExplodeOuter(explode) => {
694 if let Some(DataType::Array { element_type, .. }) = self.annotate(&explode.this) {
696 Some(*element_type)
697 } else {
698 None
699 }
700 }
701 Expression::GenerateSeries(gs) => {
702 if let Some(ref start) = gs.start {
704 self.annotate(start)
705 } else if let Some(ref end) = gs.end {
706 self.annotate(end)
707 } else {
708 Some(DataType::Int {
709 length: None,
710 integer_spelling: false,
711 })
712 }
713 }
714
715 _ => None,
717 }
718 }
719
720 pub fn annotate_in_place(&mut self, expr: &mut Expression) {
726 self.annotate_children_in_place(expr);
728
729 let dt = self.annotate(expr);
733
734 if let Some(data_type) = dt {
736 expr.set_inferred_type(data_type);
737 }
738 }
739
740 fn annotate_children_in_place(&mut self, expr: &mut Expression) {
742 match expr {
743 Expression::And(op)
745 | Expression::Or(op)
746 | Expression::Add(op)
747 | Expression::Sub(op)
748 | Expression::Mul(op)
749 | Expression::Div(op)
750 | Expression::Mod(op)
751 | Expression::Eq(op)
752 | Expression::Neq(op)
753 | Expression::Lt(op)
754 | Expression::Lte(op)
755 | Expression::Gt(op)
756 | Expression::Gte(op)
757 | Expression::Concat(op)
758 | Expression::BitwiseAnd(op)
759 | Expression::BitwiseOr(op)
760 | Expression::BitwiseXor(op)
761 | Expression::Adjacent(op)
762 | Expression::TsMatch(op)
763 | Expression::PropertyEQ(op)
764 | Expression::ArrayContainsAll(op)
765 | Expression::ArrayContainedBy(op)
766 | Expression::ArrayOverlaps(op)
767 | Expression::JSONBContainsAllTopKeys(op)
768 | Expression::JSONBContainsAnyTopKeys(op)
769 | Expression::JSONBDeleteAtPath(op)
770 | Expression::ExtendsLeft(op)
771 | Expression::ExtendsRight(op)
772 | Expression::Is(op)
773 | Expression::MemberOf(op)
774 | Expression::Match(op)
775 | Expression::NullSafeEq(op)
776 | Expression::NullSafeNeq(op)
777 | Expression::Glob(op)
778 | Expression::BitwiseLeftShift(op)
779 | Expression::BitwiseRightShift(op) => {
780 self.annotate_in_place(&mut op.left);
781 self.annotate_in_place(&mut op.right);
782 }
783
784 Expression::Like(op) | Expression::ILike(op) => {
786 self.annotate_in_place(&mut op.left);
787 self.annotate_in_place(&mut op.right);
788 }
789
790 Expression::Not(op) | Expression::Neg(op) | Expression::BitwiseNot(op) => {
792 self.annotate_in_place(&mut op.this);
793 }
794
795 Expression::Cast(c) | Expression::TryCast(c) | Expression::SafeCast(c) => {
797 self.annotate_in_place(&mut c.this);
798 }
799
800 Expression::Case(c) => {
802 if let Some(ref mut operand) = c.operand {
803 self.annotate_in_place(operand);
804 }
805 for (cond, then_expr) in &mut c.whens {
806 self.annotate_in_place(cond);
807 self.annotate_in_place(then_expr);
808 }
809 if let Some(ref mut else_expr) = c.else_ {
810 self.annotate_in_place(else_expr);
811 }
812 }
813
814 Expression::Alias(a) => {
816 self.annotate_in_place(&mut a.this);
817 }
818
819 Expression::Column(_) => {}
821
822 Expression::Function(f) => {
824 for arg in &mut f.args {
825 self.annotate_in_place(arg);
826 }
827 }
828
829 Expression::IfFunc(f) => {
831 self.annotate_in_place(&mut f.condition);
832 self.annotate_in_place(&mut f.true_value);
833 if let Some(false_value) = &mut f.false_value {
834 self.annotate_in_place(false_value);
835 }
836 }
837 Expression::Nvl2(f) => {
838 self.annotate_in_place(&mut f.this);
839 self.annotate_in_place(&mut f.true_value);
840 self.annotate_in_place(&mut f.false_value);
841 }
842
843 Expression::AggregateFunction(f) => {
845 for arg in &mut f.args {
846 self.annotate_in_place(arg);
847 }
848 }
849
850 Expression::Count(f) => {
852 if let Some(this) = &mut f.this {
853 self.annotate_in_place(this);
854 }
855 if let Some(filter) = &mut f.filter {
856 self.annotate_in_place(filter);
857 }
858 }
859 Expression::GroupConcat(f) => {
860 self.annotate_in_place(&mut f.this);
861 if let Some(separator) = &mut f.separator {
862 self.annotate_in_place(separator);
863 }
864 if let Some(order_by) = &mut f.order_by {
865 for ordered in order_by {
866 self.annotate_in_place(&mut ordered.this);
867 }
868 }
869 if let Some(filter) = &mut f.filter {
870 self.annotate_in_place(filter);
871 }
872 }
873 Expression::StringAgg(f) => {
874 self.annotate_in_place(&mut f.this);
875 if let Some(separator) = &mut f.separator {
876 self.annotate_in_place(separator);
877 }
878 if let Some(order_by) = &mut f.order_by {
879 for ordered in order_by {
880 self.annotate_in_place(&mut ordered.this);
881 }
882 }
883 if let Some(filter) = &mut f.filter {
884 self.annotate_in_place(filter);
885 }
886 if let Some(limit) = &mut f.limit {
887 self.annotate_in_place(limit);
888 }
889 }
890 Expression::ListAgg(f) => {
891 self.annotate_in_place(&mut f.this);
892 if let Some(separator) = &mut f.separator {
893 self.annotate_in_place(separator);
894 }
895 if let Some(order_by) = &mut f.order_by {
896 for ordered in order_by {
897 self.annotate_in_place(&mut ordered.this);
898 }
899 }
900 if let Some(filter) = &mut f.filter {
901 self.annotate_in_place(filter);
902 }
903 if let Some(ListAggOverflow::Truncate {
904 filler: Some(filler),
905 ..
906 }) = &mut f.on_overflow
907 {
908 self.annotate_in_place(filler);
909 }
910 }
911 Expression::SumIf(f) => {
912 self.annotate_in_place(&mut f.this);
913 self.annotate_in_place(&mut f.condition);
914 if let Some(filter) = &mut f.filter {
915 self.annotate_in_place(filter);
916 }
917 }
918
919 Expression::WindowFunction(w) => {
921 self.annotate_in_place(&mut w.this);
922 }
923
924 Expression::Subquery(s) => {
926 self.annotate_in_place(&mut s.this);
927 }
928
929 Expression::Upper(f)
931 | Expression::Lower(f)
932 | Expression::Length(f)
933 | Expression::LTrim(f)
934 | Expression::RTrim(f)
935 | Expression::Reverse(f)
936 | Expression::Abs(f)
937 | Expression::Sqrt(f)
938 | Expression::Cbrt(f)
939 | Expression::Ln(f)
940 | Expression::Exp(f)
941 | Expression::Sign(f)
942 | Expression::Date(f)
943 | Expression::Time(f)
944 | Expression::Explode(f)
945 | Expression::ExplodeOuter(f)
946 | Expression::MapFromEntries(f)
947 | Expression::MapKeys(f)
948 | Expression::MapValues(f)
949 | Expression::ArrayLength(f)
950 | Expression::ArraySize(f)
951 | Expression::Cardinality(f)
952 | Expression::ArrayReverse(f)
953 | Expression::ArrayDistinct(f)
954 | Expression::ArrayFlatten(f)
955 | Expression::ArrayCompact(f)
956 | Expression::ToArray(f)
957 | Expression::JsonArrayLength(f)
958 | Expression::JsonKeys(f)
959 | Expression::JsonType(f)
960 | Expression::ParseJson(f)
961 | Expression::ToJson(f)
962 | Expression::Year(f)
963 | Expression::Month(f)
964 | Expression::Day(f)
965 | Expression::Hour(f)
966 | Expression::Minute(f)
967 | Expression::Second(f)
968 | Expression::Initcap(f)
969 | Expression::Ascii(f)
970 | Expression::Chr(f)
971 | Expression::Soundex(f)
972 | Expression::ByteLength(f)
973 | Expression::Hex(f)
974 | Expression::LowerHex(f)
975 | Expression::Unicode(f)
976 | Expression::Typeof(f)
977 | Expression::BitwiseCount(f)
978 | Expression::Epoch(f)
979 | Expression::EpochMs(f)
980 | Expression::Radians(f)
981 | Expression::Degrees(f)
982 | Expression::Sin(f)
983 | Expression::Cos(f)
984 | Expression::Tan(f)
985 | Expression::Asin(f)
986 | Expression::Acos(f)
987 | Expression::Atan(f)
988 | Expression::IsNan(f)
989 | Expression::IsInf(f) => {
990 self.annotate_in_place(&mut f.this);
991 }
992
993 Expression::Power(f)
995 | Expression::NullIf(f)
996 | Expression::IfNull(f)
997 | Expression::Nvl(f)
998 | Expression::Contains(f)
999 | Expression::StartsWith(f)
1000 | Expression::EndsWith(f)
1001 | Expression::Levenshtein(f)
1002 | Expression::ModFunc(f)
1003 | Expression::IntDiv(f)
1004 | Expression::Atan2(f)
1005 | Expression::AddMonths(f)
1006 | Expression::MonthsBetween(f)
1007 | Expression::NextDay(f)
1008 | Expression::UnixToTimeStr(f)
1009 | Expression::ArrayContains(f)
1010 | Expression::ArrayPosition(f)
1011 | Expression::ArrayAppend(f)
1012 | Expression::ArrayPrepend(f)
1013 | Expression::ArrayUnion(f)
1014 | Expression::ArrayExcept(f)
1015 | Expression::ArrayRemove(f)
1016 | Expression::StarMap(f)
1017 | Expression::MapFromArrays(f)
1018 | Expression::MapContainsKey(f)
1019 | Expression::ElementAt(f)
1020 | Expression::JsonMergePatch(f) => {
1021 self.annotate_in_place(&mut f.this);
1022 self.annotate_in_place(&mut f.expression);
1023 }
1024
1025 Expression::Coalesce(f)
1027 | Expression::Greatest(f)
1028 | Expression::Least(f)
1029 | Expression::ArrayConcat(f)
1030 | Expression::ArrayIntersect(f)
1031 | Expression::ArrayZip(f)
1032 | Expression::MapConcat(f)
1033 | Expression::JsonArray(f) => {
1034 for e in &mut f.expressions {
1035 self.annotate_in_place(e);
1036 }
1037 }
1038
1039 Expression::Sum(f)
1041 | Expression::Avg(f)
1042 | Expression::Min(f)
1043 | Expression::Max(f)
1044 | Expression::ArrayAgg(f)
1045 | Expression::CountIf(f)
1046 | Expression::Stddev(f)
1047 | Expression::StddevPop(f)
1048 | Expression::StddevSamp(f)
1049 | Expression::Variance(f)
1050 | Expression::VarPop(f)
1051 | Expression::VarSamp(f)
1052 | Expression::Median(f)
1053 | Expression::Mode(f)
1054 | Expression::First(f)
1055 | Expression::Last(f)
1056 | Expression::AnyValue(f)
1057 | Expression::ApproxDistinct(f)
1058 | Expression::ApproxCountDistinct(f)
1059 | Expression::LogicalAnd(f)
1060 | Expression::LogicalOr(f)
1061 | Expression::Skewness(f)
1062 | Expression::ArrayConcatAgg(f)
1063 | Expression::ArrayUniqueAgg(f)
1064 | Expression::BoolXorAgg(f)
1065 | Expression::BitwiseAndAgg(f)
1066 | Expression::BitwiseOrAgg(f)
1067 | Expression::BitwiseXorAgg(f) => {
1068 self.annotate_in_place(&mut f.this);
1069 }
1070
1071 Expression::Select(s) => {
1073 for e in &mut s.expressions {
1074 self.annotate_in_place(e);
1075 }
1076 }
1077
1078 _ => {}
1080 }
1081 }
1082
1083 fn annotate_math_function(&mut self, arg: &Expression) -> Option<DataType> {
1086 let input_type = self.annotate(arg)?;
1087 match input_type {
1088 DataType::TinyInt { .. }
1089 | DataType::SmallInt { .. }
1090 | DataType::Int { .. }
1091 | DataType::BigInt { .. } => Some(DataType::Double {
1092 precision: None,
1093 scale: None,
1094 }),
1095 other => Some(other),
1096 }
1097 }
1098
1099 fn annotate_subscript(&mut self, sub: &Subscript) -> Option<DataType> {
1101 let base_type = self.annotate(&sub.this)?;
1102
1103 match base_type {
1104 DataType::Array { element_type, .. } => Some(*element_type),
1105 DataType::Map { value_type, .. } => Some(*value_type),
1106 DataType::Json | DataType::JsonB => Some(DataType::Json), DataType::VarChar { .. } | DataType::Text => {
1108 Some(DataType::VarChar {
1110 length: Some(1),
1111 parenthesized_length: false,
1112 })
1113 }
1114 _ => None,
1115 }
1116 }
1117
1118 fn annotate_struct(&mut self, s: &Struct) -> Option<DataType> {
1120 let fields: Vec<StructField> = s
1121 .fields
1122 .iter()
1123 .map(|(name, expr)| {
1124 let field_type = self.annotate(expr).unwrap_or(DataType::Unknown);
1125 StructField::new(name.clone().unwrap_or_default(), field_type)
1126 })
1127 .collect();
1128 Some(DataType::Struct {
1129 fields,
1130 nested: false,
1131 })
1132 }
1133
1134 fn annotate_map(&mut self, map: &Map) -> Option<DataType> {
1136 let key_type = if let Some(first_key) = map.keys.first() {
1137 self.annotate(first_key).unwrap_or(DataType::Unknown)
1138 } else {
1139 DataType::Unknown
1140 };
1141
1142 let value_type = if let Some(first_value) = map.values.first() {
1143 self.annotate(first_value).unwrap_or(DataType::Unknown)
1144 } else {
1145 DataType::Unknown
1146 };
1147
1148 Some(DataType::Map {
1149 key_type: Box::new(key_type),
1150 value_type: Box::new(value_type),
1151 })
1152 }
1153
1154 fn annotate_set_operation(
1157 &mut self,
1158 _left: &Expression,
1159 _right: &Expression,
1160 ) -> Option<DataType> {
1161 None
1165 }
1166
1167 fn annotate_lateral_view(&mut self, lv: &crate::expressions::LateralView) -> Option<DataType> {
1169 self.annotate(&lv.this)
1171 }
1172
1173 fn annotate_literal(&self, lit: &Literal) -> Option<DataType> {
1175 match lit {
1176 Literal::String(_)
1177 | Literal::NationalString(_)
1178 | Literal::TripleQuotedString(_, _)
1179 | Literal::EscapeString(_)
1180 | Literal::DollarString(_)
1181 | Literal::RawString(_) => Some(DataType::VarChar {
1182 length: None,
1183 parenthesized_length: false,
1184 }),
1185 Literal::Number(n) => {
1186 if n.contains('.') || n.contains('e') || n.contains('E') {
1188 Some(DataType::Double {
1189 precision: None,
1190 scale: None,
1191 })
1192 } else {
1193 if let Ok(_) = n.parse::<i32>() {
1195 Some(DataType::Int {
1196 length: None,
1197 integer_spelling: false,
1198 })
1199 } else {
1200 Some(DataType::BigInt { length: None })
1201 }
1202 }
1203 }
1204 Literal::HexString(_) | Literal::BitString(_) | Literal::ByteString(_) => {
1205 Some(DataType::VarBinary { length: None })
1206 }
1207 Literal::HexNumber(_) => Some(DataType::BigInt { length: None }),
1208 Literal::Date(_) => Some(DataType::Date),
1209 Literal::Time(_) => Some(DataType::Time {
1210 precision: None,
1211 timezone: false,
1212 }),
1213 Literal::Timestamp(_) => Some(DataType::Timestamp {
1214 precision: None,
1215 timezone: false,
1216 }),
1217 Literal::Datetime(_) => Some(DataType::Custom {
1218 name: "DATETIME".to_string(),
1219 }),
1220 }
1221 }
1222
1223 fn annotate_arithmetic(&mut self, op: &BinaryOp) -> Option<DataType> {
1225 let left_type = self.annotate(&op.left);
1226 let right_type = self.annotate(&op.right);
1227
1228 match (left_type, right_type) {
1229 (Some(l), Some(r)) => self.coerce_types(&l, &r),
1230 (Some(t), None) | (None, Some(t)) => Some(t),
1231 (None, None) => None,
1232 }
1233 }
1234
1235 fn annotate_function(&mut self, func: &Function) -> Option<DataType> {
1237 let func_name = func.name.to_uppercase();
1238
1239 if let Some(return_type) = self.function_return_types.get(&func_name) {
1241 if *return_type != DataType::Unknown {
1242 return Some(return_type.clone());
1243 }
1244 }
1245
1246 match func_name.as_str() {
1248 "COALESCE" | "IFNULL" | "NVL" | "ISNULL" => {
1249 for arg in &func.args {
1251 if let Some(arg_type) = self.annotate(arg) {
1252 return Some(arg_type);
1253 }
1254 }
1255 None
1256 }
1257 "NULLIF" => {
1258 func.args.first().and_then(|arg| self.annotate(arg))
1260 }
1261 "GREATEST" | "LEAST" => {
1262 self.coerce_arg_types(&func.args)
1264 }
1265 "IF" | "IIF" => {
1266 if func.args.len() >= 2 {
1268 self.annotate(&func.args[1])
1269 } else {
1270 None
1271 }
1272 }
1273 _ => {
1274 func.args.first().and_then(|arg| self.annotate(arg))
1276 }
1277 }
1278 }
1279
1280 fn annotate_if_func(&mut self, func: &IfFunc) -> Option<DataType> {
1282 let true_type = self.annotate(&func.true_value);
1283 let false_type = func
1284 .false_value
1285 .as_ref()
1286 .and_then(|expr| self.annotate(expr));
1287
1288 match (true_type, false_type) {
1289 (Some(left), Some(right)) => self.coerce_types(&left, &right),
1290 (Some(dt), None) | (None, Some(dt)) => Some(dt),
1291 (None, None) => None,
1292 }
1293 }
1294
1295 fn annotate_nvl2(&mut self, func: &Nvl2Func) -> Option<DataType> {
1297 let true_type = self.annotate(&func.true_value);
1298 let false_type = self.annotate(&func.false_value);
1299
1300 match (true_type, false_type) {
1301 (Some(left), Some(right)) => self.coerce_types(&left, &right),
1302 (Some(dt), None) | (None, Some(dt)) => Some(dt),
1303 (None, None) => None,
1304 }
1305 }
1306
1307 fn get_aggregate_return_type(
1309 &mut self,
1310 func_name: &str,
1311 args: &[Expression],
1312 ) -> Option<DataType> {
1313 match func_name {
1314 "COUNT" | "COUNT_IF" => Some(DataType::BigInt { length: None }),
1315 "SUM_IF" => {
1316 if let Some(arg) = args.first() {
1317 self.annotate_sum(arg)
1318 } else {
1319 Some(DataType::Decimal {
1320 precision: None,
1321 scale: None,
1322 })
1323 }
1324 }
1325 "SUM" => {
1326 if let Some(arg) = args.first() {
1327 self.annotate_sum(arg)
1328 } else {
1329 Some(DataType::Decimal {
1330 precision: None,
1331 scale: None,
1332 })
1333 }
1334 }
1335 "AVG" => Some(DataType::Double {
1336 precision: None,
1337 scale: None,
1338 }),
1339 "MIN" | "MAX" => {
1340 args.first().and_then(|arg| self.annotate(arg))
1342 }
1343 "STRING_AGG" | "GROUP_CONCAT" | "LISTAGG" | "ARRAY_AGG" => Some(DataType::VarChar {
1344 length: None,
1345 parenthesized_length: false,
1346 }),
1347 "BOOL_AND" | "BOOL_OR" | "EVERY" | "ANY" | "SOME" => Some(DataType::Boolean),
1348 "BIT_AND" | "BIT_OR" | "BIT_XOR" => Some(DataType::BigInt { length: None }),
1349 "STDDEV" | "STDDEV_POP" | "STDDEV_SAMP" | "VARIANCE" | "VAR_POP" | "VAR_SAMP" => {
1350 Some(DataType::Double {
1351 precision: None,
1352 scale: None,
1353 })
1354 }
1355 "PERCENTILE_CONT" | "PERCENTILE_DISC" | "MEDIAN" => {
1356 args.first().and_then(|arg| self.annotate(arg))
1357 }
1358 _ => None,
1359 }
1360 }
1361
1362 fn annotate_sum(&mut self, arg: &Expression) -> Option<DataType> {
1364 match self.annotate(arg) {
1365 Some(DataType::TinyInt { .. })
1366 | Some(DataType::SmallInt { .. })
1367 | Some(DataType::Int { .. }) => Some(DataType::BigInt { length: None }),
1368 Some(DataType::BigInt { .. }) => Some(DataType::BigInt { length: None }),
1369 Some(DataType::Float { .. }) | Some(DataType::Double { .. }) => {
1370 Some(DataType::Double {
1371 precision: None,
1372 scale: None,
1373 })
1374 }
1375 Some(DataType::Decimal { precision, scale }) => {
1376 Some(DataType::Decimal { precision, scale })
1377 }
1378 _ => Some(DataType::Decimal {
1379 precision: None,
1380 scale: None,
1381 }),
1382 }
1383 }
1384
1385 fn coerce_arg_types(&mut self, args: &[Expression]) -> Option<DataType> {
1387 let mut result_type: Option<DataType> = None;
1388 for arg in args {
1389 if let Some(arg_type) = self.annotate(arg) {
1390 result_type = match result_type {
1391 Some(t) => self.coerce_types(&t, &arg_type),
1392 None => Some(arg_type),
1393 };
1394 }
1395 }
1396 result_type
1397 }
1398
1399 fn coerce_types(&self, left: &DataType, right: &DataType) -> Option<DataType> {
1401 if left == right {
1403 return Some(left.clone());
1404 }
1405
1406 match (left, right) {
1408 (DataType::Date, DataType::Interval { .. })
1409 | (DataType::Interval { .. }, DataType::Date) => return Some(DataType::Date),
1410 (
1411 DataType::Timestamp {
1412 precision,
1413 timezone,
1414 },
1415 DataType::Interval { .. },
1416 )
1417 | (
1418 DataType::Interval { .. },
1419 DataType::Timestamp {
1420 precision,
1421 timezone,
1422 },
1423 ) => {
1424 return Some(DataType::Timestamp {
1425 precision: *precision,
1426 timezone: *timezone,
1427 });
1428 }
1429 _ => {}
1430 }
1431
1432 let left_class = TypeCoercionClass::from_data_type(left);
1434 let right_class = TypeCoercionClass::from_data_type(right);
1435
1436 match (left_class, right_class) {
1437 (Some(lc), Some(rc)) if lc == rc => {
1439 if lc == TypeCoercionClass::Numeric {
1441 Some(self.wider_numeric_type(left, right))
1442 } else {
1443 Some(left.clone())
1445 }
1446 }
1447 (Some(lc), Some(rc)) => {
1449 if lc > rc {
1450 Some(left.clone())
1451 } else {
1452 Some(right.clone())
1453 }
1454 }
1455 (Some(_), None) => Some(left.clone()),
1457 (None, Some(_)) => Some(right.clone()),
1458 (None, None) => Some(DataType::Unknown),
1460 }
1461 }
1462
1463 fn wider_numeric_type(&self, left: &DataType, right: &DataType) -> DataType {
1465 let order = |dt: &DataType| -> u8 {
1466 match dt {
1467 DataType::Boolean => 0,
1468 DataType::TinyInt { .. } => 1,
1469 DataType::SmallInt { .. } => 2,
1470 DataType::Int { .. } => 3,
1471 DataType::BigInt { .. } => 4,
1472 DataType::Float { .. } => 5,
1473 DataType::Double { .. } => 6,
1474 DataType::Decimal { .. } => 7,
1475 _ => 0,
1476 }
1477 };
1478
1479 if order(left) >= order(right) {
1480 left.clone()
1481 } else {
1482 right.clone()
1483 }
1484 }
1485}
1486
1487pub fn annotate_types(
1493 expr: &mut Expression,
1494 schema: Option<&dyn Schema>,
1495 dialect: Option<DialectType>,
1496) {
1497 let mut annotator = TypeAnnotator::new(schema, dialect);
1498 annotator.annotate_in_place(expr);
1499}
1500
1501#[cfg(test)]
1502mod tests {
1503 use super::*;
1504 use crate::expressions::{BooleanLiteral, Cast, Null};
1505 use crate::{parse_one, DialectType, MappingSchema, Schema};
1506
1507 fn make_int_literal(val: i64) -> Expression {
1508 Expression::Literal(Literal::Number(val.to_string()))
1509 }
1510
1511 fn make_float_literal(val: f64) -> Expression {
1512 Expression::Literal(Literal::Number(val.to_string()))
1513 }
1514
1515 fn make_string_literal(val: &str) -> Expression {
1516 Expression::Literal(Literal::String(val.to_string()))
1517 }
1518
1519 fn make_bool_literal(val: bool) -> Expression {
1520 Expression::Boolean(BooleanLiteral { value: val })
1521 }
1522
1523 #[test]
1524 fn test_literal_types() {
1525 let mut annotator = TypeAnnotator::new(None, None);
1526
1527 let int_expr = make_int_literal(42);
1529 assert_eq!(
1530 annotator.annotate(&int_expr),
1531 Some(DataType::Int {
1532 length: None,
1533 integer_spelling: false
1534 })
1535 );
1536
1537 let float_expr = make_float_literal(3.14);
1539 assert_eq!(
1540 annotator.annotate(&float_expr),
1541 Some(DataType::Double {
1542 precision: None,
1543 scale: None
1544 })
1545 );
1546
1547 let string_expr = make_string_literal("hello");
1549 assert_eq!(
1550 annotator.annotate(&string_expr),
1551 Some(DataType::VarChar {
1552 length: None,
1553 parenthesized_length: false
1554 })
1555 );
1556
1557 let bool_expr = make_bool_literal(true);
1559 assert_eq!(annotator.annotate(&bool_expr), Some(DataType::Boolean));
1560
1561 let null_expr = Expression::Null(Null);
1563 assert_eq!(annotator.annotate(&null_expr), None);
1564 }
1565
1566 #[test]
1567 fn test_comparison_types() {
1568 let mut annotator = TypeAnnotator::new(None, None);
1569
1570 let cmp = Expression::Gt(Box::new(BinaryOp::new(
1572 make_int_literal(1),
1573 make_int_literal(2),
1574 )));
1575 assert_eq!(annotator.annotate(&cmp), Some(DataType::Boolean));
1576
1577 let eq = Expression::Eq(Box::new(BinaryOp::new(
1579 make_string_literal("a"),
1580 make_string_literal("b"),
1581 )));
1582 assert_eq!(annotator.annotate(&eq), Some(DataType::Boolean));
1583 }
1584
1585 #[test]
1586 fn test_arithmetic_types() {
1587 let mut annotator = TypeAnnotator::new(None, None);
1588
1589 let add_int = Expression::Add(Box::new(BinaryOp::new(
1591 make_int_literal(1),
1592 make_int_literal(2),
1593 )));
1594 assert_eq!(
1595 annotator.annotate(&add_int),
1596 Some(DataType::Int {
1597 length: None,
1598 integer_spelling: false
1599 })
1600 );
1601
1602 let add_mixed = Expression::Add(Box::new(BinaryOp::new(
1604 make_int_literal(1),
1605 make_float_literal(2.5), )));
1607 assert_eq!(
1608 annotator.annotate(&add_mixed),
1609 Some(DataType::Double {
1610 precision: None,
1611 scale: None
1612 })
1613 );
1614 }
1615
1616 #[test]
1617 fn test_string_concat_type() {
1618 let mut annotator = TypeAnnotator::new(None, None);
1619
1620 let concat = Expression::Concat(Box::new(BinaryOp::new(
1622 make_string_literal("hello"),
1623 make_string_literal(" world"),
1624 )));
1625 assert_eq!(
1626 annotator.annotate(&concat),
1627 Some(DataType::VarChar {
1628 length: None,
1629 parenthesized_length: false
1630 })
1631 );
1632 }
1633
1634 #[test]
1635 fn test_cast_type() {
1636 let mut annotator = TypeAnnotator::new(None, None);
1637
1638 let cast = Expression::Cast(Box::new(Cast {
1640 this: make_int_literal(1),
1641 to: DataType::VarChar {
1642 length: Some(10),
1643 parenthesized_length: false,
1644 },
1645 trailing_comments: vec![],
1646 double_colon_syntax: false,
1647 format: None,
1648 default: None,
1649 inferred_type: None,
1650 }));
1651 assert_eq!(
1652 annotator.annotate(&cast),
1653 Some(DataType::VarChar {
1654 length: Some(10),
1655 parenthesized_length: false
1656 })
1657 );
1658 }
1659
1660 #[test]
1661 fn test_function_types() {
1662 let mut annotator = TypeAnnotator::new(None, None);
1663
1664 let count =
1666 Expression::Function(Box::new(Function::new("COUNT", vec![make_int_literal(1)])));
1667 assert_eq!(
1668 annotator.annotate(&count),
1669 Some(DataType::BigInt { length: None })
1670 );
1671
1672 let upper = Expression::Function(Box::new(Function::new(
1674 "UPPER",
1675 vec![make_string_literal("hello")],
1676 )));
1677 assert_eq!(
1678 annotator.annotate(&upper),
1679 Some(DataType::VarChar {
1680 length: None,
1681 parenthesized_length: false
1682 })
1683 );
1684
1685 let now = Expression::Function(Box::new(Function::new("NOW", vec![])));
1687 assert_eq!(
1688 annotator.annotate(&now),
1689 Some(DataType::Timestamp {
1690 precision: None,
1691 timezone: false
1692 })
1693 );
1694 }
1695
1696 #[test]
1697 fn test_coalesce_type_inference() {
1698 let mut annotator = TypeAnnotator::new(None, None);
1699
1700 let coalesce = Expression::Function(Box::new(Function::new(
1702 "COALESCE",
1703 vec![Expression::Null(Null), make_int_literal(1)],
1704 )));
1705 assert_eq!(
1706 annotator.annotate(&coalesce),
1707 Some(DataType::Int {
1708 length: None,
1709 integer_spelling: false
1710 })
1711 );
1712 }
1713
1714 #[test]
1715 fn test_type_coercion_class() {
1716 assert_eq!(
1718 TypeCoercionClass::from_data_type(&DataType::VarChar {
1719 length: None,
1720 parenthesized_length: false
1721 }),
1722 Some(TypeCoercionClass::Text)
1723 );
1724 assert_eq!(
1725 TypeCoercionClass::from_data_type(&DataType::Text),
1726 Some(TypeCoercionClass::Text)
1727 );
1728
1729 assert_eq!(
1731 TypeCoercionClass::from_data_type(&DataType::Int {
1732 length: None,
1733 integer_spelling: false
1734 }),
1735 Some(TypeCoercionClass::Numeric)
1736 );
1737 assert_eq!(
1738 TypeCoercionClass::from_data_type(&DataType::Double {
1739 precision: None,
1740 scale: None
1741 }),
1742 Some(TypeCoercionClass::Numeric)
1743 );
1744
1745 assert_eq!(
1747 TypeCoercionClass::from_data_type(&DataType::Date),
1748 Some(TypeCoercionClass::Timelike)
1749 );
1750 assert_eq!(
1751 TypeCoercionClass::from_data_type(&DataType::Timestamp {
1752 precision: None,
1753 timezone: false
1754 }),
1755 Some(TypeCoercionClass::Timelike)
1756 );
1757
1758 assert_eq!(TypeCoercionClass::from_data_type(&DataType::Json), None);
1760 }
1761
1762 #[test]
1763 fn test_wider_numeric_type() {
1764 let annotator = TypeAnnotator::new(None, None);
1765
1766 let result = annotator.wider_numeric_type(
1768 &DataType::Int {
1769 length: None,
1770 integer_spelling: false,
1771 },
1772 &DataType::BigInt { length: None },
1773 );
1774 assert_eq!(result, DataType::BigInt { length: None });
1775
1776 let result = annotator.wider_numeric_type(
1778 &DataType::Float {
1779 precision: None,
1780 scale: None,
1781 real_spelling: false,
1782 },
1783 &DataType::Double {
1784 precision: None,
1785 scale: None,
1786 },
1787 );
1788 assert_eq!(
1789 result,
1790 DataType::Double {
1791 precision: None,
1792 scale: None
1793 }
1794 );
1795
1796 let result = annotator.wider_numeric_type(
1798 &DataType::Int {
1799 length: None,
1800 integer_spelling: false,
1801 },
1802 &DataType::Double {
1803 precision: None,
1804 scale: None,
1805 },
1806 );
1807 assert_eq!(
1808 result,
1809 DataType::Double {
1810 precision: None,
1811 scale: None
1812 }
1813 );
1814 }
1815
1816 #[test]
1817 fn test_aggregate_return_types() {
1818 let mut annotator = TypeAnnotator::new(None, None);
1819
1820 let sum_type = annotator.get_aggregate_return_type("SUM", &[make_int_literal(1)]);
1822 assert_eq!(sum_type, Some(DataType::BigInt { length: None }));
1823
1824 let avg_type = annotator.get_aggregate_return_type("AVG", &[make_int_literal(1)]);
1826 assert_eq!(
1827 avg_type,
1828 Some(DataType::Double {
1829 precision: None,
1830 scale: None
1831 })
1832 );
1833
1834 let min_type = annotator.get_aggregate_return_type("MIN", &[make_string_literal("a")]);
1836 assert_eq!(
1837 min_type,
1838 Some(DataType::VarChar {
1839 length: None,
1840 parenthesized_length: false
1841 })
1842 );
1843 }
1844
1845 #[test]
1846 fn test_date_literal_types() {
1847 let mut annotator = TypeAnnotator::new(None, None);
1848
1849 let date_expr = Expression::Literal(Literal::Date("2024-01-15".to_string()));
1851 assert_eq!(annotator.annotate(&date_expr), Some(DataType::Date));
1852
1853 let time_expr = Expression::Literal(Literal::Time("10:30:00".to_string()));
1855 assert_eq!(
1856 annotator.annotate(&time_expr),
1857 Some(DataType::Time {
1858 precision: None,
1859 timezone: false
1860 })
1861 );
1862
1863 let ts_expr = Expression::Literal(Literal::Timestamp("2024-01-15 10:30:00".to_string()));
1865 assert_eq!(
1866 annotator.annotate(&ts_expr),
1867 Some(DataType::Timestamp {
1868 precision: None,
1869 timezone: false
1870 })
1871 );
1872 }
1873
1874 #[test]
1875 fn test_logical_operations() {
1876 let mut annotator = TypeAnnotator::new(None, None);
1877
1878 let and_expr = Expression::And(Box::new(BinaryOp::new(
1880 make_bool_literal(true),
1881 make_bool_literal(false),
1882 )));
1883 assert_eq!(annotator.annotate(&and_expr), Some(DataType::Boolean));
1884
1885 let or_expr = Expression::Or(Box::new(BinaryOp::new(
1887 make_bool_literal(true),
1888 make_bool_literal(false),
1889 )));
1890 assert_eq!(annotator.annotate(&or_expr), Some(DataType::Boolean));
1891
1892 let not_expr = Expression::Not(Box::new(crate::expressions::UnaryOp::new(
1894 make_bool_literal(true),
1895 )));
1896 assert_eq!(annotator.annotate(¬_expr), Some(DataType::Boolean));
1897 }
1898
1899 #[test]
1904 fn test_subscript_array_type() {
1905 let mut annotator = TypeAnnotator::new(None, None);
1906
1907 let arr = Expression::Array(Box::new(crate::expressions::Array {
1909 expressions: vec![make_int_literal(1), make_int_literal(2)],
1910 }));
1911 let subscript = Expression::Subscript(Box::new(crate::expressions::Subscript {
1912 this: arr,
1913 index: make_int_literal(0),
1914 }));
1915 assert_eq!(
1916 annotator.annotate(&subscript),
1917 Some(DataType::Int {
1918 length: None,
1919 integer_spelling: false
1920 })
1921 );
1922 }
1923
1924 #[test]
1925 fn test_subscript_map_type() {
1926 let mut annotator = TypeAnnotator::new(None, None);
1927
1928 let map = Expression::Map(Box::new(crate::expressions::Map {
1930 keys: vec![make_string_literal("a")],
1931 values: vec![make_int_literal(1)],
1932 }));
1933 let subscript = Expression::Subscript(Box::new(crate::expressions::Subscript {
1934 this: map,
1935 index: make_string_literal("a"),
1936 }));
1937 assert_eq!(
1938 annotator.annotate(&subscript),
1939 Some(DataType::Int {
1940 length: None,
1941 integer_spelling: false
1942 })
1943 );
1944 }
1945
1946 #[test]
1947 fn test_struct_type() {
1948 let mut annotator = TypeAnnotator::new(None, None);
1949
1950 let struct_expr = Expression::Struct(Box::new(crate::expressions::Struct {
1952 fields: vec![
1953 (Some("name".to_string()), make_string_literal("Alice")),
1954 (Some("age".to_string()), make_int_literal(30)),
1955 ],
1956 }));
1957 let result = annotator.annotate(&struct_expr);
1958 assert!(matches!(result, Some(DataType::Struct { fields, .. }) if fields.len() == 2));
1959 }
1960
1961 #[test]
1962 fn test_map_type() {
1963 let mut annotator = TypeAnnotator::new(None, None);
1964
1965 let map_expr = Expression::Map(Box::new(crate::expressions::Map {
1967 keys: vec![make_string_literal("a"), make_string_literal("b")],
1968 values: vec![make_int_literal(1), make_int_literal(2)],
1969 }));
1970 let result = annotator.annotate(&map_expr);
1971 assert!(matches!(
1972 result,
1973 Some(DataType::Map { key_type, value_type })
1974 if matches!(*key_type, DataType::VarChar { .. })
1975 && matches!(*value_type, DataType::Int { .. })
1976 ));
1977 }
1978
1979 #[test]
1980 fn test_explode_array_type() {
1981 let mut annotator = TypeAnnotator::new(None, None);
1982
1983 let arr = Expression::Array(Box::new(crate::expressions::Array {
1985 expressions: vec![make_int_literal(1), make_int_literal(2)],
1986 }));
1987 let explode = Expression::Explode(Box::new(crate::expressions::UnaryFunc {
1988 this: arr,
1989 original_name: None,
1990 inferred_type: None,
1991 }));
1992 assert_eq!(
1993 annotator.annotate(&explode),
1994 Some(DataType::Int {
1995 length: None,
1996 integer_spelling: false
1997 })
1998 );
1999 }
2000
2001 #[test]
2002 fn test_unnest_array_type() {
2003 let mut annotator = TypeAnnotator::new(None, None);
2004
2005 let arr = Expression::Array(Box::new(crate::expressions::Array {
2007 expressions: vec![make_string_literal("a"), make_string_literal("b")],
2008 }));
2009 let unnest = Expression::Unnest(Box::new(crate::expressions::UnnestFunc {
2010 this: arr,
2011 expressions: Vec::new(),
2012 with_ordinality: false,
2013 alias: None,
2014 offset_alias: None,
2015 }));
2016 assert_eq!(
2017 annotator.annotate(&unnest),
2018 Some(DataType::VarChar {
2019 length: None,
2020 parenthesized_length: false
2021 })
2022 );
2023 }
2024
2025 #[test]
2026 fn test_set_operation_type() {
2027 let mut annotator = TypeAnnotator::new(None, None);
2028
2029 let select = Expression::Select(Box::new(crate::expressions::Select::default()));
2031 let union = Expression::Union(Box::new(crate::expressions::Union {
2032 left: select.clone(),
2033 right: select.clone(),
2034 all: false,
2035 distinct: false,
2036 with: None,
2037 order_by: None,
2038 limit: None,
2039 offset: None,
2040 by_name: false,
2041 side: None,
2042 kind: None,
2043 corresponding: false,
2044 strict: false,
2045 on_columns: Vec::new(),
2046 distribute_by: None,
2047 sort_by: None,
2048 cluster_by: None,
2049 }));
2050 assert_eq!(annotator.annotate(&union), None);
2051 }
2052
2053 #[test]
2054 fn test_floor_ceil_input_dependent_types() {
2055 use crate::expressions::{CeilFunc, FloorFunc};
2056
2057 let mut annotator = TypeAnnotator::new(None, None);
2058
2059 let floor_int = Expression::Floor(Box::new(FloorFunc {
2061 this: make_int_literal(42),
2062 scale: None,
2063 to: None,
2064 }));
2065 assert_eq!(
2066 annotator.annotate(&floor_int),
2067 Some(DataType::Double {
2068 precision: None,
2069 scale: None,
2070 })
2071 );
2072
2073 let ceil_int = Expression::Ceil(Box::new(CeilFunc {
2074 this: make_int_literal(42),
2075 decimals: None,
2076 to: None,
2077 }));
2078 assert_eq!(
2079 annotator.annotate(&ceil_int),
2080 Some(DataType::Double {
2081 precision: None,
2082 scale: None,
2083 })
2084 );
2085
2086 let floor_float = Expression::Floor(Box::new(FloorFunc {
2088 this: make_float_literal(3.14),
2089 scale: None,
2090 to: None,
2091 }));
2092 assert_eq!(
2093 annotator.annotate(&floor_float),
2094 Some(DataType::Double {
2095 precision: None,
2096 scale: None,
2097 })
2098 );
2099
2100 let floor_fn =
2102 Expression::Function(Box::new(Function::new("FLOOR", vec![make_int_literal(1)])));
2103 assert_eq!(
2104 annotator.annotate(&floor_fn),
2105 Some(DataType::Int {
2106 length: None,
2107 integer_spelling: false,
2108 })
2109 );
2110 }
2111
2112 #[test]
2113 fn test_sign_preserves_input_type() {
2114 use crate::expressions::UnaryFunc;
2115
2116 let mut annotator = TypeAnnotator::new(None, None);
2117
2118 let sign_int = Expression::Sign(Box::new(UnaryFunc {
2120 this: make_int_literal(42),
2121 original_name: None,
2122 inferred_type: None,
2123 }));
2124 assert_eq!(
2125 annotator.annotate(&sign_int),
2126 Some(DataType::Int {
2127 length: None,
2128 integer_spelling: false,
2129 })
2130 );
2131
2132 let sign_float = Expression::Sign(Box::new(UnaryFunc {
2134 this: make_float_literal(3.14),
2135 original_name: None,
2136 inferred_type: None,
2137 }));
2138 assert_eq!(
2139 annotator.annotate(&sign_float),
2140 Some(DataType::Double {
2141 precision: None,
2142 scale: None,
2143 })
2144 );
2145
2146 let sign_cast = Expression::Sign(Box::new(UnaryFunc {
2148 this: Expression::Cast(Box::new(Cast {
2149 this: make_int_literal(42),
2150 to: DataType::Int {
2151 length: None,
2152 integer_spelling: false,
2153 },
2154 format: None,
2155 trailing_comments: Vec::new(),
2156 double_colon_syntax: false,
2157 default: None,
2158 inferred_type: None,
2159 })),
2160 original_name: None,
2161 inferred_type: None,
2162 }));
2163 assert_eq!(
2164 annotator.annotate(&sign_cast),
2165 Some(DataType::Int {
2166 length: None,
2167 integer_spelling: false,
2168 })
2169 );
2170 }
2171
2172 #[test]
2173 fn test_date_format_types() {
2174 use crate::expressions::{DateFormatFunc, TimeToStr};
2175
2176 let mut annotator = TypeAnnotator::new(None, None);
2177
2178 let date_fmt = Expression::DateFormat(Box::new(DateFormatFunc {
2180 this: make_string_literal("2024-01-01"),
2181 format: make_string_literal("%Y-%m-%d"),
2182 }));
2183 assert_eq!(
2184 annotator.annotate(&date_fmt),
2185 Some(DataType::VarChar {
2186 length: None,
2187 parenthesized_length: false,
2188 })
2189 );
2190
2191 let format_date = Expression::FormatDate(Box::new(DateFormatFunc {
2193 this: make_string_literal("2024-01-01"),
2194 format: make_string_literal("%Y-%m-%d"),
2195 }));
2196 assert_eq!(
2197 annotator.annotate(&format_date),
2198 Some(DataType::VarChar {
2199 length: None,
2200 parenthesized_length: false,
2201 })
2202 );
2203
2204 let time_to_str = Expression::TimeToStr(Box::new(TimeToStr {
2206 this: Box::new(make_string_literal("2024-01-01")),
2207 format: "%Y-%m-%d".to_string(),
2208 culture: None,
2209 zone: None,
2210 }));
2211 assert_eq!(
2212 annotator.annotate(&time_to_str),
2213 Some(DataType::VarChar {
2214 length: None,
2215 parenthesized_length: false,
2216 })
2217 );
2218
2219 let date_fmt_fn = Expression::Function(Box::new(Function::new(
2221 "DATE_FORMAT",
2222 vec![
2223 make_string_literal("2024-01-01"),
2224 make_string_literal("%Y-%m-%d"),
2225 ],
2226 )));
2227 assert_eq!(
2228 annotator.annotate(&date_fmt_fn),
2229 Some(DataType::VarChar {
2230 length: None,
2231 parenthesized_length: false,
2232 })
2233 );
2234 }
2235
2236 #[test]
2239 fn test_annotate_in_place_sets_type_on_root() {
2240 let mut expr = Expression::Add(Box::new(BinaryOp::new(
2242 make_int_literal(1),
2243 make_int_literal(2),
2244 )));
2245 annotate_types(&mut expr, None, None);
2246 assert_eq!(
2247 expr.inferred_type(),
2248 Some(&DataType::Int {
2249 length: None,
2250 integer_spelling: false,
2251 })
2252 );
2253 }
2254
2255 #[test]
2256 fn test_annotate_in_place_sets_types_on_children() {
2257 let inner_add = Expression::Add(Box::new(BinaryOp::new(
2260 make_int_literal(1),
2261 make_float_literal(2.5),
2262 )));
2263 let inner_sub = Expression::Sub(Box::new(BinaryOp::new(
2264 make_int_literal(3),
2265 make_int_literal(4),
2266 )));
2267 let mut expr = Expression::Add(Box::new(BinaryOp::new(inner_add, inner_sub)));
2268 annotate_types(&mut expr, None, None);
2269
2270 assert_eq!(
2272 expr.inferred_type(),
2273 Some(&DataType::Double {
2274 precision: None,
2275 scale: None,
2276 })
2277 );
2278
2279 if let Expression::Add(op) = &expr {
2281 assert_eq!(
2283 op.left.inferred_type(),
2284 Some(&DataType::Double {
2285 precision: None,
2286 scale: None,
2287 })
2288 );
2289 assert_eq!(
2291 op.right.inferred_type(),
2292 Some(&DataType::Int {
2293 length: None,
2294 integer_spelling: false,
2295 })
2296 );
2297 } else {
2298 panic!("Expected Add expression");
2299 }
2300 }
2301
2302 #[test]
2303 fn test_annotate_in_place_comparison() {
2304 let mut expr = Expression::Eq(Box::new(BinaryOp::new(
2305 make_int_literal(1),
2306 make_int_literal(2),
2307 )));
2308 annotate_types(&mut expr, None, None);
2309 assert_eq!(expr.inferred_type(), Some(&DataType::Boolean));
2310 }
2311
2312 #[test]
2313 fn test_annotate_in_place_cast() {
2314 let mut expr = Expression::Cast(Box::new(Cast {
2315 this: make_int_literal(42),
2316 to: DataType::VarChar {
2317 length: None,
2318 parenthesized_length: false,
2319 },
2320 trailing_comments: vec![],
2321 double_colon_syntax: false,
2322 format: None,
2323 default: None,
2324 inferred_type: None,
2325 }));
2326 annotate_types(&mut expr, None, None);
2327 assert_eq!(
2328 expr.inferred_type(),
2329 Some(&DataType::VarChar {
2330 length: None,
2331 parenthesized_length: false,
2332 })
2333 );
2334 }
2335
2336 #[test]
2337 fn test_annotate_in_place_nested_expression() {
2338 let add = Expression::Add(Box::new(BinaryOp::new(
2340 make_int_literal(1),
2341 make_int_literal(2),
2342 )));
2343 let mut expr = Expression::Gt(Box::new(BinaryOp::new(add, make_int_literal(0))));
2344 annotate_types(&mut expr, None, None);
2345
2346 assert_eq!(expr.inferred_type(), Some(&DataType::Boolean));
2347
2348 if let Expression::Gt(op) = &expr {
2350 assert_eq!(
2351 op.left.inferred_type(),
2352 Some(&DataType::Int {
2353 length: None,
2354 integer_spelling: false,
2355 })
2356 );
2357 }
2358 }
2359
2360 #[test]
2361 fn test_annotate_in_place_parsed_sql() {
2362 use crate::parser::Parser;
2363 let mut expr =
2364 Parser::parse_sql("SELECT 1 + 2.0, 'hello', TRUE").expect("parse failed")[0].clone();
2365 annotate_types(&mut expr, None, None);
2366
2367 assert!(expr.inferred_type().is_none());
2371 }
2372
2373 #[test]
2374 fn test_inferred_type_json_roundtrip() {
2375 let mut expr = Expression::Add(Box::new(BinaryOp::new(
2376 make_int_literal(1),
2377 make_int_literal(2),
2378 )));
2379 annotate_types(&mut expr, None, None);
2380
2381 let json = serde_json::to_string(&expr).expect("serialize failed");
2383 assert!(json.contains("inferred_type"));
2385
2386 let deserialized: Expression = serde_json::from_str(&json).expect("deserialize failed");
2388 assert_eq!(
2389 deserialized.inferred_type(),
2390 Some(&DataType::Int {
2391 length: None,
2392 integer_spelling: false,
2393 })
2394 );
2395 }
2396
2397 #[test]
2398 fn test_inferred_type_none_not_serialized() {
2399 let expr = Expression::Add(Box::new(BinaryOp::new(
2401 make_int_literal(1),
2402 make_int_literal(2),
2403 )));
2404 let json = serde_json::to_string(&expr).expect("serialize failed");
2405 assert!(!json.contains("inferred_type"));
2406 }
2407
2408 #[test]
2409 fn test_annotate_if_func_bigquery_node_and_alias_type() {
2410 let mut schema = MappingSchema::with_dialect(DialectType::BigQuery);
2411 schema
2412 .add_table(
2413 "t",
2414 &[("col1".to_string(), DataType::String { length: None })],
2415 None,
2416 )
2417 .unwrap();
2418
2419 let mut expr = parse_one(
2420 "SELECT IF(col1 IS NOT NULL, 1, 0) AS x FROM t",
2421 DialectType::BigQuery,
2422 )
2423 .unwrap();
2424 annotate_types(&mut expr, Some(&schema), Some(DialectType::BigQuery));
2425
2426 let Expression::Select(select) = &expr else {
2427 panic!("expected select");
2428 };
2429 let Expression::Alias(alias) = &select.expressions[0] else {
2430 panic!("expected alias");
2431 };
2432
2433 assert_eq!(
2434 alias.this.inferred_type(),
2435 Some(&DataType::Int {
2436 length: None,
2437 integer_spelling: false,
2438 })
2439 );
2440 assert_eq!(
2441 select.expressions[0].inferred_type(),
2442 Some(&DataType::Int {
2443 length: None,
2444 integer_spelling: false,
2445 })
2446 );
2447 }
2448
2449 #[test]
2450 fn test_annotate_nvl2_node_type() {
2451 let mut expr = parse_one("SELECT NVL2(a, 1, 0) AS x", DialectType::Generic).unwrap();
2452 annotate_types(&mut expr, None, None);
2453
2454 let Expression::Select(select) = &expr else {
2455 panic!("expected select");
2456 };
2457 let Expression::Alias(alias) = &select.expressions[0] else {
2458 panic!("expected alias");
2459 };
2460
2461 assert_eq!(
2462 alias.this.inferred_type(),
2463 Some(&DataType::Int {
2464 length: None,
2465 integer_spelling: false,
2466 })
2467 );
2468 }
2469
2470 #[test]
2471 fn test_annotate_count_node_type() {
2472 let mut expr = parse_one("SELECT COUNT(1) AS x", DialectType::Generic).unwrap();
2473 annotate_types(&mut expr, None, None);
2474
2475 let Expression::Select(select) = &expr else {
2476 panic!("expected select");
2477 };
2478 let Expression::Alias(alias) = &select.expressions[0] else {
2479 panic!("expected alias");
2480 };
2481
2482 assert_eq!(
2483 alias.this.inferred_type(),
2484 Some(&DataType::BigInt { length: None })
2485 );
2486 }
2487
2488 #[test]
2489 fn test_annotate_group_concat_node_type() {
2490 let mut expr = parse_one("SELECT GROUP_CONCAT(a) AS x", DialectType::Generic).unwrap();
2491 annotate_types(&mut expr, None, None);
2492
2493 let Expression::Select(select) = &expr else {
2494 panic!("expected select");
2495 };
2496 let Expression::Alias(alias) = &select.expressions[0] else {
2497 panic!("expected alias");
2498 };
2499
2500 assert_eq!(
2501 alias.this.inferred_type(),
2502 Some(&DataType::VarChar {
2503 length: None,
2504 parenthesized_length: false,
2505 })
2506 );
2507 }
2508
2509 #[test]
2510 fn test_annotate_sum_if_generic_aggregate_type() {
2511 let mut expr =
2512 parse_one("SELECT SUM_IF(1, a > 0) AS x FROM t", DialectType::Generic).unwrap();
2513 annotate_types(&mut expr, None, None);
2514
2515 let Expression::Select(select) = &expr else {
2516 panic!("expected select");
2517 };
2518 let Expression::Alias(alias) = &select.expressions[0] else {
2519 panic!("expected alias");
2520 };
2521
2522 assert_eq!(
2523 select.expressions[0].inferred_type(),
2524 Some(&DataType::BigInt { length: None })
2525 );
2526 assert_eq!(
2527 alias.this.inferred_type(),
2528 Some(&DataType::BigInt { length: None })
2529 );
2530 }
2531}