1use std::collections::HashMap;
13
14use crate::dialects::DialectType;
15use crate::expressions::{
16 BinaryOp, DataType, Expression, Function, Literal, Map, Struct, StructField, Subscript,
17};
18use crate::schema::Schema;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
23pub enum TypeCoercionClass {
24 Text = 0,
26 Numeric = 1,
28 Timelike = 2,
30}
31
32impl TypeCoercionClass {
33 pub fn from_data_type(dt: &DataType) -> Option<Self> {
35 match dt {
36 DataType::Char { .. }
38 | DataType::VarChar { .. }
39 | DataType::Text
40 | DataType::Binary { .. }
41 | DataType::VarBinary { .. }
42 | DataType::Blob => Some(TypeCoercionClass::Text),
43
44 DataType::Boolean
46 | DataType::TinyInt { .. }
47 | DataType::SmallInt { .. }
48 | DataType::Int { .. }
49 | DataType::BigInt { .. }
50 | DataType::Float { .. }
51 | DataType::Double { .. }
52 | DataType::Decimal { .. } => Some(TypeCoercionClass::Numeric),
53
54 DataType::Date
56 | DataType::Time { .. }
57 | DataType::Timestamp { .. }
58 | DataType::Interval { .. } => Some(TypeCoercionClass::Timelike),
59
60 _ => None,
62 }
63 }
64}
65
66pub struct TypeAnnotator<'a> {
68 _schema: Option<&'a dyn Schema>,
70 _dialect: Option<DialectType>,
72 annotate_aggregates: bool,
74 function_return_types: HashMap<String, DataType>,
76}
77
78impl<'a> TypeAnnotator<'a> {
79 pub fn new(schema: Option<&'a dyn Schema>, dialect: Option<DialectType>) -> Self {
81 let mut annotator = Self {
82 _schema: schema,
83 _dialect: dialect,
84 annotate_aggregates: true,
85 function_return_types: HashMap::new(),
86 };
87 annotator.init_function_return_types();
88 annotator
89 }
90
91 fn init_function_return_types(&mut self) {
93 self.function_return_types
95 .insert("COUNT".to_string(), DataType::BigInt { length: None });
96 self.function_return_types
97 .insert("SUM".to_string(), DataType::Decimal {
98 precision: None,
99 scale: None,
100 });
101 self.function_return_types
102 .insert("AVG".to_string(), DataType::Double { precision: None, scale: None });
103
104 self.function_return_types
106 .insert("CONCAT".to_string(), DataType::VarChar { length: None, parenthesized_length: false });
107 self.function_return_types
108 .insert("UPPER".to_string(), DataType::VarChar { length: None, parenthesized_length: false });
109 self.function_return_types
110 .insert("LOWER".to_string(), DataType::VarChar { length: None, parenthesized_length: false });
111 self.function_return_types
112 .insert("TRIM".to_string(), DataType::VarChar { length: None, parenthesized_length: false });
113 self.function_return_types
114 .insert("LTRIM".to_string(), DataType::VarChar { length: None, parenthesized_length: false });
115 self.function_return_types
116 .insert("RTRIM".to_string(), DataType::VarChar { length: None, parenthesized_length: false });
117 self.function_return_types
118 .insert("SUBSTRING".to_string(), DataType::VarChar { length: None, parenthesized_length: false });
119 self.function_return_types
120 .insert("SUBSTR".to_string(), DataType::VarChar { length: None, parenthesized_length: false });
121 self.function_return_types
122 .insert("REPLACE".to_string(), DataType::VarChar { length: None, parenthesized_length: false });
123 self.function_return_types
124 .insert("LENGTH".to_string(), DataType::Int { length: None, integer_spelling: false });
125 self.function_return_types
126 .insert("CHAR_LENGTH".to_string(), DataType::Int { length: None, integer_spelling: false });
127
128 self.function_return_types
130 .insert("NOW".to_string(), DataType::Timestamp {
131 precision: None,
132 timezone: false,
133 });
134 self.function_return_types
135 .insert("CURRENT_TIMESTAMP".to_string(), DataType::Timestamp {
136 precision: None,
137 timezone: false,
138 });
139 self.function_return_types
140 .insert("CURRENT_DATE".to_string(), DataType::Date);
141 self.function_return_types
142 .insert("CURRENT_TIME".to_string(), DataType::Time { precision: None, timezone: false });
143 self.function_return_types
144 .insert("DATE".to_string(), DataType::Date);
145 self.function_return_types
146 .insert("YEAR".to_string(), DataType::Int { length: None, integer_spelling: false });
147 self.function_return_types
148 .insert("MONTH".to_string(), DataType::Int { length: None, integer_spelling: false });
149 self.function_return_types
150 .insert("DAY".to_string(), DataType::Int { length: None, integer_spelling: false });
151 self.function_return_types
152 .insert("HOUR".to_string(), DataType::Int { length: None, integer_spelling: false });
153 self.function_return_types
154 .insert("MINUTE".to_string(), DataType::Int { length: None, integer_spelling: false });
155 self.function_return_types
156 .insert("SECOND".to_string(), DataType::Int { length: None, integer_spelling: false });
157 self.function_return_types
158 .insert("EXTRACT".to_string(), DataType::Int { length: None, integer_spelling: false });
159 self.function_return_types
160 .insert("DATE_DIFF".to_string(), DataType::Int { length: None, integer_spelling: false });
161 self.function_return_types
162 .insert("DATEDIFF".to_string(), DataType::Int { length: None, integer_spelling: false });
163
164 self.function_return_types
166 .insert("ABS".to_string(), DataType::Double { precision: None, scale: None });
167 self.function_return_types
168 .insert("ROUND".to_string(), DataType::Double { precision: None, scale: None });
169 self.function_return_types
170 .insert("FLOOR".to_string(), DataType::BigInt { length: None });
171 self.function_return_types
172 .insert("CEIL".to_string(), DataType::BigInt { length: None });
173 self.function_return_types
174 .insert("CEILING".to_string(), DataType::BigInt { length: None });
175 self.function_return_types
176 .insert("SQRT".to_string(), DataType::Double { precision: None, scale: None });
177 self.function_return_types
178 .insert("POWER".to_string(), DataType::Double { precision: None, scale: None });
179 self.function_return_types
180 .insert("MOD".to_string(), DataType::Int { length: None, integer_spelling: false });
181 self.function_return_types
182 .insert("LOG".to_string(), DataType::Double { precision: None, scale: None });
183 self.function_return_types
184 .insert("LN".to_string(), DataType::Double { precision: None, scale: None });
185 self.function_return_types
186 .insert("EXP".to_string(), DataType::Double { precision: None, scale: None });
187
188 self.function_return_types
190 .insert("COALESCE".to_string(), DataType::Unknown);
191 self.function_return_types
192 .insert("NULLIF".to_string(), DataType::Unknown);
193 self.function_return_types
194 .insert("GREATEST".to_string(), DataType::Unknown);
195 self.function_return_types
196 .insert("LEAST".to_string(), DataType::Unknown);
197 }
198
199 pub fn annotate(&mut self, expr: &Expression) -> Option<DataType> {
201 match expr {
202 Expression::Literal(lit) => self.annotate_literal(lit),
204 Expression::Boolean(_) => Some(DataType::Boolean),
205 Expression::Null(_) => None, Expression::Add(op) | Expression::Sub(op) |
209 Expression::Mul(op) | Expression::Div(op) |
210 Expression::Mod(op) => self.annotate_arithmetic(op),
211
212 Expression::Eq(_) | Expression::Neq(_) |
214 Expression::Lt(_) | Expression::Lte(_) |
215 Expression::Gt(_) | Expression::Gte(_) |
216 Expression::Like(_) | Expression::ILike(_) => Some(DataType::Boolean),
217
218 Expression::And(_) | Expression::Or(_) | Expression::Not(_) => Some(DataType::Boolean),
220
221 Expression::Between(_) | Expression::In(_) |
223 Expression::IsNull(_) | Expression::IsTrue(_) | Expression::IsFalse(_) |
224 Expression::Is(_) | Expression::Exists(_) => Some(DataType::Boolean),
225
226 Expression::Concat(_) => Some(DataType::VarChar { length: None, parenthesized_length: false }),
228
229 Expression::BitwiseAnd(_) | Expression::BitwiseOr(_) |
231 Expression::BitwiseXor(_) | Expression::BitwiseNot(_) => {
232 Some(DataType::BigInt { length: None })
233 }
234
235 Expression::Neg(op) => self.annotate(&op.this),
237
238 Expression::Function(func) => self.annotate_function(func),
240
241 Expression::Count(_) => Some(DataType::BigInt { length: None }),
243 Expression::Sum(agg) => self.annotate_sum(&agg.this),
244 Expression::Avg(_) => Some(DataType::Double { precision: None, scale: None }),
245 Expression::Min(agg) => self.annotate(&agg.this),
246 Expression::Max(agg) => self.annotate(&agg.this),
247 Expression::GroupConcat(_) | Expression::StringAgg(_) | Expression::ListAgg(_) => {
248 Some(DataType::VarChar { length: None, parenthesized_length: false })
249 }
250
251 Expression::AggregateFunction(agg) => {
253 if !self.annotate_aggregates {
254 return None;
255 }
256 let func_name = agg.name.to_uppercase();
257 self.get_aggregate_return_type(&func_name, &agg.args)
258 }
259
260 Expression::Column(col) => {
262 if let Some(schema) = &self._schema {
263 let table_name = col.table.as_ref().map(|t| t.name.as_str()).unwrap_or("");
264 schema
265 .get_column_type(table_name, &col.name.name)
266 .ok()
267 } else {
268 None
269 }
270 }
271
272 Expression::Cast(cast) => Some(cast.to.clone()),
274 Expression::SafeCast(cast) => Some(cast.to.clone()),
275 Expression::TryCast(cast) => Some(cast.to.clone()),
276
277 Expression::Subquery(subq) => {
279 if let Expression::Select(select) = &subq.this {
280 if let Some(first) = select.expressions.first() {
281 self.annotate(first)
282 } else {
283 None
284 }
285 } else {
286 None
287 }
288 }
289
290 Expression::Case(case) => {
292 if let Some(else_expr) = &case.else_ {
293 self.annotate(else_expr)
294 } else if let Some((_, then_expr)) = case.whens.first() {
295 self.annotate(then_expr)
296 } else {
297 None
298 }
299 }
300
301 Expression::Array(arr) => {
303 if let Some(first) = arr.expressions.first() {
304 if let Some(elem_type) = self.annotate(first) {
305 Some(DataType::Array {
306 element_type: Box::new(elem_type),
307 dimension: None,
308 })
309 } else {
310 Some(DataType::Array {
311 element_type: Box::new(DataType::Unknown),
312 dimension: None,
313 })
314 }
315 } else {
316 Some(DataType::Array {
317 element_type: Box::new(DataType::Unknown),
318 dimension: None,
319 })
320 }
321 }
322
323 Expression::Interval(_) => Some(DataType::Interval { unit: None, to: None }),
325
326 Expression::WindowFunction(window) => self.annotate(&window.this),
328
329 Expression::CurrentDate(_) => Some(DataType::Date),
331 Expression::CurrentTime(_) => Some(DataType::Time { precision: None, timezone: false }),
332 Expression::CurrentTimestamp(_) | Expression::CurrentTimestampLTZ(_) => {
333 Some(DataType::Timestamp {
334 precision: None,
335 timezone: false,
336 })
337 }
338
339 Expression::DateAdd(_) | Expression::DateSub(_) |
341 Expression::ToDate(_) | Expression::Date(_) => Some(DataType::Date),
342 Expression::DateDiff(_) | Expression::Extract(_) => Some(DataType::Int { length: None, integer_spelling: false }),
343 Expression::ToTimestamp(_) => Some(DataType::Timestamp {
344 precision: None,
345 timezone: false,
346 }),
347
348 Expression::Upper(_) | Expression::Lower(_) | Expression::Trim(_) |
350 Expression::LTrim(_) | Expression::RTrim(_) | Expression::Replace(_) |
351 Expression::Substring(_) | Expression::Reverse(_) | Expression::Left(_) |
352 Expression::Right(_) | Expression::Repeat(_) | Expression::Lpad(_) |
353 Expression::Rpad(_) | Expression::ConcatWs(_) | Expression::Overlay(_) => {
354 Some(DataType::VarChar { length: None, parenthesized_length: false })
355 }
356 Expression::Length(_) => Some(DataType::Int { length: None, integer_spelling: false }),
357
358 Expression::Abs(_) | Expression::Sqrt(_) | Expression::Cbrt(_) |
360 Expression::Ln(_) | Expression::Exp(_) | Expression::Power(_) |
361 Expression::Log(_) => Some(DataType::Double { precision: None, scale: None }),
362 Expression::Round(_) => Some(DataType::Double { precision: None, scale: None }),
363 Expression::Floor(_) | Expression::Ceil(_) | Expression::Sign(_) => {
364 Some(DataType::BigInt { length: None })
365 }
366
367 Expression::Greatest(v) | Expression::Least(v) => {
369 self.coerce_arg_types(&v.expressions)
370 }
371
372 Expression::Alias(alias) => self.annotate(&alias.this),
374
375 Expression::Select(_) => None,
377
378 Expression::Subscript(sub) => self.annotate_subscript(sub),
382
383 Expression::Dot(_) => None,
385
386 Expression::Struct(s) => self.annotate_struct(s),
390
391 Expression::Map(map) => self.annotate_map(map),
395 Expression::MapFromEntries(mfe) => {
396 if let Some(DataType::Array { element_type, .. }) = self.annotate(&mfe.this) {
398 if let DataType::Struct { fields, .. } = *element_type {
399 if fields.len() >= 2 {
400 return Some(DataType::Map {
401 key_type: Box::new(fields[0].data_type.clone()),
402 value_type: Box::new(fields[1].data_type.clone()),
403 });
404 }
405 }
406 }
407 Some(DataType::Map {
408 key_type: Box::new(DataType::Unknown),
409 value_type: Box::new(DataType::Unknown),
410 })
411 }
412
413 Expression::Union(union) => self.annotate_set_operation(&union.left, &union.right),
417 Expression::Intersect(intersect) => {
418 self.annotate_set_operation(&intersect.left, &intersect.right)
419 }
420 Expression::Except(except) => {
421 self.annotate_set_operation(&except.left, &except.right)
422 }
423
424 Expression::Lateral(lateral) => {
428 self.annotate(&lateral.this)
430 }
431 Expression::LateralView(lv) => {
432 self.annotate_lateral_view(lv)
434 }
435 Expression::Unnest(unnest) => {
436 if let Some(DataType::Array { element_type, .. }) = self.annotate(&unnest.this) {
438 Some(*element_type)
439 } else {
440 None
441 }
442 }
443 Expression::Explode(explode) => {
444 if let Some(DataType::Array { element_type, .. }) = self.annotate(&explode.this) {
446 Some(*element_type)
447 } else if let Some(DataType::Map { key_type, value_type }) =
448 self.annotate(&explode.this)
449 {
450 Some(DataType::Struct {
452 fields: vec![
453 StructField::new("key".to_string(), *key_type),
454 StructField::new("value".to_string(), *value_type),
455 ],
456 nested: false,
457 })
458 } else {
459 None
460 }
461 }
462 Expression::ExplodeOuter(explode) => {
463 if let Some(DataType::Array { element_type, .. }) = self.annotate(&explode.this) {
465 Some(*element_type)
466 } else {
467 None
468 }
469 }
470 Expression::GenerateSeries(gs) => {
471 if let Some(ref start) = gs.start {
473 self.annotate(start)
474 } else if let Some(ref end) = gs.end {
475 self.annotate(end)
476 } else {
477 Some(DataType::Int { length: None, integer_spelling: false })
478 }
479 }
480
481 _ => None,
483 }
484 }
485
486 fn annotate_subscript(&mut self, sub: &Subscript) -> Option<DataType> {
488 let base_type = self.annotate(&sub.this)?;
489
490 match base_type {
491 DataType::Array { element_type, .. } => Some(*element_type),
492 DataType::Map { value_type, .. } => Some(*value_type),
493 DataType::Json | DataType::JsonB => Some(DataType::Json), DataType::VarChar { .. } | DataType::Text => {
495 Some(DataType::VarChar { length: Some(1), parenthesized_length: false })
497 }
498 _ => None,
499 }
500 }
501
502 fn annotate_struct(&mut self, s: &Struct) -> Option<DataType> {
504 let fields: Vec<StructField> = s
505 .fields
506 .iter()
507 .map(|(name, expr)| {
508 let field_type = self.annotate(expr).unwrap_or(DataType::Unknown);
509 StructField::new(name.clone().unwrap_or_default(), field_type)
510 })
511 .collect();
512 Some(DataType::Struct { fields, nested: false })
513 }
514
515 fn annotate_map(&mut self, map: &Map) -> Option<DataType> {
517 let key_type = if let Some(first_key) = map.keys.first() {
518 self.annotate(first_key).unwrap_or(DataType::Unknown)
519 } else {
520 DataType::Unknown
521 };
522
523 let value_type = if let Some(first_value) = map.values.first() {
524 self.annotate(first_value).unwrap_or(DataType::Unknown)
525 } else {
526 DataType::Unknown
527 };
528
529 Some(DataType::Map {
530 key_type: Box::new(key_type),
531 value_type: Box::new(value_type),
532 })
533 }
534
535 fn annotate_set_operation(
538 &mut self,
539 _left: &Expression,
540 _right: &Expression,
541 ) -> Option<DataType> {
542 None
546 }
547
548 fn annotate_lateral_view(
550 &mut self,
551 lv: &crate::expressions::LateralView,
552 ) -> Option<DataType> {
553 self.annotate(&lv.this)
555 }
556
557 fn annotate_literal(&self, lit: &Literal) -> Option<DataType> {
559 match lit {
560 Literal::String(_) | Literal::NationalString(_) |
561 Literal::TripleQuotedString(_, _) | Literal::EscapeString(_) |
562 Literal::DollarString(_) | Literal::RawString(_) => Some(DataType::VarChar { length: None, parenthesized_length: false }),
563 Literal::Number(n) => {
564 if n.contains('.') || n.contains('e') || n.contains('E') {
566 Some(DataType::Double { precision: None, scale: None })
567 } else {
568 if let Ok(_) = n.parse::<i32>() {
570 Some(DataType::Int { length: None, integer_spelling: false })
571 } else {
572 Some(DataType::BigInt { length: None })
573 }
574 }
575 }
576 Literal::HexString(_) | Literal::BitString(_) | Literal::ByteString(_) => {
577 Some(DataType::VarBinary { length: None })
578 }
579 Literal::HexNumber(_) => Some(DataType::BigInt { length: None }),
580 Literal::Date(_) => Some(DataType::Date),
581 Literal::Time(_) => Some(DataType::Time { precision: None, timezone: false }),
582 Literal::Timestamp(_) => Some(DataType::Timestamp {
583 precision: None,
584 timezone: false,
585 }),
586 Literal::Datetime(_) => Some(DataType::Custom {
587 name: "DATETIME".to_string(),
588 }),
589 }
590 }
591
592 fn annotate_arithmetic(&mut self, op: &BinaryOp) -> Option<DataType> {
594 let left_type = self.annotate(&op.left);
595 let right_type = self.annotate(&op.right);
596
597 match (left_type, right_type) {
598 (Some(l), Some(r)) => self.coerce_types(&l, &r),
599 (Some(t), None) | (None, Some(t)) => Some(t),
600 (None, None) => None,
601 }
602 }
603
604 fn annotate_function(&mut self, func: &Function) -> Option<DataType> {
606 let func_name = func.name.to_uppercase();
607
608 if let Some(return_type) = self.function_return_types.get(&func_name) {
610 if *return_type != DataType::Unknown {
611 return Some(return_type.clone());
612 }
613 }
614
615 match func_name.as_str() {
617 "COALESCE" | "IFNULL" | "NVL" | "ISNULL" => {
618 for arg in &func.args {
620 if let Some(arg_type) = self.annotate(arg) {
621 return Some(arg_type);
622 }
623 }
624 None
625 }
626 "NULLIF" => {
627 func.args.first().and_then(|arg| self.annotate(arg))
629 }
630 "GREATEST" | "LEAST" => {
631 self.coerce_arg_types(&func.args)
633 }
634 "IF" | "IIF" => {
635 if func.args.len() >= 2 {
637 self.annotate(&func.args[1])
638 } else {
639 None
640 }
641 }
642 _ => {
643 func.args.first().and_then(|arg| self.annotate(arg))
645 }
646 }
647 }
648
649 fn get_aggregate_return_type(&mut self, func_name: &str, args: &[Expression]) -> Option<DataType> {
651 match func_name {
652 "COUNT" | "COUNT_IF" => Some(DataType::BigInt { length: None }),
653 "SUM" => {
654 if let Some(arg) = args.first() {
655 self.annotate_sum(arg)
656 } else {
657 Some(DataType::Decimal {
658 precision: None,
659 scale: None,
660 })
661 }
662 }
663 "AVG" => Some(DataType::Double { precision: None, scale: None }),
664 "MIN" | "MAX" => {
665 args.first().and_then(|arg| self.annotate(arg))
667 }
668 "STRING_AGG" | "GROUP_CONCAT" | "LISTAGG" | "ARRAY_AGG" => {
669 Some(DataType::VarChar { length: None, parenthesized_length: false })
670 }
671 "BOOL_AND" | "BOOL_OR" | "EVERY" | "ANY" | "SOME" => Some(DataType::Boolean),
672 "BIT_AND" | "BIT_OR" | "BIT_XOR" => Some(DataType::BigInt { length: None }),
673 "STDDEV" | "STDDEV_POP" | "STDDEV_SAMP" | "VARIANCE" | "VAR_POP" | "VAR_SAMP" => {
674 Some(DataType::Double { precision: None, scale: None })
675 }
676 "PERCENTILE_CONT" | "PERCENTILE_DISC" | "MEDIAN" => {
677 args.first().and_then(|arg| self.annotate(arg))
678 }
679 _ => None,
680 }
681 }
682
683 fn annotate_sum(&mut self, arg: &Expression) -> Option<DataType> {
685 match self.annotate(arg) {
686 Some(DataType::TinyInt { .. })
687 | Some(DataType::SmallInt { .. })
688 | Some(DataType::Int { .. }) => Some(DataType::BigInt { length: None }),
689 Some(DataType::BigInt { .. }) => Some(DataType::BigInt { length: None }),
690 Some(DataType::Float { .. }) | Some(DataType::Double { .. }) => Some(DataType::Double { precision: None, scale: None }),
691 Some(DataType::Decimal { precision, scale }) => {
692 Some(DataType::Decimal { precision, scale })
693 }
694 _ => Some(DataType::Decimal {
695 precision: None,
696 scale: None,
697 }),
698 }
699 }
700
701 fn coerce_arg_types(&mut self, args: &[Expression]) -> Option<DataType> {
703 let mut result_type: Option<DataType> = None;
704 for arg in args {
705 if let Some(arg_type) = self.annotate(arg) {
706 result_type = match result_type {
707 Some(t) => self.coerce_types(&t, &arg_type),
708 None => Some(arg_type),
709 };
710 }
711 }
712 result_type
713 }
714
715 fn coerce_types(&self, left: &DataType, right: &DataType) -> Option<DataType> {
717 if left == right {
719 return Some(left.clone());
720 }
721
722 match (left, right) {
724 (DataType::Date, DataType::Interval { .. }) |
725 (DataType::Interval { .. }, DataType::Date) => return Some(DataType::Date),
726 (DataType::Timestamp { precision, timezone }, DataType::Interval { .. }) |
727 (DataType::Interval { .. }, DataType::Timestamp { precision, timezone }) => {
728 return Some(DataType::Timestamp {
729 precision: *precision,
730 timezone: *timezone
731 });
732 }
733 _ => {}
734 }
735
736 let left_class = TypeCoercionClass::from_data_type(left);
738 let right_class = TypeCoercionClass::from_data_type(right);
739
740 match (left_class, right_class) {
741 (Some(lc), Some(rc)) if lc == rc => {
743 if lc == TypeCoercionClass::Numeric {
745 Some(self.wider_numeric_type(left, right))
746 } else {
747 Some(left.clone())
749 }
750 }
751 (Some(lc), Some(rc)) => {
753 if lc > rc {
754 Some(left.clone())
755 } else {
756 Some(right.clone())
757 }
758 }
759 (Some(_), None) => Some(left.clone()),
761 (None, Some(_)) => Some(right.clone()),
762 (None, None) => Some(DataType::Unknown),
764 }
765 }
766
767 fn wider_numeric_type(&self, left: &DataType, right: &DataType) -> DataType {
769 let order = |dt: &DataType| -> u8 {
770 match dt {
771 DataType::Boolean => 0,
772 DataType::TinyInt { .. } => 1,
773 DataType::SmallInt { .. } => 2,
774 DataType::Int { .. } => 3,
775 DataType::BigInt { .. } => 4,
776 DataType::Float { .. } => 5,
777 DataType::Double { .. } => 6,
778 DataType::Decimal { .. } => 7,
779 _ => 0,
780 }
781 };
782
783 if order(left) >= order(right) {
784 left.clone()
785 } else {
786 right.clone()
787 }
788 }
789}
790
791pub fn annotate_types(
793 expr: &Expression,
794 schema: Option<&dyn Schema>,
795 dialect: Option<DialectType>,
796) -> Option<DataType> {
797 let mut annotator = TypeAnnotator::new(schema, dialect);
798 annotator.annotate(expr)
799}
800
801#[cfg(test)]
802mod tests {
803 use super::*;
804 use crate::expressions::{BooleanLiteral, Cast, Null};
805
806 fn make_int_literal(val: i64) -> Expression {
807 Expression::Literal(Literal::Number(val.to_string()))
808 }
809
810 fn make_float_literal(val: f64) -> Expression {
811 Expression::Literal(Literal::Number(val.to_string()))
812 }
813
814 fn make_string_literal(val: &str) -> Expression {
815 Expression::Literal(Literal::String(val.to_string()))
816 }
817
818 fn make_bool_literal(val: bool) -> Expression {
819 Expression::Boolean(BooleanLiteral { value: val })
820 }
821
822 #[test]
823 fn test_literal_types() {
824 let mut annotator = TypeAnnotator::new(None, None);
825
826 let int_expr = make_int_literal(42);
828 assert_eq!(
829 annotator.annotate(&int_expr),
830 Some(DataType::Int { length: None, integer_spelling: false })
831 );
832
833 let float_expr = make_float_literal(3.14);
835 assert_eq!(annotator.annotate(&float_expr), Some(DataType::Double { precision: None, scale: None }));
836
837 let string_expr = make_string_literal("hello");
839 assert_eq!(
840 annotator.annotate(&string_expr),
841 Some(DataType::VarChar { length: None, parenthesized_length: false })
842 );
843
844 let bool_expr = make_bool_literal(true);
846 assert_eq!(annotator.annotate(&bool_expr), Some(DataType::Boolean));
847
848 let null_expr = Expression::Null(Null);
850 assert_eq!(annotator.annotate(&null_expr), None);
851 }
852
853 #[test]
854 fn test_comparison_types() {
855 let mut annotator = TypeAnnotator::new(None, None);
856
857 let cmp = Expression::Gt(Box::new(BinaryOp::new(
859 make_int_literal(1),
860 make_int_literal(2),
861 )));
862 assert_eq!(annotator.annotate(&cmp), Some(DataType::Boolean));
863
864 let eq = Expression::Eq(Box::new(BinaryOp::new(
866 make_string_literal("a"),
867 make_string_literal("b"),
868 )));
869 assert_eq!(annotator.annotate(&eq), Some(DataType::Boolean));
870 }
871
872 #[test]
873 fn test_arithmetic_types() {
874 let mut annotator = TypeAnnotator::new(None, None);
875
876 let add_int = Expression::Add(Box::new(BinaryOp::new(
878 make_int_literal(1),
879 make_int_literal(2),
880 )));
881 assert_eq!(
882 annotator.annotate(&add_int),
883 Some(DataType::Int { length: None, integer_spelling: false })
884 );
885
886 let add_mixed = Expression::Add(Box::new(BinaryOp::new(
888 make_int_literal(1),
889 make_float_literal(2.5), )));
891 assert_eq!(annotator.annotate(&add_mixed), Some(DataType::Double { precision: None, scale: None }));
892 }
893
894 #[test]
895 fn test_string_concat_type() {
896 let mut annotator = TypeAnnotator::new(None, None);
897
898 let concat = Expression::Concat(Box::new(BinaryOp::new(
900 make_string_literal("hello"),
901 make_string_literal(" world"),
902 )));
903 assert_eq!(
904 annotator.annotate(&concat),
905 Some(DataType::VarChar { length: None, parenthesized_length: false })
906 );
907 }
908
909 #[test]
910 fn test_cast_type() {
911 let mut annotator = TypeAnnotator::new(None, None);
912
913 let cast = Expression::Cast(Box::new(Cast {
915 this: make_int_literal(1),
916 to: DataType::VarChar { length: Some(10), parenthesized_length: false },
917 trailing_comments: vec![],
918 double_colon_syntax: false,
919 format: None,
920 default: None,
921 }));
922 assert_eq!(
923 annotator.annotate(&cast),
924 Some(DataType::VarChar { length: Some(10), parenthesized_length: false })
925 );
926 }
927
928 #[test]
929 fn test_function_types() {
930 let mut annotator = TypeAnnotator::new(None, None);
931
932 let count = Expression::Function(Box::new(Function::new("COUNT", vec![make_int_literal(1)])));
934 assert_eq!(
935 annotator.annotate(&count),
936 Some(DataType::BigInt { length: None })
937 );
938
939 let upper = Expression::Function(Box::new(Function::new("UPPER", vec![make_string_literal("hello")])));
941 assert_eq!(
942 annotator.annotate(&upper),
943 Some(DataType::VarChar { length: None, parenthesized_length: false })
944 );
945
946 let now = Expression::Function(Box::new(Function::new("NOW", vec![])));
948 assert_eq!(
949 annotator.annotate(&now),
950 Some(DataType::Timestamp {
951 precision: None,
952 timezone: false
953 })
954 );
955 }
956
957 #[test]
958 fn test_coalesce_type_inference() {
959 let mut annotator = TypeAnnotator::new(None, None);
960
961 let coalesce = Expression::Function(Box::new(Function::new(
963 "COALESCE",
964 vec![
965 Expression::Null(Null),
966 make_int_literal(1),
967 ],
968 )));
969 assert_eq!(
970 annotator.annotate(&coalesce),
971 Some(DataType::Int { length: None, integer_spelling: false })
972 );
973 }
974
975 #[test]
976 fn test_type_coercion_class() {
977 assert_eq!(
979 TypeCoercionClass::from_data_type(&DataType::VarChar { length: None, parenthesized_length: false }),
980 Some(TypeCoercionClass::Text)
981 );
982 assert_eq!(
983 TypeCoercionClass::from_data_type(&DataType::Text),
984 Some(TypeCoercionClass::Text)
985 );
986
987 assert_eq!(
989 TypeCoercionClass::from_data_type(&DataType::Int { length: None, integer_spelling: false }),
990 Some(TypeCoercionClass::Numeric)
991 );
992 assert_eq!(
993 TypeCoercionClass::from_data_type(&DataType::Double { precision: None, scale: None }),
994 Some(TypeCoercionClass::Numeric)
995 );
996
997 assert_eq!(
999 TypeCoercionClass::from_data_type(&DataType::Date),
1000 Some(TypeCoercionClass::Timelike)
1001 );
1002 assert_eq!(
1003 TypeCoercionClass::from_data_type(&DataType::Timestamp {
1004 precision: None,
1005 timezone: false
1006 }),
1007 Some(TypeCoercionClass::Timelike)
1008 );
1009
1010 assert_eq!(TypeCoercionClass::from_data_type(&DataType::Json), None);
1012 }
1013
1014 #[test]
1015 fn test_wider_numeric_type() {
1016 let annotator = TypeAnnotator::new(None, None);
1017
1018 let result = annotator.wider_numeric_type(
1020 &DataType::Int { length: None, integer_spelling: false },
1021 &DataType::BigInt { length: None },
1022 );
1023 assert_eq!(result, DataType::BigInt { length: None });
1024
1025 let result = annotator.wider_numeric_type(&DataType::Float { precision: None, scale: None, real_spelling: false }, &DataType::Double { precision: None, scale: None });
1027 assert_eq!(result, DataType::Double { precision: None, scale: None });
1028
1029 let result = annotator.wider_numeric_type(
1031 &DataType::Int { length: None, integer_spelling: false },
1032 &DataType::Double { precision: None, scale: None },
1033 );
1034 assert_eq!(result, DataType::Double { precision: None, scale: None });
1035 }
1036
1037 #[test]
1038 fn test_aggregate_return_types() {
1039 let mut annotator = TypeAnnotator::new(None, None);
1040
1041 let sum_type = annotator.get_aggregate_return_type("SUM", &[make_int_literal(1)]);
1043 assert_eq!(sum_type, Some(DataType::BigInt { length: None }));
1044
1045 let avg_type = annotator.get_aggregate_return_type("AVG", &[make_int_literal(1)]);
1047 assert_eq!(avg_type, Some(DataType::Double { precision: None, scale: None }));
1048
1049 let min_type = annotator.get_aggregate_return_type("MIN", &[make_string_literal("a")]);
1051 assert_eq!(min_type, Some(DataType::VarChar { length: None, parenthesized_length: false }));
1052 }
1053
1054 #[test]
1055 fn test_date_literal_types() {
1056 let mut annotator = TypeAnnotator::new(None, None);
1057
1058 let date_expr = Expression::Literal(Literal::Date("2024-01-15".to_string()));
1060 assert_eq!(annotator.annotate(&date_expr), Some(DataType::Date));
1061
1062 let time_expr = Expression::Literal(Literal::Time("10:30:00".to_string()));
1064 assert_eq!(
1065 annotator.annotate(&time_expr),
1066 Some(DataType::Time { precision: None, timezone: false })
1067 );
1068
1069 let ts_expr = Expression::Literal(Literal::Timestamp("2024-01-15 10:30:00".to_string()));
1071 assert_eq!(
1072 annotator.annotate(&ts_expr),
1073 Some(DataType::Timestamp {
1074 precision: None,
1075 timezone: false
1076 })
1077 );
1078 }
1079
1080 #[test]
1081 fn test_logical_operations() {
1082 let mut annotator = TypeAnnotator::new(None, None);
1083
1084 let and_expr = Expression::And(Box::new(BinaryOp::new(
1086 make_bool_literal(true),
1087 make_bool_literal(false),
1088 )));
1089 assert_eq!(annotator.annotate(&and_expr), Some(DataType::Boolean));
1090
1091 let or_expr = Expression::Or(Box::new(BinaryOp::new(
1093 make_bool_literal(true),
1094 make_bool_literal(false),
1095 )));
1096 assert_eq!(annotator.annotate(&or_expr), Some(DataType::Boolean));
1097
1098 let not_expr = Expression::Not(Box::new(crate::expressions::UnaryOp::new(
1100 make_bool_literal(true),
1101 )));
1102 assert_eq!(annotator.annotate(¬_expr), Some(DataType::Boolean));
1103 }
1104
1105 #[test]
1110 fn test_subscript_array_type() {
1111 let mut annotator = TypeAnnotator::new(None, None);
1112
1113 let arr = Expression::Array(Box::new(crate::expressions::Array {
1115 expressions: vec![make_int_literal(1), make_int_literal(2)],
1116 }));
1117 let subscript = Expression::Subscript(Box::new(crate::expressions::Subscript {
1118 this: arr,
1119 index: make_int_literal(0),
1120 }));
1121 assert_eq!(
1122 annotator.annotate(&subscript),
1123 Some(DataType::Int { length: None, integer_spelling: false })
1124 );
1125 }
1126
1127 #[test]
1128 fn test_subscript_map_type() {
1129 let mut annotator = TypeAnnotator::new(None, None);
1130
1131 let map = Expression::Map(Box::new(crate::expressions::Map {
1133 keys: vec![make_string_literal("a")],
1134 values: vec![make_int_literal(1)],
1135 }));
1136 let subscript = Expression::Subscript(Box::new(crate::expressions::Subscript {
1137 this: map,
1138 index: make_string_literal("a"),
1139 }));
1140 assert_eq!(
1141 annotator.annotate(&subscript),
1142 Some(DataType::Int { length: None, integer_spelling: false })
1143 );
1144 }
1145
1146 #[test]
1147 fn test_struct_type() {
1148 let mut annotator = TypeAnnotator::new(None, None);
1149
1150 let struct_expr = Expression::Struct(Box::new(crate::expressions::Struct {
1152 fields: vec![
1153 (Some("name".to_string()), make_string_literal("Alice")),
1154 (Some("age".to_string()), make_int_literal(30)),
1155 ],
1156 }));
1157 let result = annotator.annotate(&struct_expr);
1158 assert!(matches!(result, Some(DataType::Struct { fields, .. }) if fields.len() == 2));
1159 }
1160
1161 #[test]
1162 fn test_map_type() {
1163 let mut annotator = TypeAnnotator::new(None, None);
1164
1165 let map_expr = Expression::Map(Box::new(crate::expressions::Map {
1167 keys: vec![make_string_literal("a"), make_string_literal("b")],
1168 values: vec![make_int_literal(1), make_int_literal(2)],
1169 }));
1170 let result = annotator.annotate(&map_expr);
1171 assert!(matches!(
1172 result,
1173 Some(DataType::Map { key_type, value_type })
1174 if matches!(*key_type, DataType::VarChar { .. })
1175 && matches!(*value_type, DataType::Int { .. })
1176 ));
1177 }
1178
1179 #[test]
1180 fn test_explode_array_type() {
1181 let mut annotator = TypeAnnotator::new(None, None);
1182
1183 let arr = Expression::Array(Box::new(crate::expressions::Array {
1185 expressions: vec![make_int_literal(1), make_int_literal(2)],
1186 }));
1187 let explode = Expression::Explode(Box::new(crate::expressions::UnaryFunc {
1188 this: arr,
1189 original_name: None,
1190 }));
1191 assert_eq!(
1192 annotator.annotate(&explode),
1193 Some(DataType::Int { length: None, integer_spelling: false })
1194 );
1195 }
1196
1197 #[test]
1198 fn test_unnest_array_type() {
1199 let mut annotator = TypeAnnotator::new(None, None);
1200
1201 let arr = Expression::Array(Box::new(crate::expressions::Array {
1203 expressions: vec![make_string_literal("a"), make_string_literal("b")],
1204 }));
1205 let unnest = Expression::Unnest(Box::new(crate::expressions::UnnestFunc {
1206 this: arr,
1207 expressions: Vec::new(),
1208 with_ordinality: false,
1209 alias: None,
1210 offset_alias: None,
1211 }));
1212 assert_eq!(
1213 annotator.annotate(&unnest),
1214 Some(DataType::VarChar { length: None, parenthesized_length: false })
1215 );
1216 }
1217
1218 #[test]
1219 fn test_set_operation_type() {
1220 let mut annotator = TypeAnnotator::new(None, None);
1221
1222 let select = Expression::Select(Box::new(crate::expressions::Select::default()));
1224 let union = Expression::Union(Box::new(crate::expressions::Union {
1225 left: select.clone(),
1226 right: select.clone(),
1227 all: false,
1228 distinct: false,
1229 with: None,
1230 order_by: None,
1231 limit: None,
1232 offset: None,
1233 by_name: false,
1234 side: None,
1235 kind: None,
1236 corresponding: false,
1237 strict: false,
1238 on_columns: Vec::new(),
1239 distribute_by: None,
1240 sort_by: None,
1241 cluster_by: None,
1242 }));
1243 assert_eq!(annotator.annotate(&union), None);
1244 }
1245}