1use std::collections::HashMap;
13
14use crate::dialects::DialectType;
15use crate::expressions::{
16 BinaryOp, DataType, Expression, Function, Literal, Map, Struct, StructField, Subscript,
17};
18use crate::schema::Schema;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
23pub enum TypeCoercionClass {
24 Text = 0,
26 Numeric = 1,
28 Timelike = 2,
30}
31
32impl TypeCoercionClass {
33 pub fn from_data_type(dt: &DataType) -> Option<Self> {
35 match dt {
36 DataType::Char { .. }
38 | DataType::VarChar { .. }
39 | DataType::Text
40 | DataType::Binary { .. }
41 | DataType::VarBinary { .. }
42 | DataType::Blob => Some(TypeCoercionClass::Text),
43
44 DataType::Boolean
46 | DataType::TinyInt { .. }
47 | DataType::SmallInt { .. }
48 | DataType::Int { .. }
49 | DataType::BigInt { .. }
50 | DataType::Float { .. }
51 | DataType::Double { .. }
52 | DataType::Decimal { .. } => Some(TypeCoercionClass::Numeric),
53
54 DataType::Date
56 | DataType::Time { .. }
57 | DataType::Timestamp { .. }
58 | DataType::Interval { .. } => Some(TypeCoercionClass::Timelike),
59
60 _ => None,
62 }
63 }
64}
65
66pub struct TypeAnnotator<'a> {
68 _schema: Option<&'a dyn Schema>,
70 _dialect: Option<DialectType>,
72 annotate_aggregates: bool,
74 function_return_types: HashMap<String, DataType>,
76}
77
78impl<'a> TypeAnnotator<'a> {
79 pub fn new(schema: Option<&'a dyn Schema>, dialect: Option<DialectType>) -> Self {
81 let mut annotator = Self {
82 _schema: schema,
83 _dialect: dialect,
84 annotate_aggregates: true,
85 function_return_types: HashMap::new(),
86 };
87 annotator.init_function_return_types();
88 annotator
89 }
90
91 fn init_function_return_types(&mut self) {
93 self.function_return_types
95 .insert("COUNT".to_string(), DataType::BigInt { length: None });
96 self.function_return_types.insert(
97 "SUM".to_string(),
98 DataType::Decimal {
99 precision: None,
100 scale: None,
101 },
102 );
103 self.function_return_types.insert(
104 "AVG".to_string(),
105 DataType::Double {
106 precision: None,
107 scale: None,
108 },
109 );
110
111 self.function_return_types.insert(
113 "CONCAT".to_string(),
114 DataType::VarChar {
115 length: None,
116 parenthesized_length: false,
117 },
118 );
119 self.function_return_types.insert(
120 "UPPER".to_string(),
121 DataType::VarChar {
122 length: None,
123 parenthesized_length: false,
124 },
125 );
126 self.function_return_types.insert(
127 "LOWER".to_string(),
128 DataType::VarChar {
129 length: None,
130 parenthesized_length: false,
131 },
132 );
133 self.function_return_types.insert(
134 "TRIM".to_string(),
135 DataType::VarChar {
136 length: None,
137 parenthesized_length: false,
138 },
139 );
140 self.function_return_types.insert(
141 "LTRIM".to_string(),
142 DataType::VarChar {
143 length: None,
144 parenthesized_length: false,
145 },
146 );
147 self.function_return_types.insert(
148 "RTRIM".to_string(),
149 DataType::VarChar {
150 length: None,
151 parenthesized_length: false,
152 },
153 );
154 self.function_return_types.insert(
155 "SUBSTRING".to_string(),
156 DataType::VarChar {
157 length: None,
158 parenthesized_length: false,
159 },
160 );
161 self.function_return_types.insert(
162 "SUBSTR".to_string(),
163 DataType::VarChar {
164 length: None,
165 parenthesized_length: false,
166 },
167 );
168 self.function_return_types.insert(
169 "REPLACE".to_string(),
170 DataType::VarChar {
171 length: None,
172 parenthesized_length: false,
173 },
174 );
175 self.function_return_types.insert(
176 "LENGTH".to_string(),
177 DataType::Int {
178 length: None,
179 integer_spelling: false,
180 },
181 );
182 self.function_return_types.insert(
183 "CHAR_LENGTH".to_string(),
184 DataType::Int {
185 length: None,
186 integer_spelling: false,
187 },
188 );
189
190 self.function_return_types.insert(
192 "NOW".to_string(),
193 DataType::Timestamp {
194 precision: None,
195 timezone: false,
196 },
197 );
198 self.function_return_types.insert(
199 "CURRENT_TIMESTAMP".to_string(),
200 DataType::Timestamp {
201 precision: None,
202 timezone: false,
203 },
204 );
205 self.function_return_types
206 .insert("CURRENT_DATE".to_string(), DataType::Date);
207 self.function_return_types.insert(
208 "CURRENT_TIME".to_string(),
209 DataType::Time {
210 precision: None,
211 timezone: false,
212 },
213 );
214 self.function_return_types
215 .insert("DATE".to_string(), DataType::Date);
216 self.function_return_types.insert(
217 "YEAR".to_string(),
218 DataType::Int {
219 length: None,
220 integer_spelling: false,
221 },
222 );
223 self.function_return_types.insert(
224 "MONTH".to_string(),
225 DataType::Int {
226 length: None,
227 integer_spelling: false,
228 },
229 );
230 self.function_return_types.insert(
231 "DAY".to_string(),
232 DataType::Int {
233 length: None,
234 integer_spelling: false,
235 },
236 );
237 self.function_return_types.insert(
238 "HOUR".to_string(),
239 DataType::Int {
240 length: None,
241 integer_spelling: false,
242 },
243 );
244 self.function_return_types.insert(
245 "MINUTE".to_string(),
246 DataType::Int {
247 length: None,
248 integer_spelling: false,
249 },
250 );
251 self.function_return_types.insert(
252 "SECOND".to_string(),
253 DataType::Int {
254 length: None,
255 integer_spelling: false,
256 },
257 );
258 self.function_return_types.insert(
259 "EXTRACT".to_string(),
260 DataType::Int {
261 length: None,
262 integer_spelling: false,
263 },
264 );
265 self.function_return_types.insert(
266 "DATE_DIFF".to_string(),
267 DataType::Int {
268 length: None,
269 integer_spelling: false,
270 },
271 );
272 self.function_return_types.insert(
273 "DATEDIFF".to_string(),
274 DataType::Int {
275 length: None,
276 integer_spelling: false,
277 },
278 );
279
280 self.function_return_types.insert(
282 "ABS".to_string(),
283 DataType::Double {
284 precision: None,
285 scale: None,
286 },
287 );
288 self.function_return_types.insert(
289 "ROUND".to_string(),
290 DataType::Double {
291 precision: None,
292 scale: None,
293 },
294 );
295 self.function_return_types.insert(
296 "DATE_FORMAT".to_string(),
297 DataType::VarChar {
298 length: None,
299 parenthesized_length: false,
300 },
301 );
302 self.function_return_types.insert(
303 "FORMAT_DATE".to_string(),
304 DataType::VarChar {
305 length: None,
306 parenthesized_length: false,
307 },
308 );
309 self.function_return_types.insert(
310 "TIME_TO_STR".to_string(),
311 DataType::VarChar {
312 length: None,
313 parenthesized_length: false,
314 },
315 );
316 self.function_return_types.insert(
317 "SQRT".to_string(),
318 DataType::Double {
319 precision: None,
320 scale: None,
321 },
322 );
323 self.function_return_types.insert(
324 "POWER".to_string(),
325 DataType::Double {
326 precision: None,
327 scale: None,
328 },
329 );
330 self.function_return_types.insert(
331 "MOD".to_string(),
332 DataType::Int {
333 length: None,
334 integer_spelling: false,
335 },
336 );
337 self.function_return_types.insert(
338 "LOG".to_string(),
339 DataType::Double {
340 precision: None,
341 scale: None,
342 },
343 );
344 self.function_return_types.insert(
345 "LN".to_string(),
346 DataType::Double {
347 precision: None,
348 scale: None,
349 },
350 );
351 self.function_return_types.insert(
352 "EXP".to_string(),
353 DataType::Double {
354 precision: None,
355 scale: None,
356 },
357 );
358
359 self.function_return_types
361 .insert("COALESCE".to_string(), DataType::Unknown);
362 self.function_return_types
363 .insert("NULLIF".to_string(), DataType::Unknown);
364 self.function_return_types
365 .insert("GREATEST".to_string(), DataType::Unknown);
366 self.function_return_types
367 .insert("LEAST".to_string(), DataType::Unknown);
368 }
369
370 pub fn annotate(&mut self, expr: &Expression) -> Option<DataType> {
372 match expr {
373 Expression::Literal(lit) => self.annotate_literal(lit),
375 Expression::Boolean(_) => Some(DataType::Boolean),
376 Expression::Null(_) => None, Expression::Add(op)
380 | Expression::Sub(op)
381 | Expression::Mul(op)
382 | Expression::Div(op)
383 | Expression::Mod(op) => self.annotate_arithmetic(op),
384
385 Expression::Eq(_)
387 | Expression::Neq(_)
388 | Expression::Lt(_)
389 | Expression::Lte(_)
390 | Expression::Gt(_)
391 | Expression::Gte(_)
392 | Expression::Like(_)
393 | Expression::ILike(_) => Some(DataType::Boolean),
394
395 Expression::And(_) | Expression::Or(_) | Expression::Not(_) => Some(DataType::Boolean),
397
398 Expression::Between(_)
400 | Expression::In(_)
401 | Expression::IsNull(_)
402 | Expression::IsTrue(_)
403 | Expression::IsFalse(_)
404 | Expression::Is(_)
405 | Expression::Exists(_) => Some(DataType::Boolean),
406
407 Expression::Concat(_) => Some(DataType::VarChar {
409 length: None,
410 parenthesized_length: false,
411 }),
412
413 Expression::BitwiseAnd(_)
415 | Expression::BitwiseOr(_)
416 | Expression::BitwiseXor(_)
417 | Expression::BitwiseNot(_) => Some(DataType::BigInt { length: None }),
418
419 Expression::Neg(op) => self.annotate(&op.this),
421
422 Expression::Function(func) => self.annotate_function(func),
424
425 Expression::Count(_) => Some(DataType::BigInt { length: None }),
427 Expression::Sum(agg) => self.annotate_sum(&agg.this),
428 Expression::Avg(_) => Some(DataType::Double {
429 precision: None,
430 scale: None,
431 }),
432 Expression::Min(agg) => self.annotate(&agg.this),
433 Expression::Max(agg) => self.annotate(&agg.this),
434 Expression::GroupConcat(_) | Expression::StringAgg(_) | Expression::ListAgg(_) => {
435 Some(DataType::VarChar {
436 length: None,
437 parenthesized_length: false,
438 })
439 }
440
441 Expression::AggregateFunction(agg) => {
443 if !self.annotate_aggregates {
444 return None;
445 }
446 let func_name = agg.name.to_uppercase();
447 self.get_aggregate_return_type(&func_name, &agg.args)
448 }
449
450 Expression::Column(col) => {
452 if let Some(schema) = &self._schema {
453 let table_name = col.table.as_ref().map(|t| t.name.as_str()).unwrap_or("");
454 schema.get_column_type(table_name, &col.name.name).ok()
455 } else {
456 None
457 }
458 }
459
460 Expression::Cast(cast) => Some(cast.to.clone()),
462 Expression::SafeCast(cast) => Some(cast.to.clone()),
463 Expression::TryCast(cast) => Some(cast.to.clone()),
464
465 Expression::Subquery(subq) => {
467 if let Expression::Select(select) = &subq.this {
468 if let Some(first) = select.expressions.first() {
469 self.annotate(first)
470 } else {
471 None
472 }
473 } else {
474 None
475 }
476 }
477
478 Expression::Case(case) => {
480 if let Some(else_expr) = &case.else_ {
481 self.annotate(else_expr)
482 } else if let Some((_, then_expr)) = case.whens.first() {
483 self.annotate(then_expr)
484 } else {
485 None
486 }
487 }
488
489 Expression::Array(arr) => {
491 if let Some(first) = arr.expressions.first() {
492 if let Some(elem_type) = self.annotate(first) {
493 Some(DataType::Array {
494 element_type: Box::new(elem_type),
495 dimension: None,
496 })
497 } else {
498 Some(DataType::Array {
499 element_type: Box::new(DataType::Unknown),
500 dimension: None,
501 })
502 }
503 } else {
504 Some(DataType::Array {
505 element_type: Box::new(DataType::Unknown),
506 dimension: None,
507 })
508 }
509 }
510
511 Expression::Interval(_) => Some(DataType::Interval {
513 unit: None,
514 to: None,
515 }),
516
517 Expression::WindowFunction(window) => self.annotate(&window.this),
519
520 Expression::CurrentDate(_) => Some(DataType::Date),
522 Expression::CurrentTime(_) => Some(DataType::Time {
523 precision: None,
524 timezone: false,
525 }),
526 Expression::CurrentTimestamp(_) | Expression::CurrentTimestampLTZ(_) => {
527 Some(DataType::Timestamp {
528 precision: None,
529 timezone: false,
530 })
531 }
532
533 Expression::DateAdd(_)
535 | Expression::DateSub(_)
536 | Expression::ToDate(_)
537 | Expression::Date(_) => Some(DataType::Date),
538 Expression::DateDiff(_) | Expression::Extract(_) => Some(DataType::Int {
539 length: None,
540 integer_spelling: false,
541 }),
542 Expression::ToTimestamp(_) => Some(DataType::Timestamp {
543 precision: None,
544 timezone: false,
545 }),
546
547 Expression::Upper(_)
549 | Expression::Lower(_)
550 | Expression::Trim(_)
551 | Expression::LTrim(_)
552 | Expression::RTrim(_)
553 | Expression::Replace(_)
554 | Expression::Substring(_)
555 | Expression::Reverse(_)
556 | Expression::Left(_)
557 | Expression::Right(_)
558 | Expression::Repeat(_)
559 | Expression::Lpad(_)
560 | Expression::Rpad(_)
561 | Expression::ConcatWs(_)
562 | Expression::Overlay(_) => Some(DataType::VarChar {
563 length: None,
564 parenthesized_length: false,
565 }),
566 Expression::Length(_) => Some(DataType::Int {
567 length: None,
568 integer_spelling: false,
569 }),
570
571 Expression::Abs(_)
573 | Expression::Sqrt(_)
574 | Expression::Cbrt(_)
575 | Expression::Ln(_)
576 | Expression::Exp(_)
577 | Expression::Power(_)
578 | Expression::Log(_) => Some(DataType::Double {
579 precision: None,
580 scale: None,
581 }),
582 Expression::Round(_) => Some(DataType::Double {
583 precision: None,
584 scale: None,
585 }),
586 Expression::Floor(f) => self.annotate_math_function(&f.this),
587 Expression::Ceil(f) => self.annotate_math_function(&f.this),
588 Expression::Sign(s) => self.annotate(&s.this),
589 Expression::DateFormat(_) | Expression::FormatDate(_) | Expression::TimeToStr(_) => {
590 Some(DataType::VarChar {
591 length: None,
592 parenthesized_length: false,
593 })
594 }
595
596 Expression::Greatest(v) | Expression::Least(v) => self.coerce_arg_types(&v.expressions),
598
599 Expression::Alias(alias) => self.annotate(&alias.this),
601
602 Expression::Select(_) => None,
604
605 Expression::Subscript(sub) => self.annotate_subscript(sub),
609
610 Expression::Dot(_) => None,
612
613 Expression::Struct(s) => self.annotate_struct(s),
617
618 Expression::Map(map) => self.annotate_map(map),
622 Expression::MapFromEntries(mfe) => {
623 if let Some(DataType::Array { element_type, .. }) = self.annotate(&mfe.this) {
625 if let DataType::Struct { fields, .. } = *element_type {
626 if fields.len() >= 2 {
627 return Some(DataType::Map {
628 key_type: Box::new(fields[0].data_type.clone()),
629 value_type: Box::new(fields[1].data_type.clone()),
630 });
631 }
632 }
633 }
634 Some(DataType::Map {
635 key_type: Box::new(DataType::Unknown),
636 value_type: Box::new(DataType::Unknown),
637 })
638 }
639
640 Expression::Union(union) => self.annotate_set_operation(&union.left, &union.right),
644 Expression::Intersect(intersect) => {
645 self.annotate_set_operation(&intersect.left, &intersect.right)
646 }
647 Expression::Except(except) => self.annotate_set_operation(&except.left, &except.right),
648
649 Expression::Lateral(lateral) => {
653 self.annotate(&lateral.this)
655 }
656 Expression::LateralView(lv) => {
657 self.annotate_lateral_view(lv)
659 }
660 Expression::Unnest(unnest) => {
661 if let Some(DataType::Array { element_type, .. }) = self.annotate(&unnest.this) {
663 Some(*element_type)
664 } else {
665 None
666 }
667 }
668 Expression::Explode(explode) => {
669 if let Some(DataType::Array { element_type, .. }) = self.annotate(&explode.this) {
671 Some(*element_type)
672 } else if let Some(DataType::Map {
673 key_type,
674 value_type,
675 }) = self.annotate(&explode.this)
676 {
677 Some(DataType::Struct {
679 fields: vec![
680 StructField::new("key".to_string(), *key_type),
681 StructField::new("value".to_string(), *value_type),
682 ],
683 nested: false,
684 })
685 } else {
686 None
687 }
688 }
689 Expression::ExplodeOuter(explode) => {
690 if let Some(DataType::Array { element_type, .. }) = self.annotate(&explode.this) {
692 Some(*element_type)
693 } else {
694 None
695 }
696 }
697 Expression::GenerateSeries(gs) => {
698 if let Some(ref start) = gs.start {
700 self.annotate(start)
701 } else if let Some(ref end) = gs.end {
702 self.annotate(end)
703 } else {
704 Some(DataType::Int {
705 length: None,
706 integer_spelling: false,
707 })
708 }
709 }
710
711 _ => None,
713 }
714 }
715
716 pub fn annotate_in_place(&mut self, expr: &mut Expression) {
722 self.annotate_children_in_place(expr);
724
725 let dt = self.annotate(expr);
729
730 if let Some(data_type) = dt {
732 expr.set_inferred_type(data_type);
733 }
734 }
735
736 fn annotate_children_in_place(&mut self, expr: &mut Expression) {
738 match expr {
739 Expression::And(op)
741 | Expression::Or(op)
742 | Expression::Add(op)
743 | Expression::Sub(op)
744 | Expression::Mul(op)
745 | Expression::Div(op)
746 | Expression::Mod(op)
747 | Expression::Eq(op)
748 | Expression::Neq(op)
749 | Expression::Lt(op)
750 | Expression::Lte(op)
751 | Expression::Gt(op)
752 | Expression::Gte(op)
753 | Expression::Concat(op)
754 | Expression::BitwiseAnd(op)
755 | Expression::BitwiseOr(op)
756 | Expression::BitwiseXor(op)
757 | Expression::Adjacent(op)
758 | Expression::TsMatch(op)
759 | Expression::PropertyEQ(op)
760 | Expression::ArrayContainsAll(op)
761 | Expression::ArrayContainedBy(op)
762 | Expression::ArrayOverlaps(op)
763 | Expression::JSONBContainsAllTopKeys(op)
764 | Expression::JSONBContainsAnyTopKeys(op)
765 | Expression::JSONBDeleteAtPath(op)
766 | Expression::ExtendsLeft(op)
767 | Expression::ExtendsRight(op)
768 | Expression::Is(op)
769 | Expression::MemberOf(op)
770 | Expression::Match(op)
771 | Expression::NullSafeEq(op)
772 | Expression::NullSafeNeq(op)
773 | Expression::Glob(op)
774 | Expression::BitwiseLeftShift(op)
775 | Expression::BitwiseRightShift(op) => {
776 self.annotate_in_place(&mut op.left);
777 self.annotate_in_place(&mut op.right);
778 }
779
780 Expression::Like(op) | Expression::ILike(op) => {
782 self.annotate_in_place(&mut op.left);
783 self.annotate_in_place(&mut op.right);
784 }
785
786 Expression::Not(op) | Expression::Neg(op) | Expression::BitwiseNot(op) => {
788 self.annotate_in_place(&mut op.this);
789 }
790
791 Expression::Cast(c) | Expression::TryCast(c) | Expression::SafeCast(c) => {
793 self.annotate_in_place(&mut c.this);
794 }
795
796 Expression::Case(c) => {
798 if let Some(ref mut operand) = c.operand {
799 self.annotate_in_place(operand);
800 }
801 for (cond, then_expr) in &mut c.whens {
802 self.annotate_in_place(cond);
803 self.annotate_in_place(then_expr);
804 }
805 if let Some(ref mut else_expr) = c.else_ {
806 self.annotate_in_place(else_expr);
807 }
808 }
809
810 Expression::Alias(a) => {
812 self.annotate_in_place(&mut a.this);
813 }
814
815 Expression::Column(_) => {}
817
818 Expression::Function(f) => {
820 for arg in &mut f.args {
821 self.annotate_in_place(arg);
822 }
823 }
824
825 Expression::AggregateFunction(f) => {
827 for arg in &mut f.args {
828 self.annotate_in_place(arg);
829 }
830 }
831
832 Expression::WindowFunction(w) => {
834 self.annotate_in_place(&mut w.this);
835 }
836
837 Expression::Subquery(s) => {
839 self.annotate_in_place(&mut s.this);
840 }
841
842 Expression::Upper(f)
844 | Expression::Lower(f)
845 | Expression::Length(f)
846 | Expression::LTrim(f)
847 | Expression::RTrim(f)
848 | Expression::Reverse(f)
849 | Expression::Abs(f)
850 | Expression::Sqrt(f)
851 | Expression::Cbrt(f)
852 | Expression::Ln(f)
853 | Expression::Exp(f)
854 | Expression::Sign(f)
855 | Expression::Date(f)
856 | Expression::Time(f)
857 | Expression::Explode(f)
858 | Expression::ExplodeOuter(f)
859 | Expression::MapFromEntries(f)
860 | Expression::MapKeys(f)
861 | Expression::MapValues(f)
862 | Expression::ArrayLength(f)
863 | Expression::ArraySize(f)
864 | Expression::Cardinality(f)
865 | Expression::ArrayReverse(f)
866 | Expression::ArrayDistinct(f)
867 | Expression::ArrayFlatten(f)
868 | Expression::ArrayCompact(f)
869 | Expression::ToArray(f)
870 | Expression::JsonArrayLength(f)
871 | Expression::JsonKeys(f)
872 | Expression::JsonType(f)
873 | Expression::ParseJson(f)
874 | Expression::ToJson(f)
875 | Expression::Year(f)
876 | Expression::Month(f)
877 | Expression::Day(f)
878 | Expression::Hour(f)
879 | Expression::Minute(f)
880 | Expression::Second(f)
881 | Expression::Initcap(f)
882 | Expression::Ascii(f)
883 | Expression::Chr(f)
884 | Expression::Soundex(f)
885 | Expression::ByteLength(f)
886 | Expression::Hex(f)
887 | Expression::LowerHex(f)
888 | Expression::Unicode(f)
889 | Expression::Typeof(f)
890 | Expression::BitwiseCount(f)
891 | Expression::Epoch(f)
892 | Expression::EpochMs(f)
893 | Expression::Radians(f)
894 | Expression::Degrees(f)
895 | Expression::Sin(f)
896 | Expression::Cos(f)
897 | Expression::Tan(f)
898 | Expression::Asin(f)
899 | Expression::Acos(f)
900 | Expression::Atan(f)
901 | Expression::IsNan(f)
902 | Expression::IsInf(f) => {
903 self.annotate_in_place(&mut f.this);
904 }
905
906 Expression::Power(f)
908 | Expression::NullIf(f)
909 | Expression::IfNull(f)
910 | Expression::Nvl(f)
911 | Expression::Contains(f)
912 | Expression::StartsWith(f)
913 | Expression::EndsWith(f)
914 | Expression::Levenshtein(f)
915 | Expression::ModFunc(f)
916 | Expression::IntDiv(f)
917 | Expression::Atan2(f)
918 | Expression::AddMonths(f)
919 | Expression::MonthsBetween(f)
920 | Expression::NextDay(f)
921 | Expression::UnixToTimeStr(f)
922 | Expression::ArrayContains(f)
923 | Expression::ArrayPosition(f)
924 | Expression::ArrayAppend(f)
925 | Expression::ArrayPrepend(f)
926 | Expression::ArrayUnion(f)
927 | Expression::ArrayExcept(f)
928 | Expression::ArrayRemove(f)
929 | Expression::StarMap(f)
930 | Expression::MapFromArrays(f)
931 | Expression::MapContainsKey(f)
932 | Expression::ElementAt(f)
933 | Expression::JsonMergePatch(f) => {
934 self.annotate_in_place(&mut f.this);
935 self.annotate_in_place(&mut f.expression);
936 }
937
938 Expression::Coalesce(f)
940 | Expression::Greatest(f)
941 | Expression::Least(f)
942 | Expression::ArrayConcat(f)
943 | Expression::ArrayIntersect(f)
944 | Expression::ArrayZip(f)
945 | Expression::MapConcat(f)
946 | Expression::JsonArray(f) => {
947 for e in &mut f.expressions {
948 self.annotate_in_place(e);
949 }
950 }
951
952 Expression::Sum(f)
954 | Expression::Avg(f)
955 | Expression::Min(f)
956 | Expression::Max(f)
957 | Expression::ArrayAgg(f)
958 | Expression::CountIf(f)
959 | Expression::Stddev(f)
960 | Expression::StddevPop(f)
961 | Expression::StddevSamp(f)
962 | Expression::Variance(f)
963 | Expression::VarPop(f)
964 | Expression::VarSamp(f)
965 | Expression::Median(f)
966 | Expression::Mode(f)
967 | Expression::First(f)
968 | Expression::Last(f)
969 | Expression::AnyValue(f)
970 | Expression::ApproxDistinct(f)
971 | Expression::ApproxCountDistinct(f)
972 | Expression::LogicalAnd(f)
973 | Expression::LogicalOr(f)
974 | Expression::Skewness(f)
975 | Expression::ArrayConcatAgg(f)
976 | Expression::ArrayUniqueAgg(f)
977 | Expression::BoolXorAgg(f)
978 | Expression::BitwiseAndAgg(f)
979 | Expression::BitwiseOrAgg(f)
980 | Expression::BitwiseXorAgg(f) => {
981 self.annotate_in_place(&mut f.this);
982 }
983
984 Expression::Select(s) => {
986 for e in &mut s.expressions {
987 self.annotate_in_place(e);
988 }
989 }
990
991 _ => {}
993 }
994 }
995
996 fn annotate_math_function(&mut self, arg: &Expression) -> Option<DataType> {
999 let input_type = self.annotate(arg)?;
1000 match input_type {
1001 DataType::TinyInt { .. }
1002 | DataType::SmallInt { .. }
1003 | DataType::Int { .. }
1004 | DataType::BigInt { .. } => Some(DataType::Double {
1005 precision: None,
1006 scale: None,
1007 }),
1008 other => Some(other),
1009 }
1010 }
1011
1012 fn annotate_subscript(&mut self, sub: &Subscript) -> Option<DataType> {
1014 let base_type = self.annotate(&sub.this)?;
1015
1016 match base_type {
1017 DataType::Array { element_type, .. } => Some(*element_type),
1018 DataType::Map { value_type, .. } => Some(*value_type),
1019 DataType::Json | DataType::JsonB => Some(DataType::Json), DataType::VarChar { .. } | DataType::Text => {
1021 Some(DataType::VarChar {
1023 length: Some(1),
1024 parenthesized_length: false,
1025 })
1026 }
1027 _ => None,
1028 }
1029 }
1030
1031 fn annotate_struct(&mut self, s: &Struct) -> Option<DataType> {
1033 let fields: Vec<StructField> = s
1034 .fields
1035 .iter()
1036 .map(|(name, expr)| {
1037 let field_type = self.annotate(expr).unwrap_or(DataType::Unknown);
1038 StructField::new(name.clone().unwrap_or_default(), field_type)
1039 })
1040 .collect();
1041 Some(DataType::Struct {
1042 fields,
1043 nested: false,
1044 })
1045 }
1046
1047 fn annotate_map(&mut self, map: &Map) -> Option<DataType> {
1049 let key_type = if let Some(first_key) = map.keys.first() {
1050 self.annotate(first_key).unwrap_or(DataType::Unknown)
1051 } else {
1052 DataType::Unknown
1053 };
1054
1055 let value_type = if let Some(first_value) = map.values.first() {
1056 self.annotate(first_value).unwrap_or(DataType::Unknown)
1057 } else {
1058 DataType::Unknown
1059 };
1060
1061 Some(DataType::Map {
1062 key_type: Box::new(key_type),
1063 value_type: Box::new(value_type),
1064 })
1065 }
1066
1067 fn annotate_set_operation(
1070 &mut self,
1071 _left: &Expression,
1072 _right: &Expression,
1073 ) -> Option<DataType> {
1074 None
1078 }
1079
1080 fn annotate_lateral_view(&mut self, lv: &crate::expressions::LateralView) -> Option<DataType> {
1082 self.annotate(&lv.this)
1084 }
1085
1086 fn annotate_literal(&self, lit: &Literal) -> Option<DataType> {
1088 match lit {
1089 Literal::String(_)
1090 | Literal::NationalString(_)
1091 | Literal::TripleQuotedString(_, _)
1092 | Literal::EscapeString(_)
1093 | Literal::DollarString(_)
1094 | Literal::RawString(_) => Some(DataType::VarChar {
1095 length: None,
1096 parenthesized_length: false,
1097 }),
1098 Literal::Number(n) => {
1099 if n.contains('.') || n.contains('e') || n.contains('E') {
1101 Some(DataType::Double {
1102 precision: None,
1103 scale: None,
1104 })
1105 } else {
1106 if let Ok(_) = n.parse::<i32>() {
1108 Some(DataType::Int {
1109 length: None,
1110 integer_spelling: false,
1111 })
1112 } else {
1113 Some(DataType::BigInt { length: None })
1114 }
1115 }
1116 }
1117 Literal::HexString(_) | Literal::BitString(_) | Literal::ByteString(_) => {
1118 Some(DataType::VarBinary { length: None })
1119 }
1120 Literal::HexNumber(_) => Some(DataType::BigInt { length: None }),
1121 Literal::Date(_) => Some(DataType::Date),
1122 Literal::Time(_) => Some(DataType::Time {
1123 precision: None,
1124 timezone: false,
1125 }),
1126 Literal::Timestamp(_) => Some(DataType::Timestamp {
1127 precision: None,
1128 timezone: false,
1129 }),
1130 Literal::Datetime(_) => Some(DataType::Custom {
1131 name: "DATETIME".to_string(),
1132 }),
1133 }
1134 }
1135
1136 fn annotate_arithmetic(&mut self, op: &BinaryOp) -> Option<DataType> {
1138 let left_type = self.annotate(&op.left);
1139 let right_type = self.annotate(&op.right);
1140
1141 match (left_type, right_type) {
1142 (Some(l), Some(r)) => self.coerce_types(&l, &r),
1143 (Some(t), None) | (None, Some(t)) => Some(t),
1144 (None, None) => None,
1145 }
1146 }
1147
1148 fn annotate_function(&mut self, func: &Function) -> Option<DataType> {
1150 let func_name = func.name.to_uppercase();
1151
1152 if let Some(return_type) = self.function_return_types.get(&func_name) {
1154 if *return_type != DataType::Unknown {
1155 return Some(return_type.clone());
1156 }
1157 }
1158
1159 match func_name.as_str() {
1161 "COALESCE" | "IFNULL" | "NVL" | "ISNULL" => {
1162 for arg in &func.args {
1164 if let Some(arg_type) = self.annotate(arg) {
1165 return Some(arg_type);
1166 }
1167 }
1168 None
1169 }
1170 "NULLIF" => {
1171 func.args.first().and_then(|arg| self.annotate(arg))
1173 }
1174 "GREATEST" | "LEAST" => {
1175 self.coerce_arg_types(&func.args)
1177 }
1178 "IF" | "IIF" => {
1179 if func.args.len() >= 2 {
1181 self.annotate(&func.args[1])
1182 } else {
1183 None
1184 }
1185 }
1186 _ => {
1187 func.args.first().and_then(|arg| self.annotate(arg))
1189 }
1190 }
1191 }
1192
1193 fn get_aggregate_return_type(
1195 &mut self,
1196 func_name: &str,
1197 args: &[Expression],
1198 ) -> Option<DataType> {
1199 match func_name {
1200 "COUNT" | "COUNT_IF" => Some(DataType::BigInt { length: None }),
1201 "SUM" => {
1202 if let Some(arg) = args.first() {
1203 self.annotate_sum(arg)
1204 } else {
1205 Some(DataType::Decimal {
1206 precision: None,
1207 scale: None,
1208 })
1209 }
1210 }
1211 "AVG" => Some(DataType::Double {
1212 precision: None,
1213 scale: None,
1214 }),
1215 "MIN" | "MAX" => {
1216 args.first().and_then(|arg| self.annotate(arg))
1218 }
1219 "STRING_AGG" | "GROUP_CONCAT" | "LISTAGG" | "ARRAY_AGG" => Some(DataType::VarChar {
1220 length: None,
1221 parenthesized_length: false,
1222 }),
1223 "BOOL_AND" | "BOOL_OR" | "EVERY" | "ANY" | "SOME" => Some(DataType::Boolean),
1224 "BIT_AND" | "BIT_OR" | "BIT_XOR" => Some(DataType::BigInt { length: None }),
1225 "STDDEV" | "STDDEV_POP" | "STDDEV_SAMP" | "VARIANCE" | "VAR_POP" | "VAR_SAMP" => {
1226 Some(DataType::Double {
1227 precision: None,
1228 scale: None,
1229 })
1230 }
1231 "PERCENTILE_CONT" | "PERCENTILE_DISC" | "MEDIAN" => {
1232 args.first().and_then(|arg| self.annotate(arg))
1233 }
1234 _ => None,
1235 }
1236 }
1237
1238 fn annotate_sum(&mut self, arg: &Expression) -> Option<DataType> {
1240 match self.annotate(arg) {
1241 Some(DataType::TinyInt { .. })
1242 | Some(DataType::SmallInt { .. })
1243 | Some(DataType::Int { .. }) => Some(DataType::BigInt { length: None }),
1244 Some(DataType::BigInt { .. }) => Some(DataType::BigInt { length: None }),
1245 Some(DataType::Float { .. }) | Some(DataType::Double { .. }) => {
1246 Some(DataType::Double {
1247 precision: None,
1248 scale: None,
1249 })
1250 }
1251 Some(DataType::Decimal { precision, scale }) => {
1252 Some(DataType::Decimal { precision, scale })
1253 }
1254 _ => Some(DataType::Decimal {
1255 precision: None,
1256 scale: None,
1257 }),
1258 }
1259 }
1260
1261 fn coerce_arg_types(&mut self, args: &[Expression]) -> Option<DataType> {
1263 let mut result_type: Option<DataType> = None;
1264 for arg in args {
1265 if let Some(arg_type) = self.annotate(arg) {
1266 result_type = match result_type {
1267 Some(t) => self.coerce_types(&t, &arg_type),
1268 None => Some(arg_type),
1269 };
1270 }
1271 }
1272 result_type
1273 }
1274
1275 fn coerce_types(&self, left: &DataType, right: &DataType) -> Option<DataType> {
1277 if left == right {
1279 return Some(left.clone());
1280 }
1281
1282 match (left, right) {
1284 (DataType::Date, DataType::Interval { .. })
1285 | (DataType::Interval { .. }, DataType::Date) => return Some(DataType::Date),
1286 (
1287 DataType::Timestamp {
1288 precision,
1289 timezone,
1290 },
1291 DataType::Interval { .. },
1292 )
1293 | (
1294 DataType::Interval { .. },
1295 DataType::Timestamp {
1296 precision,
1297 timezone,
1298 },
1299 ) => {
1300 return Some(DataType::Timestamp {
1301 precision: *precision,
1302 timezone: *timezone,
1303 });
1304 }
1305 _ => {}
1306 }
1307
1308 let left_class = TypeCoercionClass::from_data_type(left);
1310 let right_class = TypeCoercionClass::from_data_type(right);
1311
1312 match (left_class, right_class) {
1313 (Some(lc), Some(rc)) if lc == rc => {
1315 if lc == TypeCoercionClass::Numeric {
1317 Some(self.wider_numeric_type(left, right))
1318 } else {
1319 Some(left.clone())
1321 }
1322 }
1323 (Some(lc), Some(rc)) => {
1325 if lc > rc {
1326 Some(left.clone())
1327 } else {
1328 Some(right.clone())
1329 }
1330 }
1331 (Some(_), None) => Some(left.clone()),
1333 (None, Some(_)) => Some(right.clone()),
1334 (None, None) => Some(DataType::Unknown),
1336 }
1337 }
1338
1339 fn wider_numeric_type(&self, left: &DataType, right: &DataType) -> DataType {
1341 let order = |dt: &DataType| -> u8 {
1342 match dt {
1343 DataType::Boolean => 0,
1344 DataType::TinyInt { .. } => 1,
1345 DataType::SmallInt { .. } => 2,
1346 DataType::Int { .. } => 3,
1347 DataType::BigInt { .. } => 4,
1348 DataType::Float { .. } => 5,
1349 DataType::Double { .. } => 6,
1350 DataType::Decimal { .. } => 7,
1351 _ => 0,
1352 }
1353 };
1354
1355 if order(left) >= order(right) {
1356 left.clone()
1357 } else {
1358 right.clone()
1359 }
1360 }
1361}
1362
1363pub fn annotate_types(
1369 expr: &mut Expression,
1370 schema: Option<&dyn Schema>,
1371 dialect: Option<DialectType>,
1372) {
1373 let mut annotator = TypeAnnotator::new(schema, dialect);
1374 annotator.annotate_in_place(expr);
1375}
1376
1377#[cfg(test)]
1378mod tests {
1379 use super::*;
1380 use crate::expressions::{BooleanLiteral, Cast, Null};
1381
1382 fn make_int_literal(val: i64) -> Expression {
1383 Expression::Literal(Literal::Number(val.to_string()))
1384 }
1385
1386 fn make_float_literal(val: f64) -> Expression {
1387 Expression::Literal(Literal::Number(val.to_string()))
1388 }
1389
1390 fn make_string_literal(val: &str) -> Expression {
1391 Expression::Literal(Literal::String(val.to_string()))
1392 }
1393
1394 fn make_bool_literal(val: bool) -> Expression {
1395 Expression::Boolean(BooleanLiteral { value: val })
1396 }
1397
1398 #[test]
1399 fn test_literal_types() {
1400 let mut annotator = TypeAnnotator::new(None, None);
1401
1402 let int_expr = make_int_literal(42);
1404 assert_eq!(
1405 annotator.annotate(&int_expr),
1406 Some(DataType::Int {
1407 length: None,
1408 integer_spelling: false
1409 })
1410 );
1411
1412 let float_expr = make_float_literal(3.14);
1414 assert_eq!(
1415 annotator.annotate(&float_expr),
1416 Some(DataType::Double {
1417 precision: None,
1418 scale: None
1419 })
1420 );
1421
1422 let string_expr = make_string_literal("hello");
1424 assert_eq!(
1425 annotator.annotate(&string_expr),
1426 Some(DataType::VarChar {
1427 length: None,
1428 parenthesized_length: false
1429 })
1430 );
1431
1432 let bool_expr = make_bool_literal(true);
1434 assert_eq!(annotator.annotate(&bool_expr), Some(DataType::Boolean));
1435
1436 let null_expr = Expression::Null(Null);
1438 assert_eq!(annotator.annotate(&null_expr), None);
1439 }
1440
1441 #[test]
1442 fn test_comparison_types() {
1443 let mut annotator = TypeAnnotator::new(None, None);
1444
1445 let cmp = Expression::Gt(Box::new(BinaryOp::new(
1447 make_int_literal(1),
1448 make_int_literal(2),
1449 )));
1450 assert_eq!(annotator.annotate(&cmp), Some(DataType::Boolean));
1451
1452 let eq = Expression::Eq(Box::new(BinaryOp::new(
1454 make_string_literal("a"),
1455 make_string_literal("b"),
1456 )));
1457 assert_eq!(annotator.annotate(&eq), Some(DataType::Boolean));
1458 }
1459
1460 #[test]
1461 fn test_arithmetic_types() {
1462 let mut annotator = TypeAnnotator::new(None, None);
1463
1464 let add_int = Expression::Add(Box::new(BinaryOp::new(
1466 make_int_literal(1),
1467 make_int_literal(2),
1468 )));
1469 assert_eq!(
1470 annotator.annotate(&add_int),
1471 Some(DataType::Int {
1472 length: None,
1473 integer_spelling: false
1474 })
1475 );
1476
1477 let add_mixed = Expression::Add(Box::new(BinaryOp::new(
1479 make_int_literal(1),
1480 make_float_literal(2.5), )));
1482 assert_eq!(
1483 annotator.annotate(&add_mixed),
1484 Some(DataType::Double {
1485 precision: None,
1486 scale: None
1487 })
1488 );
1489 }
1490
1491 #[test]
1492 fn test_string_concat_type() {
1493 let mut annotator = TypeAnnotator::new(None, None);
1494
1495 let concat = Expression::Concat(Box::new(BinaryOp::new(
1497 make_string_literal("hello"),
1498 make_string_literal(" world"),
1499 )));
1500 assert_eq!(
1501 annotator.annotate(&concat),
1502 Some(DataType::VarChar {
1503 length: None,
1504 parenthesized_length: false
1505 })
1506 );
1507 }
1508
1509 #[test]
1510 fn test_cast_type() {
1511 let mut annotator = TypeAnnotator::new(None, None);
1512
1513 let cast = Expression::Cast(Box::new(Cast {
1515 this: make_int_literal(1),
1516 to: DataType::VarChar {
1517 length: Some(10),
1518 parenthesized_length: false,
1519 },
1520 trailing_comments: vec![],
1521 double_colon_syntax: false,
1522 format: None,
1523 default: None,
1524 inferred_type: None,
1525 }));
1526 assert_eq!(
1527 annotator.annotate(&cast),
1528 Some(DataType::VarChar {
1529 length: Some(10),
1530 parenthesized_length: false
1531 })
1532 );
1533 }
1534
1535 #[test]
1536 fn test_function_types() {
1537 let mut annotator = TypeAnnotator::new(None, None);
1538
1539 let count =
1541 Expression::Function(Box::new(Function::new("COUNT", vec![make_int_literal(1)])));
1542 assert_eq!(
1543 annotator.annotate(&count),
1544 Some(DataType::BigInt { length: None })
1545 );
1546
1547 let upper = Expression::Function(Box::new(Function::new(
1549 "UPPER",
1550 vec![make_string_literal("hello")],
1551 )));
1552 assert_eq!(
1553 annotator.annotate(&upper),
1554 Some(DataType::VarChar {
1555 length: None,
1556 parenthesized_length: false
1557 })
1558 );
1559
1560 let now = Expression::Function(Box::new(Function::new("NOW", vec![])));
1562 assert_eq!(
1563 annotator.annotate(&now),
1564 Some(DataType::Timestamp {
1565 precision: None,
1566 timezone: false
1567 })
1568 );
1569 }
1570
1571 #[test]
1572 fn test_coalesce_type_inference() {
1573 let mut annotator = TypeAnnotator::new(None, None);
1574
1575 let coalesce = Expression::Function(Box::new(Function::new(
1577 "COALESCE",
1578 vec![Expression::Null(Null), make_int_literal(1)],
1579 )));
1580 assert_eq!(
1581 annotator.annotate(&coalesce),
1582 Some(DataType::Int {
1583 length: None,
1584 integer_spelling: false
1585 })
1586 );
1587 }
1588
1589 #[test]
1590 fn test_type_coercion_class() {
1591 assert_eq!(
1593 TypeCoercionClass::from_data_type(&DataType::VarChar {
1594 length: None,
1595 parenthesized_length: false
1596 }),
1597 Some(TypeCoercionClass::Text)
1598 );
1599 assert_eq!(
1600 TypeCoercionClass::from_data_type(&DataType::Text),
1601 Some(TypeCoercionClass::Text)
1602 );
1603
1604 assert_eq!(
1606 TypeCoercionClass::from_data_type(&DataType::Int {
1607 length: None,
1608 integer_spelling: false
1609 }),
1610 Some(TypeCoercionClass::Numeric)
1611 );
1612 assert_eq!(
1613 TypeCoercionClass::from_data_type(&DataType::Double {
1614 precision: None,
1615 scale: None
1616 }),
1617 Some(TypeCoercionClass::Numeric)
1618 );
1619
1620 assert_eq!(
1622 TypeCoercionClass::from_data_type(&DataType::Date),
1623 Some(TypeCoercionClass::Timelike)
1624 );
1625 assert_eq!(
1626 TypeCoercionClass::from_data_type(&DataType::Timestamp {
1627 precision: None,
1628 timezone: false
1629 }),
1630 Some(TypeCoercionClass::Timelike)
1631 );
1632
1633 assert_eq!(TypeCoercionClass::from_data_type(&DataType::Json), None);
1635 }
1636
1637 #[test]
1638 fn test_wider_numeric_type() {
1639 let annotator = TypeAnnotator::new(None, None);
1640
1641 let result = annotator.wider_numeric_type(
1643 &DataType::Int {
1644 length: None,
1645 integer_spelling: false,
1646 },
1647 &DataType::BigInt { length: None },
1648 );
1649 assert_eq!(result, DataType::BigInt { length: None });
1650
1651 let result = annotator.wider_numeric_type(
1653 &DataType::Float {
1654 precision: None,
1655 scale: None,
1656 real_spelling: false,
1657 },
1658 &DataType::Double {
1659 precision: None,
1660 scale: None,
1661 },
1662 );
1663 assert_eq!(
1664 result,
1665 DataType::Double {
1666 precision: None,
1667 scale: None
1668 }
1669 );
1670
1671 let result = annotator.wider_numeric_type(
1673 &DataType::Int {
1674 length: None,
1675 integer_spelling: false,
1676 },
1677 &DataType::Double {
1678 precision: None,
1679 scale: None,
1680 },
1681 );
1682 assert_eq!(
1683 result,
1684 DataType::Double {
1685 precision: None,
1686 scale: None
1687 }
1688 );
1689 }
1690
1691 #[test]
1692 fn test_aggregate_return_types() {
1693 let mut annotator = TypeAnnotator::new(None, None);
1694
1695 let sum_type = annotator.get_aggregate_return_type("SUM", &[make_int_literal(1)]);
1697 assert_eq!(sum_type, Some(DataType::BigInt { length: None }));
1698
1699 let avg_type = annotator.get_aggregate_return_type("AVG", &[make_int_literal(1)]);
1701 assert_eq!(
1702 avg_type,
1703 Some(DataType::Double {
1704 precision: None,
1705 scale: None
1706 })
1707 );
1708
1709 let min_type = annotator.get_aggregate_return_type("MIN", &[make_string_literal("a")]);
1711 assert_eq!(
1712 min_type,
1713 Some(DataType::VarChar {
1714 length: None,
1715 parenthesized_length: false
1716 })
1717 );
1718 }
1719
1720 #[test]
1721 fn test_date_literal_types() {
1722 let mut annotator = TypeAnnotator::new(None, None);
1723
1724 let date_expr = Expression::Literal(Literal::Date("2024-01-15".to_string()));
1726 assert_eq!(annotator.annotate(&date_expr), Some(DataType::Date));
1727
1728 let time_expr = Expression::Literal(Literal::Time("10:30:00".to_string()));
1730 assert_eq!(
1731 annotator.annotate(&time_expr),
1732 Some(DataType::Time {
1733 precision: None,
1734 timezone: false
1735 })
1736 );
1737
1738 let ts_expr = Expression::Literal(Literal::Timestamp("2024-01-15 10:30:00".to_string()));
1740 assert_eq!(
1741 annotator.annotate(&ts_expr),
1742 Some(DataType::Timestamp {
1743 precision: None,
1744 timezone: false
1745 })
1746 );
1747 }
1748
1749 #[test]
1750 fn test_logical_operations() {
1751 let mut annotator = TypeAnnotator::new(None, None);
1752
1753 let and_expr = Expression::And(Box::new(BinaryOp::new(
1755 make_bool_literal(true),
1756 make_bool_literal(false),
1757 )));
1758 assert_eq!(annotator.annotate(&and_expr), Some(DataType::Boolean));
1759
1760 let or_expr = Expression::Or(Box::new(BinaryOp::new(
1762 make_bool_literal(true),
1763 make_bool_literal(false),
1764 )));
1765 assert_eq!(annotator.annotate(&or_expr), Some(DataType::Boolean));
1766
1767 let not_expr = Expression::Not(Box::new(crate::expressions::UnaryOp::new(
1769 make_bool_literal(true),
1770 )));
1771 assert_eq!(annotator.annotate(¬_expr), Some(DataType::Boolean));
1772 }
1773
1774 #[test]
1779 fn test_subscript_array_type() {
1780 let mut annotator = TypeAnnotator::new(None, None);
1781
1782 let arr = Expression::Array(Box::new(crate::expressions::Array {
1784 expressions: vec![make_int_literal(1), make_int_literal(2)],
1785 }));
1786 let subscript = Expression::Subscript(Box::new(crate::expressions::Subscript {
1787 this: arr,
1788 index: make_int_literal(0),
1789 }));
1790 assert_eq!(
1791 annotator.annotate(&subscript),
1792 Some(DataType::Int {
1793 length: None,
1794 integer_spelling: false
1795 })
1796 );
1797 }
1798
1799 #[test]
1800 fn test_subscript_map_type() {
1801 let mut annotator = TypeAnnotator::new(None, None);
1802
1803 let map = Expression::Map(Box::new(crate::expressions::Map {
1805 keys: vec![make_string_literal("a")],
1806 values: vec![make_int_literal(1)],
1807 }));
1808 let subscript = Expression::Subscript(Box::new(crate::expressions::Subscript {
1809 this: map,
1810 index: make_string_literal("a"),
1811 }));
1812 assert_eq!(
1813 annotator.annotate(&subscript),
1814 Some(DataType::Int {
1815 length: None,
1816 integer_spelling: false
1817 })
1818 );
1819 }
1820
1821 #[test]
1822 fn test_struct_type() {
1823 let mut annotator = TypeAnnotator::new(None, None);
1824
1825 let struct_expr = Expression::Struct(Box::new(crate::expressions::Struct {
1827 fields: vec![
1828 (Some("name".to_string()), make_string_literal("Alice")),
1829 (Some("age".to_string()), make_int_literal(30)),
1830 ],
1831 }));
1832 let result = annotator.annotate(&struct_expr);
1833 assert!(matches!(result, Some(DataType::Struct { fields, .. }) if fields.len() == 2));
1834 }
1835
1836 #[test]
1837 fn test_map_type() {
1838 let mut annotator = TypeAnnotator::new(None, None);
1839
1840 let map_expr = Expression::Map(Box::new(crate::expressions::Map {
1842 keys: vec![make_string_literal("a"), make_string_literal("b")],
1843 values: vec![make_int_literal(1), make_int_literal(2)],
1844 }));
1845 let result = annotator.annotate(&map_expr);
1846 assert!(matches!(
1847 result,
1848 Some(DataType::Map { key_type, value_type })
1849 if matches!(*key_type, DataType::VarChar { .. })
1850 && matches!(*value_type, DataType::Int { .. })
1851 ));
1852 }
1853
1854 #[test]
1855 fn test_explode_array_type() {
1856 let mut annotator = TypeAnnotator::new(None, None);
1857
1858 let arr = Expression::Array(Box::new(crate::expressions::Array {
1860 expressions: vec![make_int_literal(1), make_int_literal(2)],
1861 }));
1862 let explode = Expression::Explode(Box::new(crate::expressions::UnaryFunc {
1863 this: arr,
1864 original_name: None,
1865 inferred_type: None,
1866 }));
1867 assert_eq!(
1868 annotator.annotate(&explode),
1869 Some(DataType::Int {
1870 length: None,
1871 integer_spelling: false
1872 })
1873 );
1874 }
1875
1876 #[test]
1877 fn test_unnest_array_type() {
1878 let mut annotator = TypeAnnotator::new(None, None);
1879
1880 let arr = Expression::Array(Box::new(crate::expressions::Array {
1882 expressions: vec![make_string_literal("a"), make_string_literal("b")],
1883 }));
1884 let unnest = Expression::Unnest(Box::new(crate::expressions::UnnestFunc {
1885 this: arr,
1886 expressions: Vec::new(),
1887 with_ordinality: false,
1888 alias: None,
1889 offset_alias: None,
1890 }));
1891 assert_eq!(
1892 annotator.annotate(&unnest),
1893 Some(DataType::VarChar {
1894 length: None,
1895 parenthesized_length: false
1896 })
1897 );
1898 }
1899
1900 #[test]
1901 fn test_set_operation_type() {
1902 let mut annotator = TypeAnnotator::new(None, None);
1903
1904 let select = Expression::Select(Box::new(crate::expressions::Select::default()));
1906 let union = Expression::Union(Box::new(crate::expressions::Union {
1907 left: select.clone(),
1908 right: select.clone(),
1909 all: false,
1910 distinct: false,
1911 with: None,
1912 order_by: None,
1913 limit: None,
1914 offset: None,
1915 by_name: false,
1916 side: None,
1917 kind: None,
1918 corresponding: false,
1919 strict: false,
1920 on_columns: Vec::new(),
1921 distribute_by: None,
1922 sort_by: None,
1923 cluster_by: None,
1924 }));
1925 assert_eq!(annotator.annotate(&union), None);
1926 }
1927
1928 #[test]
1929 fn test_floor_ceil_input_dependent_types() {
1930 use crate::expressions::{CeilFunc, FloorFunc};
1931
1932 let mut annotator = TypeAnnotator::new(None, None);
1933
1934 let floor_int = Expression::Floor(Box::new(FloorFunc {
1936 this: make_int_literal(42),
1937 scale: None,
1938 to: None,
1939 }));
1940 assert_eq!(
1941 annotator.annotate(&floor_int),
1942 Some(DataType::Double {
1943 precision: None,
1944 scale: None,
1945 })
1946 );
1947
1948 let ceil_int = Expression::Ceil(Box::new(CeilFunc {
1949 this: make_int_literal(42),
1950 decimals: None,
1951 to: None,
1952 }));
1953 assert_eq!(
1954 annotator.annotate(&ceil_int),
1955 Some(DataType::Double {
1956 precision: None,
1957 scale: None,
1958 })
1959 );
1960
1961 let floor_float = Expression::Floor(Box::new(FloorFunc {
1963 this: make_float_literal(3.14),
1964 scale: None,
1965 to: None,
1966 }));
1967 assert_eq!(
1968 annotator.annotate(&floor_float),
1969 Some(DataType::Double {
1970 precision: None,
1971 scale: None,
1972 })
1973 );
1974
1975 let floor_fn =
1977 Expression::Function(Box::new(Function::new("FLOOR", vec![make_int_literal(1)])));
1978 assert_eq!(
1979 annotator.annotate(&floor_fn),
1980 Some(DataType::Int {
1981 length: None,
1982 integer_spelling: false,
1983 })
1984 );
1985 }
1986
1987 #[test]
1988 fn test_sign_preserves_input_type() {
1989 use crate::expressions::UnaryFunc;
1990
1991 let mut annotator = TypeAnnotator::new(None, None);
1992
1993 let sign_int = Expression::Sign(Box::new(UnaryFunc {
1995 this: make_int_literal(42),
1996 original_name: None,
1997 inferred_type: None,
1998 }));
1999 assert_eq!(
2000 annotator.annotate(&sign_int),
2001 Some(DataType::Int {
2002 length: None,
2003 integer_spelling: false,
2004 })
2005 );
2006
2007 let sign_float = Expression::Sign(Box::new(UnaryFunc {
2009 this: make_float_literal(3.14),
2010 original_name: None,
2011 inferred_type: None,
2012 }));
2013 assert_eq!(
2014 annotator.annotate(&sign_float),
2015 Some(DataType::Double {
2016 precision: None,
2017 scale: None,
2018 })
2019 );
2020
2021 let sign_cast = Expression::Sign(Box::new(UnaryFunc {
2023 this: Expression::Cast(Box::new(Cast {
2024 this: make_int_literal(42),
2025 to: DataType::Int {
2026 length: None,
2027 integer_spelling: false,
2028 },
2029 format: None,
2030 trailing_comments: Vec::new(),
2031 double_colon_syntax: false,
2032 default: None,
2033 inferred_type: None,
2034 })),
2035 original_name: None,
2036 inferred_type: None,
2037 }));
2038 assert_eq!(
2039 annotator.annotate(&sign_cast),
2040 Some(DataType::Int {
2041 length: None,
2042 integer_spelling: false,
2043 })
2044 );
2045 }
2046
2047 #[test]
2048 fn test_date_format_types() {
2049 use crate::expressions::{DateFormatFunc, TimeToStr};
2050
2051 let mut annotator = TypeAnnotator::new(None, None);
2052
2053 let date_fmt = Expression::DateFormat(Box::new(DateFormatFunc {
2055 this: make_string_literal("2024-01-01"),
2056 format: make_string_literal("%Y-%m-%d"),
2057 }));
2058 assert_eq!(
2059 annotator.annotate(&date_fmt),
2060 Some(DataType::VarChar {
2061 length: None,
2062 parenthesized_length: false,
2063 })
2064 );
2065
2066 let format_date = Expression::FormatDate(Box::new(DateFormatFunc {
2068 this: make_string_literal("2024-01-01"),
2069 format: make_string_literal("%Y-%m-%d"),
2070 }));
2071 assert_eq!(
2072 annotator.annotate(&format_date),
2073 Some(DataType::VarChar {
2074 length: None,
2075 parenthesized_length: false,
2076 })
2077 );
2078
2079 let time_to_str = Expression::TimeToStr(Box::new(TimeToStr {
2081 this: Box::new(make_string_literal("2024-01-01")),
2082 format: "%Y-%m-%d".to_string(),
2083 culture: None,
2084 zone: None,
2085 }));
2086 assert_eq!(
2087 annotator.annotate(&time_to_str),
2088 Some(DataType::VarChar {
2089 length: None,
2090 parenthesized_length: false,
2091 })
2092 );
2093
2094 let date_fmt_fn = Expression::Function(Box::new(Function::new(
2096 "DATE_FORMAT",
2097 vec![
2098 make_string_literal("2024-01-01"),
2099 make_string_literal("%Y-%m-%d"),
2100 ],
2101 )));
2102 assert_eq!(
2103 annotator.annotate(&date_fmt_fn),
2104 Some(DataType::VarChar {
2105 length: None,
2106 parenthesized_length: false,
2107 })
2108 );
2109 }
2110
2111 #[test]
2114 fn test_annotate_in_place_sets_type_on_root() {
2115 let mut expr = Expression::Add(Box::new(BinaryOp::new(
2117 make_int_literal(1),
2118 make_int_literal(2),
2119 )));
2120 annotate_types(&mut expr, None, None);
2121 assert_eq!(
2122 expr.inferred_type(),
2123 Some(&DataType::Int {
2124 length: None,
2125 integer_spelling: false,
2126 })
2127 );
2128 }
2129
2130 #[test]
2131 fn test_annotate_in_place_sets_types_on_children() {
2132 let inner_add = Expression::Add(Box::new(BinaryOp::new(
2135 make_int_literal(1),
2136 make_float_literal(2.5),
2137 )));
2138 let inner_sub = Expression::Sub(Box::new(BinaryOp::new(
2139 make_int_literal(3),
2140 make_int_literal(4),
2141 )));
2142 let mut expr = Expression::Add(Box::new(BinaryOp::new(inner_add, inner_sub)));
2143 annotate_types(&mut expr, None, None);
2144
2145 assert_eq!(
2147 expr.inferred_type(),
2148 Some(&DataType::Double {
2149 precision: None,
2150 scale: None,
2151 })
2152 );
2153
2154 if let Expression::Add(op) = &expr {
2156 assert_eq!(
2158 op.left.inferred_type(),
2159 Some(&DataType::Double {
2160 precision: None,
2161 scale: None,
2162 })
2163 );
2164 assert_eq!(
2166 op.right.inferred_type(),
2167 Some(&DataType::Int {
2168 length: None,
2169 integer_spelling: false,
2170 })
2171 );
2172 } else {
2173 panic!("Expected Add expression");
2174 }
2175 }
2176
2177 #[test]
2178 fn test_annotate_in_place_comparison() {
2179 let mut expr = Expression::Eq(Box::new(BinaryOp::new(
2180 make_int_literal(1),
2181 make_int_literal(2),
2182 )));
2183 annotate_types(&mut expr, None, None);
2184 assert_eq!(expr.inferred_type(), Some(&DataType::Boolean));
2185 }
2186
2187 #[test]
2188 fn test_annotate_in_place_cast() {
2189 let mut expr = Expression::Cast(Box::new(Cast {
2190 this: make_int_literal(42),
2191 to: DataType::VarChar {
2192 length: None,
2193 parenthesized_length: false,
2194 },
2195 trailing_comments: vec![],
2196 double_colon_syntax: false,
2197 format: None,
2198 default: None,
2199 inferred_type: None,
2200 }));
2201 annotate_types(&mut expr, None, None);
2202 assert_eq!(
2203 expr.inferred_type(),
2204 Some(&DataType::VarChar {
2205 length: None,
2206 parenthesized_length: false,
2207 })
2208 );
2209 }
2210
2211 #[test]
2212 fn test_annotate_in_place_nested_expression() {
2213 let add = Expression::Add(Box::new(BinaryOp::new(
2215 make_int_literal(1),
2216 make_int_literal(2),
2217 )));
2218 let mut expr = Expression::Gt(Box::new(BinaryOp::new(add, make_int_literal(0))));
2219 annotate_types(&mut expr, None, None);
2220
2221 assert_eq!(expr.inferred_type(), Some(&DataType::Boolean));
2222
2223 if let Expression::Gt(op) = &expr {
2225 assert_eq!(
2226 op.left.inferred_type(),
2227 Some(&DataType::Int {
2228 length: None,
2229 integer_spelling: false,
2230 })
2231 );
2232 }
2233 }
2234
2235 #[test]
2236 fn test_annotate_in_place_parsed_sql() {
2237 use crate::parser::Parser;
2238 let mut expr =
2239 Parser::parse_sql("SELECT 1 + 2.0, 'hello', TRUE").expect("parse failed")[0].clone();
2240 annotate_types(&mut expr, None, None);
2241
2242 assert!(expr.inferred_type().is_none());
2246 }
2247
2248 #[test]
2249 fn test_inferred_type_json_roundtrip() {
2250 let mut expr = Expression::Add(Box::new(BinaryOp::new(
2251 make_int_literal(1),
2252 make_int_literal(2),
2253 )));
2254 annotate_types(&mut expr, None, None);
2255
2256 let json = serde_json::to_string(&expr).expect("serialize failed");
2258 assert!(json.contains("inferred_type"));
2260
2261 let deserialized: Expression = serde_json::from_str(&json).expect("deserialize failed");
2263 assert_eq!(
2264 deserialized.inferred_type(),
2265 Some(&DataType::Int {
2266 length: None,
2267 integer_spelling: false,
2268 })
2269 );
2270 }
2271
2272 #[test]
2273 fn test_inferred_type_none_not_serialized() {
2274 let expr = Expression::Add(Box::new(BinaryOp::new(
2276 make_int_literal(1),
2277 make_int_literal(2),
2278 )));
2279 let json = serde_json::to_string(&expr).expect("serialize failed");
2280 assert!(!json.contains("inferred_type"));
2281 }
2282}