1use std::collections::HashMap;
13
14use crate::dialects::DialectType;
15use crate::expressions::{
16 BinaryOp, DataType, Expression, Function, Literal, Map, Struct, StructField, Subscript,
17};
18use crate::schema::Schema;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
23pub enum TypeCoercionClass {
24 Text = 0,
26 Numeric = 1,
28 Timelike = 2,
30}
31
32impl TypeCoercionClass {
33 pub fn from_data_type(dt: &DataType) -> Option<Self> {
35 match dt {
36 DataType::Char { .. }
38 | DataType::VarChar { .. }
39 | DataType::Text
40 | DataType::Binary { .. }
41 | DataType::VarBinary { .. }
42 | DataType::Blob => Some(TypeCoercionClass::Text),
43
44 DataType::Boolean
46 | DataType::TinyInt { .. }
47 | DataType::SmallInt { .. }
48 | DataType::Int { .. }
49 | DataType::BigInt { .. }
50 | DataType::Float { .. }
51 | DataType::Double { .. }
52 | DataType::Decimal { .. } => Some(TypeCoercionClass::Numeric),
53
54 DataType::Date
56 | DataType::Time { .. }
57 | DataType::Timestamp { .. }
58 | DataType::Interval { .. } => Some(TypeCoercionClass::Timelike),
59
60 _ => None,
62 }
63 }
64}
65
66pub struct TypeAnnotator<'a> {
68 _schema: Option<&'a dyn Schema>,
70 _dialect: Option<DialectType>,
72 annotate_aggregates: bool,
74 function_return_types: HashMap<String, DataType>,
76}
77
78impl<'a> TypeAnnotator<'a> {
79 pub fn new(schema: Option<&'a dyn Schema>, dialect: Option<DialectType>) -> Self {
81 let mut annotator = Self {
82 _schema: schema,
83 _dialect: dialect,
84 annotate_aggregates: true,
85 function_return_types: HashMap::new(),
86 };
87 annotator.init_function_return_types();
88 annotator
89 }
90
91 fn init_function_return_types(&mut self) {
93 self.function_return_types
95 .insert("COUNT".to_string(), DataType::BigInt { length: None });
96 self.function_return_types.insert(
97 "SUM".to_string(),
98 DataType::Decimal {
99 precision: None,
100 scale: None,
101 },
102 );
103 self.function_return_types.insert(
104 "AVG".to_string(),
105 DataType::Double {
106 precision: None,
107 scale: None,
108 },
109 );
110
111 self.function_return_types.insert(
113 "CONCAT".to_string(),
114 DataType::VarChar {
115 length: None,
116 parenthesized_length: false,
117 },
118 );
119 self.function_return_types.insert(
120 "UPPER".to_string(),
121 DataType::VarChar {
122 length: None,
123 parenthesized_length: false,
124 },
125 );
126 self.function_return_types.insert(
127 "LOWER".to_string(),
128 DataType::VarChar {
129 length: None,
130 parenthesized_length: false,
131 },
132 );
133 self.function_return_types.insert(
134 "TRIM".to_string(),
135 DataType::VarChar {
136 length: None,
137 parenthesized_length: false,
138 },
139 );
140 self.function_return_types.insert(
141 "LTRIM".to_string(),
142 DataType::VarChar {
143 length: None,
144 parenthesized_length: false,
145 },
146 );
147 self.function_return_types.insert(
148 "RTRIM".to_string(),
149 DataType::VarChar {
150 length: None,
151 parenthesized_length: false,
152 },
153 );
154 self.function_return_types.insert(
155 "SUBSTRING".to_string(),
156 DataType::VarChar {
157 length: None,
158 parenthesized_length: false,
159 },
160 );
161 self.function_return_types.insert(
162 "SUBSTR".to_string(),
163 DataType::VarChar {
164 length: None,
165 parenthesized_length: false,
166 },
167 );
168 self.function_return_types.insert(
169 "REPLACE".to_string(),
170 DataType::VarChar {
171 length: None,
172 parenthesized_length: false,
173 },
174 );
175 self.function_return_types.insert(
176 "LENGTH".to_string(),
177 DataType::Int {
178 length: None,
179 integer_spelling: false,
180 },
181 );
182 self.function_return_types.insert(
183 "CHAR_LENGTH".to_string(),
184 DataType::Int {
185 length: None,
186 integer_spelling: false,
187 },
188 );
189
190 self.function_return_types.insert(
192 "NOW".to_string(),
193 DataType::Timestamp {
194 precision: None,
195 timezone: false,
196 },
197 );
198 self.function_return_types.insert(
199 "CURRENT_TIMESTAMP".to_string(),
200 DataType::Timestamp {
201 precision: None,
202 timezone: false,
203 },
204 );
205 self.function_return_types
206 .insert("CURRENT_DATE".to_string(), DataType::Date);
207 self.function_return_types.insert(
208 "CURRENT_TIME".to_string(),
209 DataType::Time {
210 precision: None,
211 timezone: false,
212 },
213 );
214 self.function_return_types
215 .insert("DATE".to_string(), DataType::Date);
216 self.function_return_types.insert(
217 "YEAR".to_string(),
218 DataType::Int {
219 length: None,
220 integer_spelling: false,
221 },
222 );
223 self.function_return_types.insert(
224 "MONTH".to_string(),
225 DataType::Int {
226 length: None,
227 integer_spelling: false,
228 },
229 );
230 self.function_return_types.insert(
231 "DAY".to_string(),
232 DataType::Int {
233 length: None,
234 integer_spelling: false,
235 },
236 );
237 self.function_return_types.insert(
238 "HOUR".to_string(),
239 DataType::Int {
240 length: None,
241 integer_spelling: false,
242 },
243 );
244 self.function_return_types.insert(
245 "MINUTE".to_string(),
246 DataType::Int {
247 length: None,
248 integer_spelling: false,
249 },
250 );
251 self.function_return_types.insert(
252 "SECOND".to_string(),
253 DataType::Int {
254 length: None,
255 integer_spelling: false,
256 },
257 );
258 self.function_return_types.insert(
259 "EXTRACT".to_string(),
260 DataType::Int {
261 length: None,
262 integer_spelling: false,
263 },
264 );
265 self.function_return_types.insert(
266 "DATE_DIFF".to_string(),
267 DataType::Int {
268 length: None,
269 integer_spelling: false,
270 },
271 );
272 self.function_return_types.insert(
273 "DATEDIFF".to_string(),
274 DataType::Int {
275 length: None,
276 integer_spelling: false,
277 },
278 );
279
280 self.function_return_types.insert(
282 "ABS".to_string(),
283 DataType::Double {
284 precision: None,
285 scale: None,
286 },
287 );
288 self.function_return_types.insert(
289 "ROUND".to_string(),
290 DataType::Double {
291 precision: None,
292 scale: None,
293 },
294 );
295 self.function_return_types
296 .insert("FLOOR".to_string(), DataType::BigInt { length: None });
297 self.function_return_types
298 .insert("CEIL".to_string(), DataType::BigInt { length: None });
299 self.function_return_types
300 .insert("CEILING".to_string(), DataType::BigInt { length: None });
301 self.function_return_types.insert(
302 "SQRT".to_string(),
303 DataType::Double {
304 precision: None,
305 scale: None,
306 },
307 );
308 self.function_return_types.insert(
309 "POWER".to_string(),
310 DataType::Double {
311 precision: None,
312 scale: None,
313 },
314 );
315 self.function_return_types.insert(
316 "MOD".to_string(),
317 DataType::Int {
318 length: None,
319 integer_spelling: false,
320 },
321 );
322 self.function_return_types.insert(
323 "LOG".to_string(),
324 DataType::Double {
325 precision: None,
326 scale: None,
327 },
328 );
329 self.function_return_types.insert(
330 "LN".to_string(),
331 DataType::Double {
332 precision: None,
333 scale: None,
334 },
335 );
336 self.function_return_types.insert(
337 "EXP".to_string(),
338 DataType::Double {
339 precision: None,
340 scale: None,
341 },
342 );
343
344 self.function_return_types
346 .insert("COALESCE".to_string(), DataType::Unknown);
347 self.function_return_types
348 .insert("NULLIF".to_string(), DataType::Unknown);
349 self.function_return_types
350 .insert("GREATEST".to_string(), DataType::Unknown);
351 self.function_return_types
352 .insert("LEAST".to_string(), DataType::Unknown);
353 }
354
355 pub fn annotate(&mut self, expr: &Expression) -> Option<DataType> {
357 match expr {
358 Expression::Literal(lit) => self.annotate_literal(lit),
360 Expression::Boolean(_) => Some(DataType::Boolean),
361 Expression::Null(_) => None, Expression::Add(op)
365 | Expression::Sub(op)
366 | Expression::Mul(op)
367 | Expression::Div(op)
368 | Expression::Mod(op) => self.annotate_arithmetic(op),
369
370 Expression::Eq(_)
372 | Expression::Neq(_)
373 | Expression::Lt(_)
374 | Expression::Lte(_)
375 | Expression::Gt(_)
376 | Expression::Gte(_)
377 | Expression::Like(_)
378 | Expression::ILike(_) => Some(DataType::Boolean),
379
380 Expression::And(_) | Expression::Or(_) | Expression::Not(_) => Some(DataType::Boolean),
382
383 Expression::Between(_)
385 | Expression::In(_)
386 | Expression::IsNull(_)
387 | Expression::IsTrue(_)
388 | Expression::IsFalse(_)
389 | Expression::Is(_)
390 | Expression::Exists(_) => Some(DataType::Boolean),
391
392 Expression::Concat(_) => Some(DataType::VarChar {
394 length: None,
395 parenthesized_length: false,
396 }),
397
398 Expression::BitwiseAnd(_)
400 | Expression::BitwiseOr(_)
401 | Expression::BitwiseXor(_)
402 | Expression::BitwiseNot(_) => Some(DataType::BigInt { length: None }),
403
404 Expression::Neg(op) => self.annotate(&op.this),
406
407 Expression::Function(func) => self.annotate_function(func),
409
410 Expression::Count(_) => Some(DataType::BigInt { length: None }),
412 Expression::Sum(agg) => self.annotate_sum(&agg.this),
413 Expression::Avg(_) => Some(DataType::Double {
414 precision: None,
415 scale: None,
416 }),
417 Expression::Min(agg) => self.annotate(&agg.this),
418 Expression::Max(agg) => self.annotate(&agg.this),
419 Expression::GroupConcat(_) | Expression::StringAgg(_) | Expression::ListAgg(_) => {
420 Some(DataType::VarChar {
421 length: None,
422 parenthesized_length: false,
423 })
424 }
425
426 Expression::AggregateFunction(agg) => {
428 if !self.annotate_aggregates {
429 return None;
430 }
431 let func_name = agg.name.to_uppercase();
432 self.get_aggregate_return_type(&func_name, &agg.args)
433 }
434
435 Expression::Column(col) => {
437 if let Some(schema) = &self._schema {
438 let table_name = col.table.as_ref().map(|t| t.name.as_str()).unwrap_or("");
439 schema.get_column_type(table_name, &col.name.name).ok()
440 } else {
441 None
442 }
443 }
444
445 Expression::Cast(cast) => Some(cast.to.clone()),
447 Expression::SafeCast(cast) => Some(cast.to.clone()),
448 Expression::TryCast(cast) => Some(cast.to.clone()),
449
450 Expression::Subquery(subq) => {
452 if let Expression::Select(select) = &subq.this {
453 if let Some(first) = select.expressions.first() {
454 self.annotate(first)
455 } else {
456 None
457 }
458 } else {
459 None
460 }
461 }
462
463 Expression::Case(case) => {
465 if let Some(else_expr) = &case.else_ {
466 self.annotate(else_expr)
467 } else if let Some((_, then_expr)) = case.whens.first() {
468 self.annotate(then_expr)
469 } else {
470 None
471 }
472 }
473
474 Expression::Array(arr) => {
476 if let Some(first) = arr.expressions.first() {
477 if let Some(elem_type) = self.annotate(first) {
478 Some(DataType::Array {
479 element_type: Box::new(elem_type),
480 dimension: None,
481 })
482 } else {
483 Some(DataType::Array {
484 element_type: Box::new(DataType::Unknown),
485 dimension: None,
486 })
487 }
488 } else {
489 Some(DataType::Array {
490 element_type: Box::new(DataType::Unknown),
491 dimension: None,
492 })
493 }
494 }
495
496 Expression::Interval(_) => Some(DataType::Interval {
498 unit: None,
499 to: None,
500 }),
501
502 Expression::WindowFunction(window) => self.annotate(&window.this),
504
505 Expression::CurrentDate(_) => Some(DataType::Date),
507 Expression::CurrentTime(_) => Some(DataType::Time {
508 precision: None,
509 timezone: false,
510 }),
511 Expression::CurrentTimestamp(_) | Expression::CurrentTimestampLTZ(_) => {
512 Some(DataType::Timestamp {
513 precision: None,
514 timezone: false,
515 })
516 }
517
518 Expression::DateAdd(_)
520 | Expression::DateSub(_)
521 | Expression::ToDate(_)
522 | Expression::Date(_) => Some(DataType::Date),
523 Expression::DateDiff(_) | Expression::Extract(_) => Some(DataType::Int {
524 length: None,
525 integer_spelling: false,
526 }),
527 Expression::ToTimestamp(_) => Some(DataType::Timestamp {
528 precision: None,
529 timezone: false,
530 }),
531
532 Expression::Upper(_)
534 | Expression::Lower(_)
535 | Expression::Trim(_)
536 | Expression::LTrim(_)
537 | Expression::RTrim(_)
538 | Expression::Replace(_)
539 | Expression::Substring(_)
540 | Expression::Reverse(_)
541 | Expression::Left(_)
542 | Expression::Right(_)
543 | Expression::Repeat(_)
544 | Expression::Lpad(_)
545 | Expression::Rpad(_)
546 | Expression::ConcatWs(_)
547 | Expression::Overlay(_) => Some(DataType::VarChar {
548 length: None,
549 parenthesized_length: false,
550 }),
551 Expression::Length(_) => Some(DataType::Int {
552 length: None,
553 integer_spelling: false,
554 }),
555
556 Expression::Abs(_)
558 | Expression::Sqrt(_)
559 | Expression::Cbrt(_)
560 | Expression::Ln(_)
561 | Expression::Exp(_)
562 | Expression::Power(_)
563 | Expression::Log(_) => Some(DataType::Double {
564 precision: None,
565 scale: None,
566 }),
567 Expression::Round(_) => Some(DataType::Double {
568 precision: None,
569 scale: None,
570 }),
571 Expression::Floor(_) | Expression::Ceil(_) | Expression::Sign(_) => {
572 Some(DataType::BigInt { length: None })
573 }
574
575 Expression::Greatest(v) | Expression::Least(v) => self.coerce_arg_types(&v.expressions),
577
578 Expression::Alias(alias) => self.annotate(&alias.this),
580
581 Expression::Select(_) => None,
583
584 Expression::Subscript(sub) => self.annotate_subscript(sub),
588
589 Expression::Dot(_) => None,
591
592 Expression::Struct(s) => self.annotate_struct(s),
596
597 Expression::Map(map) => self.annotate_map(map),
601 Expression::MapFromEntries(mfe) => {
602 if let Some(DataType::Array { element_type, .. }) = self.annotate(&mfe.this) {
604 if let DataType::Struct { fields, .. } = *element_type {
605 if fields.len() >= 2 {
606 return Some(DataType::Map {
607 key_type: Box::new(fields[0].data_type.clone()),
608 value_type: Box::new(fields[1].data_type.clone()),
609 });
610 }
611 }
612 }
613 Some(DataType::Map {
614 key_type: Box::new(DataType::Unknown),
615 value_type: Box::new(DataType::Unknown),
616 })
617 }
618
619 Expression::Union(union) => self.annotate_set_operation(&union.left, &union.right),
623 Expression::Intersect(intersect) => {
624 self.annotate_set_operation(&intersect.left, &intersect.right)
625 }
626 Expression::Except(except) => self.annotate_set_operation(&except.left, &except.right),
627
628 Expression::Lateral(lateral) => {
632 self.annotate(&lateral.this)
634 }
635 Expression::LateralView(lv) => {
636 self.annotate_lateral_view(lv)
638 }
639 Expression::Unnest(unnest) => {
640 if let Some(DataType::Array { element_type, .. }) = self.annotate(&unnest.this) {
642 Some(*element_type)
643 } else {
644 None
645 }
646 }
647 Expression::Explode(explode) => {
648 if let Some(DataType::Array { element_type, .. }) = self.annotate(&explode.this) {
650 Some(*element_type)
651 } else if let Some(DataType::Map {
652 key_type,
653 value_type,
654 }) = self.annotate(&explode.this)
655 {
656 Some(DataType::Struct {
658 fields: vec![
659 StructField::new("key".to_string(), *key_type),
660 StructField::new("value".to_string(), *value_type),
661 ],
662 nested: false,
663 })
664 } else {
665 None
666 }
667 }
668 Expression::ExplodeOuter(explode) => {
669 if let Some(DataType::Array { element_type, .. }) = self.annotate(&explode.this) {
671 Some(*element_type)
672 } else {
673 None
674 }
675 }
676 Expression::GenerateSeries(gs) => {
677 if let Some(ref start) = gs.start {
679 self.annotate(start)
680 } else if let Some(ref end) = gs.end {
681 self.annotate(end)
682 } else {
683 Some(DataType::Int {
684 length: None,
685 integer_spelling: false,
686 })
687 }
688 }
689
690 _ => None,
692 }
693 }
694
695 fn annotate_subscript(&mut self, sub: &Subscript) -> Option<DataType> {
697 let base_type = self.annotate(&sub.this)?;
698
699 match base_type {
700 DataType::Array { element_type, .. } => Some(*element_type),
701 DataType::Map { value_type, .. } => Some(*value_type),
702 DataType::Json | DataType::JsonB => Some(DataType::Json), DataType::VarChar { .. } | DataType::Text => {
704 Some(DataType::VarChar {
706 length: Some(1),
707 parenthesized_length: false,
708 })
709 }
710 _ => None,
711 }
712 }
713
714 fn annotate_struct(&mut self, s: &Struct) -> Option<DataType> {
716 let fields: Vec<StructField> = s
717 .fields
718 .iter()
719 .map(|(name, expr)| {
720 let field_type = self.annotate(expr).unwrap_or(DataType::Unknown);
721 StructField::new(name.clone().unwrap_or_default(), field_type)
722 })
723 .collect();
724 Some(DataType::Struct {
725 fields,
726 nested: false,
727 })
728 }
729
730 fn annotate_map(&mut self, map: &Map) -> Option<DataType> {
732 let key_type = if let Some(first_key) = map.keys.first() {
733 self.annotate(first_key).unwrap_or(DataType::Unknown)
734 } else {
735 DataType::Unknown
736 };
737
738 let value_type = if let Some(first_value) = map.values.first() {
739 self.annotate(first_value).unwrap_or(DataType::Unknown)
740 } else {
741 DataType::Unknown
742 };
743
744 Some(DataType::Map {
745 key_type: Box::new(key_type),
746 value_type: Box::new(value_type),
747 })
748 }
749
750 fn annotate_set_operation(
753 &mut self,
754 _left: &Expression,
755 _right: &Expression,
756 ) -> Option<DataType> {
757 None
761 }
762
763 fn annotate_lateral_view(&mut self, lv: &crate::expressions::LateralView) -> Option<DataType> {
765 self.annotate(&lv.this)
767 }
768
769 fn annotate_literal(&self, lit: &Literal) -> Option<DataType> {
771 match lit {
772 Literal::String(_)
773 | Literal::NationalString(_)
774 | Literal::TripleQuotedString(_, _)
775 | Literal::EscapeString(_)
776 | Literal::DollarString(_)
777 | Literal::RawString(_) => Some(DataType::VarChar {
778 length: None,
779 parenthesized_length: false,
780 }),
781 Literal::Number(n) => {
782 if n.contains('.') || n.contains('e') || n.contains('E') {
784 Some(DataType::Double {
785 precision: None,
786 scale: None,
787 })
788 } else {
789 if let Ok(_) = n.parse::<i32>() {
791 Some(DataType::Int {
792 length: None,
793 integer_spelling: false,
794 })
795 } else {
796 Some(DataType::BigInt { length: None })
797 }
798 }
799 }
800 Literal::HexString(_) | Literal::BitString(_) | Literal::ByteString(_) => {
801 Some(DataType::VarBinary { length: None })
802 }
803 Literal::HexNumber(_) => Some(DataType::BigInt { length: None }),
804 Literal::Date(_) => Some(DataType::Date),
805 Literal::Time(_) => Some(DataType::Time {
806 precision: None,
807 timezone: false,
808 }),
809 Literal::Timestamp(_) => Some(DataType::Timestamp {
810 precision: None,
811 timezone: false,
812 }),
813 Literal::Datetime(_) => Some(DataType::Custom {
814 name: "DATETIME".to_string(),
815 }),
816 }
817 }
818
819 fn annotate_arithmetic(&mut self, op: &BinaryOp) -> Option<DataType> {
821 let left_type = self.annotate(&op.left);
822 let right_type = self.annotate(&op.right);
823
824 match (left_type, right_type) {
825 (Some(l), Some(r)) => self.coerce_types(&l, &r),
826 (Some(t), None) | (None, Some(t)) => Some(t),
827 (None, None) => None,
828 }
829 }
830
831 fn annotate_function(&mut self, func: &Function) -> Option<DataType> {
833 let func_name = func.name.to_uppercase();
834
835 if let Some(return_type) = self.function_return_types.get(&func_name) {
837 if *return_type != DataType::Unknown {
838 return Some(return_type.clone());
839 }
840 }
841
842 match func_name.as_str() {
844 "COALESCE" | "IFNULL" | "NVL" | "ISNULL" => {
845 for arg in &func.args {
847 if let Some(arg_type) = self.annotate(arg) {
848 return Some(arg_type);
849 }
850 }
851 None
852 }
853 "NULLIF" => {
854 func.args.first().and_then(|arg| self.annotate(arg))
856 }
857 "GREATEST" | "LEAST" => {
858 self.coerce_arg_types(&func.args)
860 }
861 "IF" | "IIF" => {
862 if func.args.len() >= 2 {
864 self.annotate(&func.args[1])
865 } else {
866 None
867 }
868 }
869 _ => {
870 func.args.first().and_then(|arg| self.annotate(arg))
872 }
873 }
874 }
875
876 fn get_aggregate_return_type(
878 &mut self,
879 func_name: &str,
880 args: &[Expression],
881 ) -> Option<DataType> {
882 match func_name {
883 "COUNT" | "COUNT_IF" => Some(DataType::BigInt { length: None }),
884 "SUM" => {
885 if let Some(arg) = args.first() {
886 self.annotate_sum(arg)
887 } else {
888 Some(DataType::Decimal {
889 precision: None,
890 scale: None,
891 })
892 }
893 }
894 "AVG" => Some(DataType::Double {
895 precision: None,
896 scale: None,
897 }),
898 "MIN" | "MAX" => {
899 args.first().and_then(|arg| self.annotate(arg))
901 }
902 "STRING_AGG" | "GROUP_CONCAT" | "LISTAGG" | "ARRAY_AGG" => Some(DataType::VarChar {
903 length: None,
904 parenthesized_length: false,
905 }),
906 "BOOL_AND" | "BOOL_OR" | "EVERY" | "ANY" | "SOME" => Some(DataType::Boolean),
907 "BIT_AND" | "BIT_OR" | "BIT_XOR" => Some(DataType::BigInt { length: None }),
908 "STDDEV" | "STDDEV_POP" | "STDDEV_SAMP" | "VARIANCE" | "VAR_POP" | "VAR_SAMP" => {
909 Some(DataType::Double {
910 precision: None,
911 scale: None,
912 })
913 }
914 "PERCENTILE_CONT" | "PERCENTILE_DISC" | "MEDIAN" => {
915 args.first().and_then(|arg| self.annotate(arg))
916 }
917 _ => None,
918 }
919 }
920
921 fn annotate_sum(&mut self, arg: &Expression) -> Option<DataType> {
923 match self.annotate(arg) {
924 Some(DataType::TinyInt { .. })
925 | Some(DataType::SmallInt { .. })
926 | Some(DataType::Int { .. }) => Some(DataType::BigInt { length: None }),
927 Some(DataType::BigInt { .. }) => Some(DataType::BigInt { length: None }),
928 Some(DataType::Float { .. }) | Some(DataType::Double { .. }) => {
929 Some(DataType::Double {
930 precision: None,
931 scale: None,
932 })
933 }
934 Some(DataType::Decimal { precision, scale }) => {
935 Some(DataType::Decimal { precision, scale })
936 }
937 _ => Some(DataType::Decimal {
938 precision: None,
939 scale: None,
940 }),
941 }
942 }
943
944 fn coerce_arg_types(&mut self, args: &[Expression]) -> Option<DataType> {
946 let mut result_type: Option<DataType> = None;
947 for arg in args {
948 if let Some(arg_type) = self.annotate(arg) {
949 result_type = match result_type {
950 Some(t) => self.coerce_types(&t, &arg_type),
951 None => Some(arg_type),
952 };
953 }
954 }
955 result_type
956 }
957
958 fn coerce_types(&self, left: &DataType, right: &DataType) -> Option<DataType> {
960 if left == right {
962 return Some(left.clone());
963 }
964
965 match (left, right) {
967 (DataType::Date, DataType::Interval { .. })
968 | (DataType::Interval { .. }, DataType::Date) => return Some(DataType::Date),
969 (
970 DataType::Timestamp {
971 precision,
972 timezone,
973 },
974 DataType::Interval { .. },
975 )
976 | (
977 DataType::Interval { .. },
978 DataType::Timestamp {
979 precision,
980 timezone,
981 },
982 ) => {
983 return Some(DataType::Timestamp {
984 precision: *precision,
985 timezone: *timezone,
986 });
987 }
988 _ => {}
989 }
990
991 let left_class = TypeCoercionClass::from_data_type(left);
993 let right_class = TypeCoercionClass::from_data_type(right);
994
995 match (left_class, right_class) {
996 (Some(lc), Some(rc)) if lc == rc => {
998 if lc == TypeCoercionClass::Numeric {
1000 Some(self.wider_numeric_type(left, right))
1001 } else {
1002 Some(left.clone())
1004 }
1005 }
1006 (Some(lc), Some(rc)) => {
1008 if lc > rc {
1009 Some(left.clone())
1010 } else {
1011 Some(right.clone())
1012 }
1013 }
1014 (Some(_), None) => Some(left.clone()),
1016 (None, Some(_)) => Some(right.clone()),
1017 (None, None) => Some(DataType::Unknown),
1019 }
1020 }
1021
1022 fn wider_numeric_type(&self, left: &DataType, right: &DataType) -> DataType {
1024 let order = |dt: &DataType| -> u8 {
1025 match dt {
1026 DataType::Boolean => 0,
1027 DataType::TinyInt { .. } => 1,
1028 DataType::SmallInt { .. } => 2,
1029 DataType::Int { .. } => 3,
1030 DataType::BigInt { .. } => 4,
1031 DataType::Float { .. } => 5,
1032 DataType::Double { .. } => 6,
1033 DataType::Decimal { .. } => 7,
1034 _ => 0,
1035 }
1036 };
1037
1038 if order(left) >= order(right) {
1039 left.clone()
1040 } else {
1041 right.clone()
1042 }
1043 }
1044}
1045
1046pub fn annotate_types(
1048 expr: &Expression,
1049 schema: Option<&dyn Schema>,
1050 dialect: Option<DialectType>,
1051) -> Option<DataType> {
1052 let mut annotator = TypeAnnotator::new(schema, dialect);
1053 annotator.annotate(expr)
1054}
1055
1056#[cfg(test)]
1057mod tests {
1058 use super::*;
1059 use crate::expressions::{BooleanLiteral, Cast, Null};
1060
1061 fn make_int_literal(val: i64) -> Expression {
1062 Expression::Literal(Literal::Number(val.to_string()))
1063 }
1064
1065 fn make_float_literal(val: f64) -> Expression {
1066 Expression::Literal(Literal::Number(val.to_string()))
1067 }
1068
1069 fn make_string_literal(val: &str) -> Expression {
1070 Expression::Literal(Literal::String(val.to_string()))
1071 }
1072
1073 fn make_bool_literal(val: bool) -> Expression {
1074 Expression::Boolean(BooleanLiteral { value: val })
1075 }
1076
1077 #[test]
1078 fn test_literal_types() {
1079 let mut annotator = TypeAnnotator::new(None, None);
1080
1081 let int_expr = make_int_literal(42);
1083 assert_eq!(
1084 annotator.annotate(&int_expr),
1085 Some(DataType::Int {
1086 length: None,
1087 integer_spelling: false
1088 })
1089 );
1090
1091 let float_expr = make_float_literal(3.14);
1093 assert_eq!(
1094 annotator.annotate(&float_expr),
1095 Some(DataType::Double {
1096 precision: None,
1097 scale: None
1098 })
1099 );
1100
1101 let string_expr = make_string_literal("hello");
1103 assert_eq!(
1104 annotator.annotate(&string_expr),
1105 Some(DataType::VarChar {
1106 length: None,
1107 parenthesized_length: false
1108 })
1109 );
1110
1111 let bool_expr = make_bool_literal(true);
1113 assert_eq!(annotator.annotate(&bool_expr), Some(DataType::Boolean));
1114
1115 let null_expr = Expression::Null(Null);
1117 assert_eq!(annotator.annotate(&null_expr), None);
1118 }
1119
1120 #[test]
1121 fn test_comparison_types() {
1122 let mut annotator = TypeAnnotator::new(None, None);
1123
1124 let cmp = Expression::Gt(Box::new(BinaryOp::new(
1126 make_int_literal(1),
1127 make_int_literal(2),
1128 )));
1129 assert_eq!(annotator.annotate(&cmp), Some(DataType::Boolean));
1130
1131 let eq = Expression::Eq(Box::new(BinaryOp::new(
1133 make_string_literal("a"),
1134 make_string_literal("b"),
1135 )));
1136 assert_eq!(annotator.annotate(&eq), Some(DataType::Boolean));
1137 }
1138
1139 #[test]
1140 fn test_arithmetic_types() {
1141 let mut annotator = TypeAnnotator::new(None, None);
1142
1143 let add_int = Expression::Add(Box::new(BinaryOp::new(
1145 make_int_literal(1),
1146 make_int_literal(2),
1147 )));
1148 assert_eq!(
1149 annotator.annotate(&add_int),
1150 Some(DataType::Int {
1151 length: None,
1152 integer_spelling: false
1153 })
1154 );
1155
1156 let add_mixed = Expression::Add(Box::new(BinaryOp::new(
1158 make_int_literal(1),
1159 make_float_literal(2.5), )));
1161 assert_eq!(
1162 annotator.annotate(&add_mixed),
1163 Some(DataType::Double {
1164 precision: None,
1165 scale: None
1166 })
1167 );
1168 }
1169
1170 #[test]
1171 fn test_string_concat_type() {
1172 let mut annotator = TypeAnnotator::new(None, None);
1173
1174 let concat = Expression::Concat(Box::new(BinaryOp::new(
1176 make_string_literal("hello"),
1177 make_string_literal(" world"),
1178 )));
1179 assert_eq!(
1180 annotator.annotate(&concat),
1181 Some(DataType::VarChar {
1182 length: None,
1183 parenthesized_length: false
1184 })
1185 );
1186 }
1187
1188 #[test]
1189 fn test_cast_type() {
1190 let mut annotator = TypeAnnotator::new(None, None);
1191
1192 let cast = Expression::Cast(Box::new(Cast {
1194 this: make_int_literal(1),
1195 to: DataType::VarChar {
1196 length: Some(10),
1197 parenthesized_length: false,
1198 },
1199 trailing_comments: vec![],
1200 double_colon_syntax: false,
1201 format: None,
1202 default: None,
1203 }));
1204 assert_eq!(
1205 annotator.annotate(&cast),
1206 Some(DataType::VarChar {
1207 length: Some(10),
1208 parenthesized_length: false
1209 })
1210 );
1211 }
1212
1213 #[test]
1214 fn test_function_types() {
1215 let mut annotator = TypeAnnotator::new(None, None);
1216
1217 let count =
1219 Expression::Function(Box::new(Function::new("COUNT", vec![make_int_literal(1)])));
1220 assert_eq!(
1221 annotator.annotate(&count),
1222 Some(DataType::BigInt { length: None })
1223 );
1224
1225 let upper = Expression::Function(Box::new(Function::new(
1227 "UPPER",
1228 vec![make_string_literal("hello")],
1229 )));
1230 assert_eq!(
1231 annotator.annotate(&upper),
1232 Some(DataType::VarChar {
1233 length: None,
1234 parenthesized_length: false
1235 })
1236 );
1237
1238 let now = Expression::Function(Box::new(Function::new("NOW", vec![])));
1240 assert_eq!(
1241 annotator.annotate(&now),
1242 Some(DataType::Timestamp {
1243 precision: None,
1244 timezone: false
1245 })
1246 );
1247 }
1248
1249 #[test]
1250 fn test_coalesce_type_inference() {
1251 let mut annotator = TypeAnnotator::new(None, None);
1252
1253 let coalesce = Expression::Function(Box::new(Function::new(
1255 "COALESCE",
1256 vec![Expression::Null(Null), make_int_literal(1)],
1257 )));
1258 assert_eq!(
1259 annotator.annotate(&coalesce),
1260 Some(DataType::Int {
1261 length: None,
1262 integer_spelling: false
1263 })
1264 );
1265 }
1266
1267 #[test]
1268 fn test_type_coercion_class() {
1269 assert_eq!(
1271 TypeCoercionClass::from_data_type(&DataType::VarChar {
1272 length: None,
1273 parenthesized_length: false
1274 }),
1275 Some(TypeCoercionClass::Text)
1276 );
1277 assert_eq!(
1278 TypeCoercionClass::from_data_type(&DataType::Text),
1279 Some(TypeCoercionClass::Text)
1280 );
1281
1282 assert_eq!(
1284 TypeCoercionClass::from_data_type(&DataType::Int {
1285 length: None,
1286 integer_spelling: false
1287 }),
1288 Some(TypeCoercionClass::Numeric)
1289 );
1290 assert_eq!(
1291 TypeCoercionClass::from_data_type(&DataType::Double {
1292 precision: None,
1293 scale: None
1294 }),
1295 Some(TypeCoercionClass::Numeric)
1296 );
1297
1298 assert_eq!(
1300 TypeCoercionClass::from_data_type(&DataType::Date),
1301 Some(TypeCoercionClass::Timelike)
1302 );
1303 assert_eq!(
1304 TypeCoercionClass::from_data_type(&DataType::Timestamp {
1305 precision: None,
1306 timezone: false
1307 }),
1308 Some(TypeCoercionClass::Timelike)
1309 );
1310
1311 assert_eq!(TypeCoercionClass::from_data_type(&DataType::Json), None);
1313 }
1314
1315 #[test]
1316 fn test_wider_numeric_type() {
1317 let annotator = TypeAnnotator::new(None, None);
1318
1319 let result = annotator.wider_numeric_type(
1321 &DataType::Int {
1322 length: None,
1323 integer_spelling: false,
1324 },
1325 &DataType::BigInt { length: None },
1326 );
1327 assert_eq!(result, DataType::BigInt { length: None });
1328
1329 let result = annotator.wider_numeric_type(
1331 &DataType::Float {
1332 precision: None,
1333 scale: None,
1334 real_spelling: false,
1335 },
1336 &DataType::Double {
1337 precision: None,
1338 scale: None,
1339 },
1340 );
1341 assert_eq!(
1342 result,
1343 DataType::Double {
1344 precision: None,
1345 scale: None
1346 }
1347 );
1348
1349 let result = annotator.wider_numeric_type(
1351 &DataType::Int {
1352 length: None,
1353 integer_spelling: false,
1354 },
1355 &DataType::Double {
1356 precision: None,
1357 scale: None,
1358 },
1359 );
1360 assert_eq!(
1361 result,
1362 DataType::Double {
1363 precision: None,
1364 scale: None
1365 }
1366 );
1367 }
1368
1369 #[test]
1370 fn test_aggregate_return_types() {
1371 let mut annotator = TypeAnnotator::new(None, None);
1372
1373 let sum_type = annotator.get_aggregate_return_type("SUM", &[make_int_literal(1)]);
1375 assert_eq!(sum_type, Some(DataType::BigInt { length: None }));
1376
1377 let avg_type = annotator.get_aggregate_return_type("AVG", &[make_int_literal(1)]);
1379 assert_eq!(
1380 avg_type,
1381 Some(DataType::Double {
1382 precision: None,
1383 scale: None
1384 })
1385 );
1386
1387 let min_type = annotator.get_aggregate_return_type("MIN", &[make_string_literal("a")]);
1389 assert_eq!(
1390 min_type,
1391 Some(DataType::VarChar {
1392 length: None,
1393 parenthesized_length: false
1394 })
1395 );
1396 }
1397
1398 #[test]
1399 fn test_date_literal_types() {
1400 let mut annotator = TypeAnnotator::new(None, None);
1401
1402 let date_expr = Expression::Literal(Literal::Date("2024-01-15".to_string()));
1404 assert_eq!(annotator.annotate(&date_expr), Some(DataType::Date));
1405
1406 let time_expr = Expression::Literal(Literal::Time("10:30:00".to_string()));
1408 assert_eq!(
1409 annotator.annotate(&time_expr),
1410 Some(DataType::Time {
1411 precision: None,
1412 timezone: false
1413 })
1414 );
1415
1416 let ts_expr = Expression::Literal(Literal::Timestamp("2024-01-15 10:30:00".to_string()));
1418 assert_eq!(
1419 annotator.annotate(&ts_expr),
1420 Some(DataType::Timestamp {
1421 precision: None,
1422 timezone: false
1423 })
1424 );
1425 }
1426
1427 #[test]
1428 fn test_logical_operations() {
1429 let mut annotator = TypeAnnotator::new(None, None);
1430
1431 let and_expr = Expression::And(Box::new(BinaryOp::new(
1433 make_bool_literal(true),
1434 make_bool_literal(false),
1435 )));
1436 assert_eq!(annotator.annotate(&and_expr), Some(DataType::Boolean));
1437
1438 let or_expr = Expression::Or(Box::new(BinaryOp::new(
1440 make_bool_literal(true),
1441 make_bool_literal(false),
1442 )));
1443 assert_eq!(annotator.annotate(&or_expr), Some(DataType::Boolean));
1444
1445 let not_expr = Expression::Not(Box::new(crate::expressions::UnaryOp::new(
1447 make_bool_literal(true),
1448 )));
1449 assert_eq!(annotator.annotate(¬_expr), Some(DataType::Boolean));
1450 }
1451
1452 #[test]
1457 fn test_subscript_array_type() {
1458 let mut annotator = TypeAnnotator::new(None, None);
1459
1460 let arr = Expression::Array(Box::new(crate::expressions::Array {
1462 expressions: vec![make_int_literal(1), make_int_literal(2)],
1463 }));
1464 let subscript = Expression::Subscript(Box::new(crate::expressions::Subscript {
1465 this: arr,
1466 index: make_int_literal(0),
1467 }));
1468 assert_eq!(
1469 annotator.annotate(&subscript),
1470 Some(DataType::Int {
1471 length: None,
1472 integer_spelling: false
1473 })
1474 );
1475 }
1476
1477 #[test]
1478 fn test_subscript_map_type() {
1479 let mut annotator = TypeAnnotator::new(None, None);
1480
1481 let map = Expression::Map(Box::new(crate::expressions::Map {
1483 keys: vec![make_string_literal("a")],
1484 values: vec![make_int_literal(1)],
1485 }));
1486 let subscript = Expression::Subscript(Box::new(crate::expressions::Subscript {
1487 this: map,
1488 index: make_string_literal("a"),
1489 }));
1490 assert_eq!(
1491 annotator.annotate(&subscript),
1492 Some(DataType::Int {
1493 length: None,
1494 integer_spelling: false
1495 })
1496 );
1497 }
1498
1499 #[test]
1500 fn test_struct_type() {
1501 let mut annotator = TypeAnnotator::new(None, None);
1502
1503 let struct_expr = Expression::Struct(Box::new(crate::expressions::Struct {
1505 fields: vec![
1506 (Some("name".to_string()), make_string_literal("Alice")),
1507 (Some("age".to_string()), make_int_literal(30)),
1508 ],
1509 }));
1510 let result = annotator.annotate(&struct_expr);
1511 assert!(matches!(result, Some(DataType::Struct { fields, .. }) if fields.len() == 2));
1512 }
1513
1514 #[test]
1515 fn test_map_type() {
1516 let mut annotator = TypeAnnotator::new(None, None);
1517
1518 let map_expr = Expression::Map(Box::new(crate::expressions::Map {
1520 keys: vec![make_string_literal("a"), make_string_literal("b")],
1521 values: vec![make_int_literal(1), make_int_literal(2)],
1522 }));
1523 let result = annotator.annotate(&map_expr);
1524 assert!(matches!(
1525 result,
1526 Some(DataType::Map { key_type, value_type })
1527 if matches!(*key_type, DataType::VarChar { .. })
1528 && matches!(*value_type, DataType::Int { .. })
1529 ));
1530 }
1531
1532 #[test]
1533 fn test_explode_array_type() {
1534 let mut annotator = TypeAnnotator::new(None, None);
1535
1536 let arr = Expression::Array(Box::new(crate::expressions::Array {
1538 expressions: vec![make_int_literal(1), make_int_literal(2)],
1539 }));
1540 let explode = Expression::Explode(Box::new(crate::expressions::UnaryFunc {
1541 this: arr,
1542 original_name: None,
1543 }));
1544 assert_eq!(
1545 annotator.annotate(&explode),
1546 Some(DataType::Int {
1547 length: None,
1548 integer_spelling: false
1549 })
1550 );
1551 }
1552
1553 #[test]
1554 fn test_unnest_array_type() {
1555 let mut annotator = TypeAnnotator::new(None, None);
1556
1557 let arr = Expression::Array(Box::new(crate::expressions::Array {
1559 expressions: vec![make_string_literal("a"), make_string_literal("b")],
1560 }));
1561 let unnest = Expression::Unnest(Box::new(crate::expressions::UnnestFunc {
1562 this: arr,
1563 expressions: Vec::new(),
1564 with_ordinality: false,
1565 alias: None,
1566 offset_alias: None,
1567 }));
1568 assert_eq!(
1569 annotator.annotate(&unnest),
1570 Some(DataType::VarChar {
1571 length: None,
1572 parenthesized_length: false
1573 })
1574 );
1575 }
1576
1577 #[test]
1578 fn test_set_operation_type() {
1579 let mut annotator = TypeAnnotator::new(None, None);
1580
1581 let select = Expression::Select(Box::new(crate::expressions::Select::default()));
1583 let union = Expression::Union(Box::new(crate::expressions::Union {
1584 left: select.clone(),
1585 right: select.clone(),
1586 all: false,
1587 distinct: false,
1588 with: None,
1589 order_by: None,
1590 limit: None,
1591 offset: None,
1592 by_name: false,
1593 side: None,
1594 kind: None,
1595 corresponding: false,
1596 strict: false,
1597 on_columns: Vec::new(),
1598 distribute_by: None,
1599 sort_by: None,
1600 cluster_by: None,
1601 }));
1602 assert_eq!(annotator.annotate(&union), None);
1603 }
1604}