Skip to main content

polyglot_sql/optimizer/
annotate_types.rs

1//! Type Annotation for SQL Expressions
2//!
3//! This module provides type inference and annotation for SQL AST nodes.
4//! It walks the expression tree and assigns data types to expressions based on:
5//! - Literal values (strings, numbers, booleans)
6//! - Column references (from schema)
7//! - Function return types
8//! - Operator result types (with coercion rules)
9//!
10//! Based on SQLGlot's optimizer/annotate_types.py
11
12use std::collections::HashMap;
13
14use crate::dialects::DialectType;
15use crate::expressions::{
16    BinaryOp, DataType, Expression, Function, Literal, Map, Struct, StructField, Subscript,
17};
18use crate::schema::Schema;
19
20/// Type coercion class for determining result types in binary operations.
21/// Higher-priority classes win during coercion.
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
23pub enum TypeCoercionClass {
24    /// Text types (CHAR, VARCHAR, TEXT)
25    Text = 0,
26    /// Numeric types (INT, FLOAT, DECIMAL, etc.)
27    Numeric = 1,
28    /// Time-like types (DATE, TIME, TIMESTAMP, INTERVAL)
29    Timelike = 2,
30}
31
32impl TypeCoercionClass {
33    /// Get the coercion class for a data type
34    pub fn from_data_type(dt: &DataType) -> Option<Self> {
35        match dt {
36            // Text types
37            DataType::Char { .. }
38            | DataType::VarChar { .. }
39            | DataType::Text
40            | DataType::Binary { .. }
41            | DataType::VarBinary { .. }
42            | DataType::Blob => Some(TypeCoercionClass::Text),
43
44            // Numeric types
45            DataType::Boolean
46            | DataType::TinyInt { .. }
47            | DataType::SmallInt { .. }
48            | DataType::Int { .. }
49            | DataType::BigInt { .. }
50            | DataType::Float { .. }
51            | DataType::Double { .. }
52            | DataType::Decimal { .. } => Some(TypeCoercionClass::Numeric),
53
54            // Timelike types
55            DataType::Date
56            | DataType::Time { .. }
57            | DataType::Timestamp { .. }
58            | DataType::Interval { .. } => Some(TypeCoercionClass::Timelike),
59
60            // Other types don't have a coercion class
61            _ => None,
62        }
63    }
64}
65
66/// Type annotation configuration and state
67pub struct TypeAnnotator<'a> {
68    /// Schema for looking up column types
69    _schema: Option<&'a dyn Schema>,
70    /// Dialect for dialect-specific type rules
71    _dialect: Option<DialectType>,
72    /// Whether to annotate types for all expressions
73    annotate_aggregates: bool,
74    /// Function return type mappings
75    function_return_types: HashMap<String, DataType>,
76}
77
78impl<'a> TypeAnnotator<'a> {
79    /// Create a new type annotator
80    pub fn new(schema: Option<&'a dyn Schema>, dialect: Option<DialectType>) -> Self {
81        let mut annotator = Self {
82            _schema: schema,
83            _dialect: dialect,
84            annotate_aggregates: true,
85            function_return_types: HashMap::new(),
86        };
87        annotator.init_function_return_types();
88        annotator
89    }
90
91    /// Initialize function return type mappings
92    fn init_function_return_types(&mut self) {
93        // Aggregate functions
94        self.function_return_types
95            .insert("COUNT".to_string(), DataType::BigInt { length: None });
96        self.function_return_types
97            .insert("SUM".to_string(), DataType::Decimal {
98                precision: None,
99                scale: None,
100            });
101        self.function_return_types
102            .insert("AVG".to_string(), DataType::Double { precision: None, scale: None });
103
104        // String functions
105        self.function_return_types
106            .insert("CONCAT".to_string(), DataType::VarChar { length: None, parenthesized_length: false });
107        self.function_return_types
108            .insert("UPPER".to_string(), DataType::VarChar { length: None, parenthesized_length: false });
109        self.function_return_types
110            .insert("LOWER".to_string(), DataType::VarChar { length: None, parenthesized_length: false });
111        self.function_return_types
112            .insert("TRIM".to_string(), DataType::VarChar { length: None, parenthesized_length: false });
113        self.function_return_types
114            .insert("LTRIM".to_string(), DataType::VarChar { length: None, parenthesized_length: false });
115        self.function_return_types
116            .insert("RTRIM".to_string(), DataType::VarChar { length: None, parenthesized_length: false });
117        self.function_return_types
118            .insert("SUBSTRING".to_string(), DataType::VarChar { length: None, parenthesized_length: false });
119        self.function_return_types
120            .insert("SUBSTR".to_string(), DataType::VarChar { length: None, parenthesized_length: false });
121        self.function_return_types
122            .insert("REPLACE".to_string(), DataType::VarChar { length: None, parenthesized_length: false });
123        self.function_return_types
124            .insert("LENGTH".to_string(), DataType::Int { length: None, integer_spelling: false });
125        self.function_return_types
126            .insert("CHAR_LENGTH".to_string(), DataType::Int { length: None, integer_spelling: false });
127
128        // Date/Time functions
129        self.function_return_types
130            .insert("NOW".to_string(), DataType::Timestamp {
131                precision: None,
132                timezone: false,
133            });
134        self.function_return_types
135            .insert("CURRENT_TIMESTAMP".to_string(), DataType::Timestamp {
136                precision: None,
137                timezone: false,
138            });
139        self.function_return_types
140            .insert("CURRENT_DATE".to_string(), DataType::Date);
141        self.function_return_types
142            .insert("CURRENT_TIME".to_string(), DataType::Time { precision: None, timezone: false });
143        self.function_return_types
144            .insert("DATE".to_string(), DataType::Date);
145        self.function_return_types
146            .insert("YEAR".to_string(), DataType::Int { length: None, integer_spelling: false });
147        self.function_return_types
148            .insert("MONTH".to_string(), DataType::Int { length: None, integer_spelling: false });
149        self.function_return_types
150            .insert("DAY".to_string(), DataType::Int { length: None, integer_spelling: false });
151        self.function_return_types
152            .insert("HOUR".to_string(), DataType::Int { length: None, integer_spelling: false });
153        self.function_return_types
154            .insert("MINUTE".to_string(), DataType::Int { length: None, integer_spelling: false });
155        self.function_return_types
156            .insert("SECOND".to_string(), DataType::Int { length: None, integer_spelling: false });
157        self.function_return_types
158            .insert("EXTRACT".to_string(), DataType::Int { length: None, integer_spelling: false });
159        self.function_return_types
160            .insert("DATE_DIFF".to_string(), DataType::Int { length: None, integer_spelling: false });
161        self.function_return_types
162            .insert("DATEDIFF".to_string(), DataType::Int { length: None, integer_spelling: false });
163
164        // Math functions
165        self.function_return_types
166            .insert("ABS".to_string(), DataType::Double { precision: None, scale: None });
167        self.function_return_types
168            .insert("ROUND".to_string(), DataType::Double { precision: None, scale: None });
169        self.function_return_types
170            .insert("FLOOR".to_string(), DataType::BigInt { length: None });
171        self.function_return_types
172            .insert("CEIL".to_string(), DataType::BigInt { length: None });
173        self.function_return_types
174            .insert("CEILING".to_string(), DataType::BigInt { length: None });
175        self.function_return_types
176            .insert("SQRT".to_string(), DataType::Double { precision: None, scale: None });
177        self.function_return_types
178            .insert("POWER".to_string(), DataType::Double { precision: None, scale: None });
179        self.function_return_types
180            .insert("MOD".to_string(), DataType::Int { length: None, integer_spelling: false });
181        self.function_return_types
182            .insert("LOG".to_string(), DataType::Double { precision: None, scale: None });
183        self.function_return_types
184            .insert("LN".to_string(), DataType::Double { precision: None, scale: None });
185        self.function_return_types
186            .insert("EXP".to_string(), DataType::Double { precision: None, scale: None });
187
188        // Null-handling functions return Unknown (infer from args)
189        self.function_return_types
190            .insert("COALESCE".to_string(), DataType::Unknown);
191        self.function_return_types
192            .insert("NULLIF".to_string(), DataType::Unknown);
193        self.function_return_types
194            .insert("GREATEST".to_string(), DataType::Unknown);
195        self.function_return_types
196            .insert("LEAST".to_string(), DataType::Unknown);
197    }
198
199    /// Annotate types for an expression tree
200    pub fn annotate(&mut self, expr: &Expression) -> Option<DataType> {
201        match expr {
202            // Literals
203            Expression::Literal(lit) => self.annotate_literal(lit),
204            Expression::Boolean(_) => Some(DataType::Boolean),
205            Expression::Null(_) => None, // NULL has no type
206
207            // Arithmetic binary operations
208            Expression::Add(op) | Expression::Sub(op) |
209            Expression::Mul(op) | Expression::Div(op) |
210            Expression::Mod(op) => self.annotate_arithmetic(op),
211
212            // Comparison operations - always boolean
213            Expression::Eq(_) | Expression::Neq(_) |
214            Expression::Lt(_) | Expression::Lte(_) |
215            Expression::Gt(_) | Expression::Gte(_) |
216            Expression::Like(_) | Expression::ILike(_) => Some(DataType::Boolean),
217
218            // Logical operations - always boolean
219            Expression::And(_) | Expression::Or(_) | Expression::Not(_) => Some(DataType::Boolean),
220
221            // Predicates - always boolean
222            Expression::Between(_) | Expression::In(_) |
223            Expression::IsNull(_) | Expression::IsTrue(_) | Expression::IsFalse(_) |
224            Expression::Is(_) | Expression::Exists(_) => Some(DataType::Boolean),
225
226            // String concatenation
227            Expression::Concat(_) => Some(DataType::VarChar { length: None, parenthesized_length: false }),
228
229            // Bitwise operations - integer
230            Expression::BitwiseAnd(_) | Expression::BitwiseOr(_) |
231            Expression::BitwiseXor(_) | Expression::BitwiseNot(_) => {
232                Some(DataType::BigInt { length: None })
233            }
234
235            // Negation preserves type
236            Expression::Neg(op) => self.annotate(&op.this),
237
238            // Functions
239            Expression::Function(func) => self.annotate_function(func),
240
241            // Typed aggregate functions
242            Expression::Count(_) => Some(DataType::BigInt { length: None }),
243            Expression::Sum(agg) => self.annotate_sum(&agg.this),
244            Expression::Avg(_) => Some(DataType::Double { precision: None, scale: None }),
245            Expression::Min(agg) => self.annotate(&agg.this),
246            Expression::Max(agg) => self.annotate(&agg.this),
247            Expression::GroupConcat(_) | Expression::StringAgg(_) | Expression::ListAgg(_) => {
248                Some(DataType::VarChar { length: None, parenthesized_length: false })
249            }
250
251            // Generic aggregate function
252            Expression::AggregateFunction(agg) => {
253                if !self.annotate_aggregates {
254                    return None;
255                }
256                let func_name = agg.name.to_uppercase();
257                self.get_aggregate_return_type(&func_name, &agg.args)
258            }
259
260            // Column references - look up type from schema if available
261            Expression::Column(col) => {
262                if let Some(schema) = &self._schema {
263                    let table_name = col.table.as_ref().map(|t| t.name.as_str()).unwrap_or("");
264                    schema
265                        .get_column_type(table_name, &col.name.name)
266                        .ok()
267                } else {
268                    None
269                }
270            }
271
272            // Cast expressions
273            Expression::Cast(cast) => Some(cast.to.clone()),
274            Expression::SafeCast(cast) => Some(cast.to.clone()),
275            Expression::TryCast(cast) => Some(cast.to.clone()),
276
277            // Subqueries - type is the type of the first SELECT expression
278            Expression::Subquery(subq) => {
279                if let Expression::Select(select) = &subq.this {
280                    if let Some(first) = select.expressions.first() {
281                        self.annotate(first)
282                    } else {
283                        None
284                    }
285                } else {
286                    None
287                }
288            }
289
290            // CASE expression - type of the first THEN/ELSE
291            Expression::Case(case) => {
292                if let Some(else_expr) = &case.else_ {
293                    self.annotate(else_expr)
294                } else if let Some((_, then_expr)) = case.whens.first() {
295                    self.annotate(then_expr)
296                } else {
297                    None
298                }
299            }
300
301            // Array expressions
302            Expression::Array(arr) => {
303                if let Some(first) = arr.expressions.first() {
304                    if let Some(elem_type) = self.annotate(first) {
305                        Some(DataType::Array {
306                            element_type: Box::new(elem_type),
307                            dimension: None,
308                        })
309                    } else {
310                        Some(DataType::Array {
311                            element_type: Box::new(DataType::Unknown),
312                            dimension: None,
313                        })
314                    }
315                } else {
316                    Some(DataType::Array {
317                        element_type: Box::new(DataType::Unknown),
318                        dimension: None,
319                    })
320                }
321            }
322
323            // Interval expressions
324            Expression::Interval(_) => Some(DataType::Interval { unit: None, to: None }),
325
326            // Window functions inherit type from their function
327            Expression::WindowFunction(window) => self.annotate(&window.this),
328
329            // Date/time expressions
330            Expression::CurrentDate(_) => Some(DataType::Date),
331            Expression::CurrentTime(_) => Some(DataType::Time { precision: None, timezone: false }),
332            Expression::CurrentTimestamp(_) | Expression::CurrentTimestampLTZ(_) => {
333                Some(DataType::Timestamp {
334                    precision: None,
335                    timezone: false,
336                })
337            }
338
339            // Date functions
340            Expression::DateAdd(_) | Expression::DateSub(_) |
341            Expression::ToDate(_) | Expression::Date(_) => Some(DataType::Date),
342            Expression::DateDiff(_) | Expression::Extract(_) => Some(DataType::Int { length: None, integer_spelling: false }),
343            Expression::ToTimestamp(_) => Some(DataType::Timestamp {
344                precision: None,
345                timezone: false,
346            }),
347
348            // String functions
349            Expression::Upper(_) | Expression::Lower(_) | Expression::Trim(_) |
350            Expression::LTrim(_) | Expression::RTrim(_) | Expression::Replace(_) |
351            Expression::Substring(_) | Expression::Reverse(_) | Expression::Left(_) |
352            Expression::Right(_) | Expression::Repeat(_) | Expression::Lpad(_) |
353            Expression::Rpad(_) | Expression::ConcatWs(_) | Expression::Overlay(_) => {
354                Some(DataType::VarChar { length: None, parenthesized_length: false })
355            }
356            Expression::Length(_) => Some(DataType::Int { length: None, integer_spelling: false }),
357
358            // Math functions
359            Expression::Abs(_) | Expression::Sqrt(_) | Expression::Cbrt(_) |
360            Expression::Ln(_) | Expression::Exp(_) | Expression::Power(_) |
361            Expression::Log(_) => Some(DataType::Double { precision: None, scale: None }),
362            Expression::Round(_) => Some(DataType::Double { precision: None, scale: None }),
363            Expression::Floor(_) | Expression::Ceil(_) | Expression::Sign(_) => {
364                Some(DataType::BigInt { length: None })
365            }
366
367            // Greatest/Least - coerce argument types
368            Expression::Greatest(v) | Expression::Least(v) => {
369                self.coerce_arg_types(&v.expressions)
370            }
371
372            // Alias - type of the inner expression
373            Expression::Alias(alias) => self.annotate(&alias.this),
374
375            // SELECT expressions - no scalar type
376            Expression::Select(_) => None,
377
378            // ============================================
379            // 3.1.8: Array/Map Indexing (Subscript/Bracket)
380            // ============================================
381            Expression::Subscript(sub) => self.annotate_subscript(sub),
382
383            // Dot access (struct.field) - returns Unknown without schema
384            Expression::Dot(_) => None,
385
386            // ============================================
387            // 3.1.9: STRUCT Construction
388            // ============================================
389            Expression::Struct(s) => self.annotate_struct(s),
390
391            // ============================================
392            // 3.1.10: MAP Construction
393            // ============================================
394            Expression::Map(map) => self.annotate_map(map),
395            Expression::MapFromEntries(mfe) => {
396                // MAP_FROM_ENTRIES(array_of_pairs) - infer from array element type
397                if let Some(DataType::Array { element_type, .. }) = self.annotate(&mfe.this) {
398                    if let DataType::Struct { fields, .. } = *element_type {
399                        if fields.len() >= 2 {
400                            return Some(DataType::Map {
401                                key_type: Box::new(fields[0].data_type.clone()),
402                                value_type: Box::new(fields[1].data_type.clone()),
403                            });
404                        }
405                    }
406                }
407                Some(DataType::Map {
408                    key_type: Box::new(DataType::Unknown),
409                    value_type: Box::new(DataType::Unknown),
410                })
411            }
412
413            // ============================================
414            // 3.1.11: SetOperation Type Coercion
415            // ============================================
416            Expression::Union(union) => self.annotate_set_operation(&union.left, &union.right),
417            Expression::Intersect(intersect) => {
418                self.annotate_set_operation(&intersect.left, &intersect.right)
419            }
420            Expression::Except(except) => {
421                self.annotate_set_operation(&except.left, &except.right)
422            }
423
424            // ============================================
425            // 3.1.12: UDTF Type Handling
426            // ============================================
427            Expression::Lateral(lateral) => {
428                // LATERAL subquery - type is the subquery's type
429                self.annotate(&lateral.this)
430            }
431            Expression::LateralView(lv) => {
432                // LATERAL VIEW - returns the exploded type
433                self.annotate_lateral_view(lv)
434            }
435            Expression::Unnest(unnest) => {
436                // UNNEST(array) - returns the element type of the array
437                if let Some(DataType::Array { element_type, .. }) = self.annotate(&unnest.this) {
438                    Some(*element_type)
439                } else {
440                    None
441                }
442            }
443            Expression::Explode(explode) => {
444                // EXPLODE(array) - returns the element type
445                if let Some(DataType::Array { element_type, .. }) = self.annotate(&explode.this) {
446                    Some(*element_type)
447                } else if let Some(DataType::Map { key_type, value_type }) =
448                    self.annotate(&explode.this)
449                {
450                    // EXPLODE(map) returns struct(key, value)
451                    Some(DataType::Struct {
452                        fields: vec![
453                            StructField::new("key".to_string(), *key_type),
454                            StructField::new("value".to_string(), *value_type),
455                        ],
456                        nested: false,
457                    })
458                } else {
459                    None
460                }
461            }
462            Expression::ExplodeOuter(explode) => {
463                // EXPLODE_OUTER - same as EXPLODE but preserves nulls
464                if let Some(DataType::Array { element_type, .. }) = self.annotate(&explode.this) {
465                    Some(*element_type)
466                } else {
467                    None
468                }
469            }
470            Expression::GenerateSeries(gs) => {
471                // GENERATE_SERIES returns the type of start/end
472                if let Some(ref start) = gs.start {
473                    self.annotate(start)
474                } else if let Some(ref end) = gs.end {
475                    self.annotate(end)
476                } else {
477                    Some(DataType::Int { length: None, integer_spelling: false })
478                }
479            }
480
481            // Other expressions - unknown
482            _ => None,
483        }
484    }
485
486    /// Annotate a subscript/bracket expression (array[index] or map[key])
487    fn annotate_subscript(&mut self, sub: &Subscript) -> Option<DataType> {
488        let base_type = self.annotate(&sub.this)?;
489
490        match base_type {
491            DataType::Array { element_type, .. } => Some(*element_type),
492            DataType::Map { value_type, .. } => Some(*value_type),
493            DataType::Json | DataType::JsonB => Some(DataType::Json), // JSON indexing returns JSON
494            DataType::VarChar { .. } | DataType::Text => {
495                // String indexing returns a character
496                Some(DataType::VarChar { length: Some(1), parenthesized_length: false })
497            }
498            _ => None,
499        }
500    }
501
502    /// Annotate a STRUCT literal
503    fn annotate_struct(&mut self, s: &Struct) -> Option<DataType> {
504        let fields: Vec<StructField> = s
505            .fields
506            .iter()
507            .map(|(name, expr)| {
508                let field_type = self.annotate(expr).unwrap_or(DataType::Unknown);
509                StructField::new(name.clone().unwrap_or_default(), field_type)
510            })
511            .collect();
512        Some(DataType::Struct { fields, nested: false })
513    }
514
515    /// Annotate a MAP literal
516    fn annotate_map(&mut self, map: &Map) -> Option<DataType> {
517        let key_type = if let Some(first_key) = map.keys.first() {
518            self.annotate(first_key).unwrap_or(DataType::Unknown)
519        } else {
520            DataType::Unknown
521        };
522
523        let value_type = if let Some(first_value) = map.values.first() {
524            self.annotate(first_value).unwrap_or(DataType::Unknown)
525        } else {
526            DataType::Unknown
527        };
528
529        Some(DataType::Map {
530            key_type: Box::new(key_type),
531            value_type: Box::new(value_type),
532        })
533    }
534
535    /// Annotate a SetOperation (UNION/INTERSECT/EXCEPT)
536    /// Returns None since set operations produce relation types, not scalar types
537    fn annotate_set_operation(
538        &mut self,
539        _left: &Expression,
540        _right: &Expression,
541    ) -> Option<DataType> {
542        // Set operations produce relations, not scalar types
543        // The column types would be coerced between left and right
544        // For now, return None as this is a relation-level type
545        None
546    }
547
548    /// Annotate a LATERAL VIEW expression
549    fn annotate_lateral_view(
550        &mut self,
551        lv: &crate::expressions::LateralView,
552    ) -> Option<DataType> {
553        // The type depends on the table-generating function
554        self.annotate(&lv.this)
555    }
556
557    /// Annotate a literal value
558    fn annotate_literal(&self, lit: &Literal) -> Option<DataType> {
559        match lit {
560            Literal::String(_) | Literal::NationalString(_) |
561            Literal::TripleQuotedString(_, _) | Literal::EscapeString(_) |
562            Literal::DollarString(_) | Literal::RawString(_) => Some(DataType::VarChar { length: None, parenthesized_length: false }),
563            Literal::Number(n) => {
564                // Try to determine if it's an integer or float
565                if n.contains('.') || n.contains('e') || n.contains('E') {
566                    Some(DataType::Double { precision: None, scale: None })
567                } else {
568                    // Check if it fits in an Int or needs BigInt
569                    if let Ok(_) = n.parse::<i32>() {
570                        Some(DataType::Int { length: None, integer_spelling: false })
571                    } else {
572                        Some(DataType::BigInt { length: None })
573                    }
574                }
575            }
576            Literal::HexString(_) | Literal::BitString(_) | Literal::ByteString(_) => {
577                Some(DataType::VarBinary { length: None })
578            }
579            Literal::HexNumber(_) => Some(DataType::BigInt { length: None }),
580            Literal::Date(_) => Some(DataType::Date),
581            Literal::Time(_) => Some(DataType::Time { precision: None, timezone: false }),
582            Literal::Timestamp(_) => Some(DataType::Timestamp {
583                precision: None,
584                timezone: false,
585            }),
586            Literal::Datetime(_) => Some(DataType::Custom {
587                name: "DATETIME".to_string(),
588            }),
589        }
590    }
591
592    /// Annotate an arithmetic binary operation
593    fn annotate_arithmetic(&mut self, op: &BinaryOp) -> Option<DataType> {
594        let left_type = self.annotate(&op.left);
595        let right_type = self.annotate(&op.right);
596
597        match (left_type, right_type) {
598            (Some(l), Some(r)) => self.coerce_types(&l, &r),
599            (Some(t), None) | (None, Some(t)) => Some(t),
600            (None, None) => None,
601        }
602    }
603
604    /// Annotate a function call
605    fn annotate_function(&mut self, func: &Function) -> Option<DataType> {
606        let func_name = func.name.to_uppercase();
607
608        // Check known function return types
609        if let Some(return_type) = self.function_return_types.get(&func_name) {
610            if *return_type != DataType::Unknown {
611                return Some(return_type.clone());
612            }
613        }
614
615        // For functions with Unknown return type, infer from arguments
616        match func_name.as_str() {
617            "COALESCE" | "IFNULL" | "NVL" | "ISNULL" => {
618                // Return type of first non-null argument
619                for arg in &func.args {
620                    if let Some(arg_type) = self.annotate(arg) {
621                        return Some(arg_type);
622                    }
623                }
624                None
625            }
626            "NULLIF" => {
627                // Return type of first argument
628                func.args.first().and_then(|arg| self.annotate(arg))
629            }
630            "GREATEST" | "LEAST" => {
631                // Coerce all argument types
632                self.coerce_arg_types(&func.args)
633            }
634            "IF" | "IIF" => {
635                // Return type of THEN/ELSE branches
636                if func.args.len() >= 2 {
637                    self.annotate(&func.args[1])
638                } else {
639                    None
640                }
641            }
642            _ => {
643                // Unknown function - try to infer from first argument
644                func.args.first().and_then(|arg| self.annotate(arg))
645            }
646        }
647    }
648
649    /// Get return type for aggregate functions
650    fn get_aggregate_return_type(&mut self, func_name: &str, args: &[Expression]) -> Option<DataType> {
651        match func_name {
652            "COUNT" | "COUNT_IF" => Some(DataType::BigInt { length: None }),
653            "SUM" => {
654                if let Some(arg) = args.first() {
655                    self.annotate_sum(arg)
656                } else {
657                    Some(DataType::Decimal {
658                        precision: None,
659                        scale: None,
660                    })
661                }
662            }
663            "AVG" => Some(DataType::Double { precision: None, scale: None }),
664            "MIN" | "MAX" => {
665                // Preserves input type
666                args.first().and_then(|arg| self.annotate(arg))
667            }
668            "STRING_AGG" | "GROUP_CONCAT" | "LISTAGG" | "ARRAY_AGG" => {
669                Some(DataType::VarChar { length: None, parenthesized_length: false })
670            }
671            "BOOL_AND" | "BOOL_OR" | "EVERY" | "ANY" | "SOME" => Some(DataType::Boolean),
672            "BIT_AND" | "BIT_OR" | "BIT_XOR" => Some(DataType::BigInt { length: None }),
673            "STDDEV" | "STDDEV_POP" | "STDDEV_SAMP" | "VARIANCE" | "VAR_POP" | "VAR_SAMP" => {
674                Some(DataType::Double { precision: None, scale: None })
675            }
676            "PERCENTILE_CONT" | "PERCENTILE_DISC" | "MEDIAN" => {
677                args.first().and_then(|arg| self.annotate(arg))
678            }
679            _ => None,
680        }
681    }
682
683    /// Annotate SUM function - promotes to at least BigInt
684    fn annotate_sum(&mut self, arg: &Expression) -> Option<DataType> {
685        match self.annotate(arg) {
686            Some(DataType::TinyInt { .. })
687            | Some(DataType::SmallInt { .. })
688            | Some(DataType::Int { .. }) => Some(DataType::BigInt { length: None }),
689            Some(DataType::BigInt { .. }) => Some(DataType::BigInt { length: None }),
690            Some(DataType::Float { .. }) | Some(DataType::Double { .. }) => Some(DataType::Double { precision: None, scale: None }),
691            Some(DataType::Decimal { precision, scale }) => {
692                Some(DataType::Decimal { precision, scale })
693            }
694            _ => Some(DataType::Decimal {
695                precision: None,
696                scale: None,
697            }),
698        }
699    }
700
701    /// Coerce multiple argument types to a common type
702    fn coerce_arg_types(&mut self, args: &[Expression]) -> Option<DataType> {
703        let mut result_type: Option<DataType> = None;
704        for arg in args {
705            if let Some(arg_type) = self.annotate(arg) {
706                result_type = match result_type {
707                    Some(t) => self.coerce_types(&t, &arg_type),
708                    None => Some(arg_type),
709                };
710            }
711        }
712        result_type
713    }
714
715    /// Coerce two types to a common type
716    fn coerce_types(&self, left: &DataType, right: &DataType) -> Option<DataType> {
717        // If types are the same, return that type
718        if left == right {
719            return Some(left.clone());
720        }
721
722        // Special case: Interval + Date/Timestamp
723        match (left, right) {
724            (DataType::Date, DataType::Interval { .. }) |
725            (DataType::Interval { .. }, DataType::Date) => return Some(DataType::Date),
726            (DataType::Timestamp { precision, timezone }, DataType::Interval { .. }) |
727            (DataType::Interval { .. }, DataType::Timestamp { precision, timezone }) => {
728                return Some(DataType::Timestamp {
729                    precision: *precision,
730                    timezone: *timezone
731                });
732            }
733            _ => {}
734        }
735
736        // Coerce based on class
737        let left_class = TypeCoercionClass::from_data_type(left);
738        let right_class = TypeCoercionClass::from_data_type(right);
739
740        match (left_class, right_class) {
741            // Same class: use higher-precision type within class
742            (Some(lc), Some(rc)) if lc == rc => {
743                // For numeric, choose wider type
744                if lc == TypeCoercionClass::Numeric {
745                    Some(self.wider_numeric_type(left, right))
746                } else {
747                    // For text and timelike, left wins by default
748                    Some(left.clone())
749                }
750            }
751            // Different classes: higher-priority class wins
752            (Some(lc), Some(rc)) => {
753                if lc > rc {
754                    Some(left.clone())
755                } else {
756                    Some(right.clone())
757                }
758            }
759            // One unknown: use the known type
760            (Some(_), None) => Some(left.clone()),
761            (None, Some(_)) => Some(right.clone()),
762            // Both unknown: return unknown
763            (None, None) => Some(DataType::Unknown),
764        }
765    }
766
767    /// Get the wider numeric type
768    fn wider_numeric_type(&self, left: &DataType, right: &DataType) -> DataType {
769        let order = |dt: &DataType| -> u8 {
770            match dt {
771                DataType::Boolean => 0,
772                DataType::TinyInt { .. } => 1,
773                DataType::SmallInt { .. } => 2,
774                DataType::Int { .. } => 3,
775                DataType::BigInt { .. } => 4,
776                DataType::Float { .. } => 5,
777                DataType::Double { .. } => 6,
778                DataType::Decimal { .. } => 7,
779                _ => 0,
780            }
781        };
782
783        if order(left) >= order(right) {
784            left.clone()
785        } else {
786            right.clone()
787        }
788    }
789}
790
791/// Convenience function to annotate types in an expression tree
792pub fn annotate_types(
793    expr: &Expression,
794    schema: Option<&dyn Schema>,
795    dialect: Option<DialectType>,
796) -> Option<DataType> {
797    let mut annotator = TypeAnnotator::new(schema, dialect);
798    annotator.annotate(expr)
799}
800
801#[cfg(test)]
802mod tests {
803    use super::*;
804    use crate::expressions::{BooleanLiteral, Cast, Null};
805
806    fn make_int_literal(val: i64) -> Expression {
807        Expression::Literal(Literal::Number(val.to_string()))
808    }
809
810    fn make_float_literal(val: f64) -> Expression {
811        Expression::Literal(Literal::Number(val.to_string()))
812    }
813
814    fn make_string_literal(val: &str) -> Expression {
815        Expression::Literal(Literal::String(val.to_string()))
816    }
817
818    fn make_bool_literal(val: bool) -> Expression {
819        Expression::Boolean(BooleanLiteral { value: val })
820    }
821
822    #[test]
823    fn test_literal_types() {
824        let mut annotator = TypeAnnotator::new(None, None);
825
826        // Integer literal
827        let int_expr = make_int_literal(42);
828        assert_eq!(
829            annotator.annotate(&int_expr),
830            Some(DataType::Int { length: None, integer_spelling: false })
831        );
832
833        // Float literal
834        let float_expr = make_float_literal(3.14);
835        assert_eq!(annotator.annotate(&float_expr), Some(DataType::Double { precision: None, scale: None }));
836
837        // String literal
838        let string_expr = make_string_literal("hello");
839        assert_eq!(
840            annotator.annotate(&string_expr),
841            Some(DataType::VarChar { length: None, parenthesized_length: false })
842        );
843
844        // Boolean literal
845        let bool_expr = make_bool_literal(true);
846        assert_eq!(annotator.annotate(&bool_expr), Some(DataType::Boolean));
847
848        // Null literal
849        let null_expr = Expression::Null(Null);
850        assert_eq!(annotator.annotate(&null_expr), None);
851    }
852
853    #[test]
854    fn test_comparison_types() {
855        let mut annotator = TypeAnnotator::new(None, None);
856
857        // Comparison returns boolean
858        let cmp = Expression::Gt(Box::new(BinaryOp::new(
859            make_int_literal(1),
860            make_int_literal(2),
861        )));
862        assert_eq!(annotator.annotate(&cmp), Some(DataType::Boolean));
863
864        // Equality returns boolean
865        let eq = Expression::Eq(Box::new(BinaryOp::new(
866            make_string_literal("a"),
867            make_string_literal("b"),
868        )));
869        assert_eq!(annotator.annotate(&eq), Some(DataType::Boolean));
870    }
871
872    #[test]
873    fn test_arithmetic_types() {
874        let mut annotator = TypeAnnotator::new(None, None);
875
876        // Int + Int = Int
877        let add_int = Expression::Add(Box::new(BinaryOp::new(
878            make_int_literal(1),
879            make_int_literal(2),
880        )));
881        assert_eq!(
882            annotator.annotate(&add_int),
883            Some(DataType::Int { length: None, integer_spelling: false })
884        );
885
886        // Int + Float = Double (wider type)
887        let add_mixed = Expression::Add(Box::new(BinaryOp::new(
888            make_int_literal(1),
889            make_float_literal(2.5), // Use 2.5 so the string has a decimal point
890        )));
891        assert_eq!(annotator.annotate(&add_mixed), Some(DataType::Double { precision: None, scale: None }));
892    }
893
894    #[test]
895    fn test_string_concat_type() {
896        let mut annotator = TypeAnnotator::new(None, None);
897
898        // String || String = VarChar
899        let concat = Expression::Concat(Box::new(BinaryOp::new(
900            make_string_literal("hello"),
901            make_string_literal(" world"),
902        )));
903        assert_eq!(
904            annotator.annotate(&concat),
905            Some(DataType::VarChar { length: None, parenthesized_length: false })
906        );
907    }
908
909    #[test]
910    fn test_cast_type() {
911        let mut annotator = TypeAnnotator::new(None, None);
912
913        // CAST(1 AS VARCHAR)
914        let cast = Expression::Cast(Box::new(Cast {
915            this: make_int_literal(1),
916            to: DataType::VarChar { length: Some(10), parenthesized_length: false },
917            trailing_comments: vec![],
918            double_colon_syntax: false,
919            format: None,
920            default: None,
921        }));
922        assert_eq!(
923            annotator.annotate(&cast),
924            Some(DataType::VarChar { length: Some(10), parenthesized_length: false })
925        );
926    }
927
928    #[test]
929    fn test_function_types() {
930        let mut annotator = TypeAnnotator::new(None, None);
931
932        // COUNT returns BigInt
933        let count = Expression::Function(Box::new(Function::new("COUNT", vec![make_int_literal(1)])));
934        assert_eq!(
935            annotator.annotate(&count),
936            Some(DataType::BigInt { length: None })
937        );
938
939        // UPPER returns VarChar
940        let upper = Expression::Function(Box::new(Function::new("UPPER", vec![make_string_literal("hello")])));
941        assert_eq!(
942            annotator.annotate(&upper),
943            Some(DataType::VarChar { length: None, parenthesized_length: false })
944        );
945
946        // NOW returns Timestamp
947        let now = Expression::Function(Box::new(Function::new("NOW", vec![])));
948        assert_eq!(
949            annotator.annotate(&now),
950            Some(DataType::Timestamp {
951                precision: None,
952                timezone: false
953            })
954        );
955    }
956
957    #[test]
958    fn test_coalesce_type_inference() {
959        let mut annotator = TypeAnnotator::new(None, None);
960
961        // COALESCE(NULL, 1) returns Int (type of first non-null arg)
962        let coalesce = Expression::Function(Box::new(Function::new(
963            "COALESCE",
964            vec![
965                Expression::Null(Null),
966                make_int_literal(1),
967            ],
968        )));
969        assert_eq!(
970            annotator.annotate(&coalesce),
971            Some(DataType::Int { length: None, integer_spelling: false })
972        );
973    }
974
975    #[test]
976    fn test_type_coercion_class() {
977        // Text types
978        assert_eq!(
979            TypeCoercionClass::from_data_type(&DataType::VarChar { length: None, parenthesized_length: false }),
980            Some(TypeCoercionClass::Text)
981        );
982        assert_eq!(
983            TypeCoercionClass::from_data_type(&DataType::Text),
984            Some(TypeCoercionClass::Text)
985        );
986
987        // Numeric types
988        assert_eq!(
989            TypeCoercionClass::from_data_type(&DataType::Int { length: None, integer_spelling: false }),
990            Some(TypeCoercionClass::Numeric)
991        );
992        assert_eq!(
993            TypeCoercionClass::from_data_type(&DataType::Double { precision: None, scale: None }),
994            Some(TypeCoercionClass::Numeric)
995        );
996
997        // Timelike types
998        assert_eq!(
999            TypeCoercionClass::from_data_type(&DataType::Date),
1000            Some(TypeCoercionClass::Timelike)
1001        );
1002        assert_eq!(
1003            TypeCoercionClass::from_data_type(&DataType::Timestamp {
1004                precision: None,
1005                timezone: false
1006            }),
1007            Some(TypeCoercionClass::Timelike)
1008        );
1009
1010        // Unknown types
1011        assert_eq!(TypeCoercionClass::from_data_type(&DataType::Json), None);
1012    }
1013
1014    #[test]
1015    fn test_wider_numeric_type() {
1016        let annotator = TypeAnnotator::new(None, None);
1017
1018        // Int vs BigInt -> BigInt
1019        let result = annotator.wider_numeric_type(
1020            &DataType::Int { length: None, integer_spelling: false },
1021            &DataType::BigInt { length: None },
1022        );
1023        assert_eq!(result, DataType::BigInt { length: None });
1024
1025        // Float vs Double -> Double
1026        let result = annotator.wider_numeric_type(&DataType::Float { precision: None, scale: None, real_spelling: false }, &DataType::Double { precision: None, scale: None });
1027        assert_eq!(result, DataType::Double { precision: None, scale: None });
1028
1029        // Int vs Double -> Double
1030        let result = annotator.wider_numeric_type(
1031            &DataType::Int { length: None, integer_spelling: false },
1032            &DataType::Double { precision: None, scale: None },
1033        );
1034        assert_eq!(result, DataType::Double { precision: None, scale: None });
1035    }
1036
1037    #[test]
1038    fn test_aggregate_return_types() {
1039        let mut annotator = TypeAnnotator::new(None, None);
1040
1041        // SUM(int) returns BigInt
1042        let sum_type = annotator.get_aggregate_return_type("SUM", &[make_int_literal(1)]);
1043        assert_eq!(sum_type, Some(DataType::BigInt { length: None }));
1044
1045        // AVG always returns Double
1046        let avg_type = annotator.get_aggregate_return_type("AVG", &[make_int_literal(1)]);
1047        assert_eq!(avg_type, Some(DataType::Double { precision: None, scale: None }));
1048
1049        // MIN/MAX preserve input type
1050        let min_type = annotator.get_aggregate_return_type("MIN", &[make_string_literal("a")]);
1051        assert_eq!(min_type, Some(DataType::VarChar { length: None, parenthesized_length: false }));
1052    }
1053
1054    #[test]
1055    fn test_date_literal_types() {
1056        let mut annotator = TypeAnnotator::new(None, None);
1057
1058        // DATE literal
1059        let date_expr = Expression::Literal(Literal::Date("2024-01-15".to_string()));
1060        assert_eq!(annotator.annotate(&date_expr), Some(DataType::Date));
1061
1062        // TIME literal
1063        let time_expr = Expression::Literal(Literal::Time("10:30:00".to_string()));
1064        assert_eq!(
1065            annotator.annotate(&time_expr),
1066            Some(DataType::Time { precision: None, timezone: false })
1067        );
1068
1069        // TIMESTAMP literal
1070        let ts_expr = Expression::Literal(Literal::Timestamp("2024-01-15 10:30:00".to_string()));
1071        assert_eq!(
1072            annotator.annotate(&ts_expr),
1073            Some(DataType::Timestamp {
1074                precision: None,
1075                timezone: false
1076            })
1077        );
1078    }
1079
1080    #[test]
1081    fn test_logical_operations() {
1082        let mut annotator = TypeAnnotator::new(None, None);
1083
1084        // AND returns boolean
1085        let and_expr = Expression::And(Box::new(BinaryOp::new(
1086            make_bool_literal(true),
1087            make_bool_literal(false),
1088        )));
1089        assert_eq!(annotator.annotate(&and_expr), Some(DataType::Boolean));
1090
1091        // OR returns boolean
1092        let or_expr = Expression::Or(Box::new(BinaryOp::new(
1093            make_bool_literal(true),
1094            make_bool_literal(false),
1095        )));
1096        assert_eq!(annotator.annotate(&or_expr), Some(DataType::Boolean));
1097
1098        // NOT returns boolean
1099        let not_expr = Expression::Not(Box::new(crate::expressions::UnaryOp::new(
1100            make_bool_literal(true),
1101        )));
1102        assert_eq!(annotator.annotate(&not_expr), Some(DataType::Boolean));
1103    }
1104
1105    // ========================================
1106    // Tests for newly implemented features
1107    // ========================================
1108
1109    #[test]
1110    fn test_subscript_array_type() {
1111        let mut annotator = TypeAnnotator::new(None, None);
1112
1113        // Array[index] returns element type
1114        let arr = Expression::Array(Box::new(crate::expressions::Array {
1115            expressions: vec![make_int_literal(1), make_int_literal(2)],
1116        }));
1117        let subscript = Expression::Subscript(Box::new(crate::expressions::Subscript {
1118            this: arr,
1119            index: make_int_literal(0),
1120        }));
1121        assert_eq!(
1122            annotator.annotate(&subscript),
1123            Some(DataType::Int { length: None, integer_spelling: false })
1124        );
1125    }
1126
1127    #[test]
1128    fn test_subscript_map_type() {
1129        let mut annotator = TypeAnnotator::new(None, None);
1130
1131        // Map[key] returns value type
1132        let map = Expression::Map(Box::new(crate::expressions::Map {
1133            keys: vec![make_string_literal("a")],
1134            values: vec![make_int_literal(1)],
1135        }));
1136        let subscript = Expression::Subscript(Box::new(crate::expressions::Subscript {
1137            this: map,
1138            index: make_string_literal("a"),
1139        }));
1140        assert_eq!(
1141            annotator.annotate(&subscript),
1142            Some(DataType::Int { length: None, integer_spelling: false })
1143        );
1144    }
1145
1146    #[test]
1147    fn test_struct_type() {
1148        let mut annotator = TypeAnnotator::new(None, None);
1149
1150        // STRUCT literal
1151        let struct_expr = Expression::Struct(Box::new(crate::expressions::Struct {
1152            fields: vec![
1153                (Some("name".to_string()), make_string_literal("Alice")),
1154                (Some("age".to_string()), make_int_literal(30)),
1155            ],
1156        }));
1157        let result = annotator.annotate(&struct_expr);
1158        assert!(matches!(result, Some(DataType::Struct { fields, .. }) if fields.len() == 2));
1159    }
1160
1161    #[test]
1162    fn test_map_type() {
1163        let mut annotator = TypeAnnotator::new(None, None);
1164
1165        // MAP literal
1166        let map_expr = Expression::Map(Box::new(crate::expressions::Map {
1167            keys: vec![make_string_literal("a"), make_string_literal("b")],
1168            values: vec![make_int_literal(1), make_int_literal(2)],
1169        }));
1170        let result = annotator.annotate(&map_expr);
1171        assert!(matches!(
1172            result,
1173            Some(DataType::Map { key_type, value_type })
1174            if matches!(*key_type, DataType::VarChar { .. })
1175               && matches!(*value_type, DataType::Int { .. })
1176        ));
1177    }
1178
1179    #[test]
1180    fn test_explode_array_type() {
1181        let mut annotator = TypeAnnotator::new(None, None);
1182
1183        // EXPLODE(array) returns element type
1184        let arr = Expression::Array(Box::new(crate::expressions::Array {
1185            expressions: vec![make_int_literal(1), make_int_literal(2)],
1186        }));
1187        let explode = Expression::Explode(Box::new(crate::expressions::UnaryFunc {
1188            this: arr,
1189            original_name: None,
1190        }));
1191        assert_eq!(
1192            annotator.annotate(&explode),
1193            Some(DataType::Int { length: None, integer_spelling: false })
1194        );
1195    }
1196
1197    #[test]
1198    fn test_unnest_array_type() {
1199        let mut annotator = TypeAnnotator::new(None, None);
1200
1201        // UNNEST(array) returns element type
1202        let arr = Expression::Array(Box::new(crate::expressions::Array {
1203            expressions: vec![make_string_literal("a"), make_string_literal("b")],
1204        }));
1205        let unnest = Expression::Unnest(Box::new(crate::expressions::UnnestFunc {
1206            this: arr,
1207            expressions: Vec::new(),
1208            with_ordinality: false,
1209            alias: None,
1210            offset_alias: None,
1211        }));
1212        assert_eq!(
1213            annotator.annotate(&unnest),
1214            Some(DataType::VarChar { length: None, parenthesized_length: false })
1215        );
1216    }
1217
1218    #[test]
1219    fn test_set_operation_type() {
1220        let mut annotator = TypeAnnotator::new(None, None);
1221
1222        // UNION/INTERSECT/EXCEPT return None (they produce relations, not scalars)
1223        let select = Expression::Select(Box::new(crate::expressions::Select::default()));
1224        let union = Expression::Union(Box::new(crate::expressions::Union {
1225            left: select.clone(),
1226            right: select.clone(),
1227            all: false,
1228            distinct: false,
1229            with: None,
1230            order_by: None,
1231            limit: None,
1232            offset: None,
1233            by_name: false,
1234            side: None,
1235            kind: None,
1236            corresponding: false,
1237            strict: false,
1238            on_columns: Vec::new(),
1239            distribute_by: None,
1240            sort_by: None,
1241            cluster_by: None,
1242        }));
1243        assert_eq!(annotator.annotate(&union), None);
1244    }
1245}