Skip to main content

polyglot_sql/optimizer/
annotate_types.rs

1//! Type Annotation for SQL Expressions
2//!
3//! This module provides type inference and annotation for SQL AST nodes.
4//! It walks the expression tree and assigns data types to expressions based on:
5//! - Literal values (strings, numbers, booleans)
6//! - Column references (from schema)
7//! - Function return types
8//! - Operator result types (with coercion rules)
9//!
10//! Based on SQLGlot's optimizer/annotate_types.py
11
12use std::collections::HashMap;
13
14use crate::dialects::DialectType;
15use crate::expressions::{
16    BinaryOp, DataType, Expression, Function, Literal, Map, Struct, StructField, Subscript,
17};
18use crate::schema::Schema;
19
20/// Type coercion class for determining result types in binary operations.
21/// Higher-priority classes win during coercion.
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
23pub enum TypeCoercionClass {
24    /// Text types (CHAR, VARCHAR, TEXT)
25    Text = 0,
26    /// Numeric types (INT, FLOAT, DECIMAL, etc.)
27    Numeric = 1,
28    /// Time-like types (DATE, TIME, TIMESTAMP, INTERVAL)
29    Timelike = 2,
30}
31
32impl TypeCoercionClass {
33    /// Get the coercion class for a data type
34    pub fn from_data_type(dt: &DataType) -> Option<Self> {
35        match dt {
36            // Text types
37            DataType::Char { .. }
38            | DataType::VarChar { .. }
39            | DataType::Text
40            | DataType::Binary { .. }
41            | DataType::VarBinary { .. }
42            | DataType::Blob => Some(TypeCoercionClass::Text),
43
44            // Numeric types
45            DataType::Boolean
46            | DataType::TinyInt { .. }
47            | DataType::SmallInt { .. }
48            | DataType::Int { .. }
49            | DataType::BigInt { .. }
50            | DataType::Float { .. }
51            | DataType::Double { .. }
52            | DataType::Decimal { .. } => Some(TypeCoercionClass::Numeric),
53
54            // Timelike types
55            DataType::Date
56            | DataType::Time { .. }
57            | DataType::Timestamp { .. }
58            | DataType::Interval { .. } => Some(TypeCoercionClass::Timelike),
59
60            // Other types don't have a coercion class
61            _ => None,
62        }
63    }
64}
65
66/// Type annotation configuration and state
67pub struct TypeAnnotator<'a> {
68    /// Schema for looking up column types
69    _schema: Option<&'a dyn Schema>,
70    /// Dialect for dialect-specific type rules
71    _dialect: Option<DialectType>,
72    /// Whether to annotate types for all expressions
73    annotate_aggregates: bool,
74    /// Function return type mappings
75    function_return_types: HashMap<String, DataType>,
76}
77
78impl<'a> TypeAnnotator<'a> {
79    /// Create a new type annotator
80    pub fn new(schema: Option<&'a dyn Schema>, dialect: Option<DialectType>) -> Self {
81        let mut annotator = Self {
82            _schema: schema,
83            _dialect: dialect,
84            annotate_aggregates: true,
85            function_return_types: HashMap::new(),
86        };
87        annotator.init_function_return_types();
88        annotator
89    }
90
91    /// Initialize function return type mappings
92    fn init_function_return_types(&mut self) {
93        // Aggregate functions
94        self.function_return_types
95            .insert("COUNT".to_string(), DataType::BigInt { length: None });
96        self.function_return_types.insert(
97            "SUM".to_string(),
98            DataType::Decimal {
99                precision: None,
100                scale: None,
101            },
102        );
103        self.function_return_types.insert(
104            "AVG".to_string(),
105            DataType::Double {
106                precision: None,
107                scale: None,
108            },
109        );
110
111        // String functions
112        self.function_return_types.insert(
113            "CONCAT".to_string(),
114            DataType::VarChar {
115                length: None,
116                parenthesized_length: false,
117            },
118        );
119        self.function_return_types.insert(
120            "UPPER".to_string(),
121            DataType::VarChar {
122                length: None,
123                parenthesized_length: false,
124            },
125        );
126        self.function_return_types.insert(
127            "LOWER".to_string(),
128            DataType::VarChar {
129                length: None,
130                parenthesized_length: false,
131            },
132        );
133        self.function_return_types.insert(
134            "TRIM".to_string(),
135            DataType::VarChar {
136                length: None,
137                parenthesized_length: false,
138            },
139        );
140        self.function_return_types.insert(
141            "LTRIM".to_string(),
142            DataType::VarChar {
143                length: None,
144                parenthesized_length: false,
145            },
146        );
147        self.function_return_types.insert(
148            "RTRIM".to_string(),
149            DataType::VarChar {
150                length: None,
151                parenthesized_length: false,
152            },
153        );
154        self.function_return_types.insert(
155            "SUBSTRING".to_string(),
156            DataType::VarChar {
157                length: None,
158                parenthesized_length: false,
159            },
160        );
161        self.function_return_types.insert(
162            "SUBSTR".to_string(),
163            DataType::VarChar {
164                length: None,
165                parenthesized_length: false,
166            },
167        );
168        self.function_return_types.insert(
169            "REPLACE".to_string(),
170            DataType::VarChar {
171                length: None,
172                parenthesized_length: false,
173            },
174        );
175        self.function_return_types.insert(
176            "LENGTH".to_string(),
177            DataType::Int {
178                length: None,
179                integer_spelling: false,
180            },
181        );
182        self.function_return_types.insert(
183            "CHAR_LENGTH".to_string(),
184            DataType::Int {
185                length: None,
186                integer_spelling: false,
187            },
188        );
189
190        // Date/Time functions
191        self.function_return_types.insert(
192            "NOW".to_string(),
193            DataType::Timestamp {
194                precision: None,
195                timezone: false,
196            },
197        );
198        self.function_return_types.insert(
199            "CURRENT_TIMESTAMP".to_string(),
200            DataType::Timestamp {
201                precision: None,
202                timezone: false,
203            },
204        );
205        self.function_return_types
206            .insert("CURRENT_DATE".to_string(), DataType::Date);
207        self.function_return_types.insert(
208            "CURRENT_TIME".to_string(),
209            DataType::Time {
210                precision: None,
211                timezone: false,
212            },
213        );
214        self.function_return_types
215            .insert("DATE".to_string(), DataType::Date);
216        self.function_return_types.insert(
217            "YEAR".to_string(),
218            DataType::Int {
219                length: None,
220                integer_spelling: false,
221            },
222        );
223        self.function_return_types.insert(
224            "MONTH".to_string(),
225            DataType::Int {
226                length: None,
227                integer_spelling: false,
228            },
229        );
230        self.function_return_types.insert(
231            "DAY".to_string(),
232            DataType::Int {
233                length: None,
234                integer_spelling: false,
235            },
236        );
237        self.function_return_types.insert(
238            "HOUR".to_string(),
239            DataType::Int {
240                length: None,
241                integer_spelling: false,
242            },
243        );
244        self.function_return_types.insert(
245            "MINUTE".to_string(),
246            DataType::Int {
247                length: None,
248                integer_spelling: false,
249            },
250        );
251        self.function_return_types.insert(
252            "SECOND".to_string(),
253            DataType::Int {
254                length: None,
255                integer_spelling: false,
256            },
257        );
258        self.function_return_types.insert(
259            "EXTRACT".to_string(),
260            DataType::Int {
261                length: None,
262                integer_spelling: false,
263            },
264        );
265        self.function_return_types.insert(
266            "DATE_DIFF".to_string(),
267            DataType::Int {
268                length: None,
269                integer_spelling: false,
270            },
271        );
272        self.function_return_types.insert(
273            "DATEDIFF".to_string(),
274            DataType::Int {
275                length: None,
276                integer_spelling: false,
277            },
278        );
279
280        // Math functions
281        self.function_return_types.insert(
282            "ABS".to_string(),
283            DataType::Double {
284                precision: None,
285                scale: None,
286            },
287        );
288        self.function_return_types.insert(
289            "ROUND".to_string(),
290            DataType::Double {
291                precision: None,
292                scale: None,
293            },
294        );
295        self.function_return_types.insert(
296            "DATE_FORMAT".to_string(),
297            DataType::VarChar {
298                length: None,
299                parenthesized_length: false,
300            },
301        );
302        self.function_return_types.insert(
303            "FORMAT_DATE".to_string(),
304            DataType::VarChar {
305                length: None,
306                parenthesized_length: false,
307            },
308        );
309        self.function_return_types.insert(
310            "TIME_TO_STR".to_string(),
311            DataType::VarChar {
312                length: None,
313                parenthesized_length: false,
314            },
315        );
316        self.function_return_types.insert(
317            "SQRT".to_string(),
318            DataType::Double {
319                precision: None,
320                scale: None,
321            },
322        );
323        self.function_return_types.insert(
324            "POWER".to_string(),
325            DataType::Double {
326                precision: None,
327                scale: None,
328            },
329        );
330        self.function_return_types.insert(
331            "MOD".to_string(),
332            DataType::Int {
333                length: None,
334                integer_spelling: false,
335            },
336        );
337        self.function_return_types.insert(
338            "LOG".to_string(),
339            DataType::Double {
340                precision: None,
341                scale: None,
342            },
343        );
344        self.function_return_types.insert(
345            "LN".to_string(),
346            DataType::Double {
347                precision: None,
348                scale: None,
349            },
350        );
351        self.function_return_types.insert(
352            "EXP".to_string(),
353            DataType::Double {
354                precision: None,
355                scale: None,
356            },
357        );
358
359        // Null-handling functions return Unknown (infer from args)
360        self.function_return_types
361            .insert("COALESCE".to_string(), DataType::Unknown);
362        self.function_return_types
363            .insert("NULLIF".to_string(), DataType::Unknown);
364        self.function_return_types
365            .insert("GREATEST".to_string(), DataType::Unknown);
366        self.function_return_types
367            .insert("LEAST".to_string(), DataType::Unknown);
368    }
369
370    /// Annotate types for an expression tree
371    pub fn annotate(&mut self, expr: &Expression) -> Option<DataType> {
372        match expr {
373            // Literals
374            Expression::Literal(lit) => self.annotate_literal(lit),
375            Expression::Boolean(_) => Some(DataType::Boolean),
376            Expression::Null(_) => None, // NULL has no type
377
378            // Arithmetic binary operations
379            Expression::Add(op)
380            | Expression::Sub(op)
381            | Expression::Mul(op)
382            | Expression::Div(op)
383            | Expression::Mod(op) => self.annotate_arithmetic(op),
384
385            // Comparison operations - always boolean
386            Expression::Eq(_)
387            | Expression::Neq(_)
388            | Expression::Lt(_)
389            | Expression::Lte(_)
390            | Expression::Gt(_)
391            | Expression::Gte(_)
392            | Expression::Like(_)
393            | Expression::ILike(_) => Some(DataType::Boolean),
394
395            // Logical operations - always boolean
396            Expression::And(_) | Expression::Or(_) | Expression::Not(_) => Some(DataType::Boolean),
397
398            // Predicates - always boolean
399            Expression::Between(_)
400            | Expression::In(_)
401            | Expression::IsNull(_)
402            | Expression::IsTrue(_)
403            | Expression::IsFalse(_)
404            | Expression::Is(_)
405            | Expression::Exists(_) => Some(DataType::Boolean),
406
407            // String concatenation
408            Expression::Concat(_) => Some(DataType::VarChar {
409                length: None,
410                parenthesized_length: false,
411            }),
412
413            // Bitwise operations - integer
414            Expression::BitwiseAnd(_)
415            | Expression::BitwiseOr(_)
416            | Expression::BitwiseXor(_)
417            | Expression::BitwiseNot(_) => Some(DataType::BigInt { length: None }),
418
419            // Negation preserves type
420            Expression::Neg(op) => self.annotate(&op.this),
421
422            // Functions
423            Expression::Function(func) => self.annotate_function(func),
424
425            // Typed aggregate functions
426            Expression::Count(_) => Some(DataType::BigInt { length: None }),
427            Expression::Sum(agg) => self.annotate_sum(&agg.this),
428            Expression::Avg(_) => Some(DataType::Double {
429                precision: None,
430                scale: None,
431            }),
432            Expression::Min(agg) => self.annotate(&agg.this),
433            Expression::Max(agg) => self.annotate(&agg.this),
434            Expression::GroupConcat(_) | Expression::StringAgg(_) | Expression::ListAgg(_) => {
435                Some(DataType::VarChar {
436                    length: None,
437                    parenthesized_length: false,
438                })
439            }
440
441            // Generic aggregate function
442            Expression::AggregateFunction(agg) => {
443                if !self.annotate_aggregates {
444                    return None;
445                }
446                let func_name = agg.name.to_uppercase();
447                self.get_aggregate_return_type(&func_name, &agg.args)
448            }
449
450            // Column references - look up type from schema if available
451            Expression::Column(col) => {
452                if let Some(schema) = &self._schema {
453                    let table_name = col.table.as_ref().map(|t| t.name.as_str()).unwrap_or("");
454                    schema.get_column_type(table_name, &col.name.name).ok()
455                } else {
456                    None
457                }
458            }
459
460            // Cast expressions
461            Expression::Cast(cast) => Some(cast.to.clone()),
462            Expression::SafeCast(cast) => Some(cast.to.clone()),
463            Expression::TryCast(cast) => Some(cast.to.clone()),
464
465            // Subqueries - type is the type of the first SELECT expression
466            Expression::Subquery(subq) => {
467                if let Expression::Select(select) = &subq.this {
468                    if let Some(first) = select.expressions.first() {
469                        self.annotate(first)
470                    } else {
471                        None
472                    }
473                } else {
474                    None
475                }
476            }
477
478            // CASE expression - type of the first THEN/ELSE
479            Expression::Case(case) => {
480                if let Some(else_expr) = &case.else_ {
481                    self.annotate(else_expr)
482                } else if let Some((_, then_expr)) = case.whens.first() {
483                    self.annotate(then_expr)
484                } else {
485                    None
486                }
487            }
488
489            // Array expressions
490            Expression::Array(arr) => {
491                if let Some(first) = arr.expressions.first() {
492                    if let Some(elem_type) = self.annotate(first) {
493                        Some(DataType::Array {
494                            element_type: Box::new(elem_type),
495                            dimension: None,
496                        })
497                    } else {
498                        Some(DataType::Array {
499                            element_type: Box::new(DataType::Unknown),
500                            dimension: None,
501                        })
502                    }
503                } else {
504                    Some(DataType::Array {
505                        element_type: Box::new(DataType::Unknown),
506                        dimension: None,
507                    })
508                }
509            }
510
511            // Interval expressions
512            Expression::Interval(_) => Some(DataType::Interval {
513                unit: None,
514                to: None,
515            }),
516
517            // Window functions inherit type from their function
518            Expression::WindowFunction(window) => self.annotate(&window.this),
519
520            // Date/time expressions
521            Expression::CurrentDate(_) => Some(DataType::Date),
522            Expression::CurrentTime(_) => Some(DataType::Time {
523                precision: None,
524                timezone: false,
525            }),
526            Expression::CurrentTimestamp(_) | Expression::CurrentTimestampLTZ(_) => {
527                Some(DataType::Timestamp {
528                    precision: None,
529                    timezone: false,
530                })
531            }
532
533            // Date functions
534            Expression::DateAdd(_)
535            | Expression::DateSub(_)
536            | Expression::ToDate(_)
537            | Expression::Date(_) => Some(DataType::Date),
538            Expression::DateDiff(_) | Expression::Extract(_) => Some(DataType::Int {
539                length: None,
540                integer_spelling: false,
541            }),
542            Expression::ToTimestamp(_) => Some(DataType::Timestamp {
543                precision: None,
544                timezone: false,
545            }),
546
547            // String functions
548            Expression::Upper(_)
549            | Expression::Lower(_)
550            | Expression::Trim(_)
551            | Expression::LTrim(_)
552            | Expression::RTrim(_)
553            | Expression::Replace(_)
554            | Expression::Substring(_)
555            | Expression::Reverse(_)
556            | Expression::Left(_)
557            | Expression::Right(_)
558            | Expression::Repeat(_)
559            | Expression::Lpad(_)
560            | Expression::Rpad(_)
561            | Expression::ConcatWs(_)
562            | Expression::Overlay(_) => Some(DataType::VarChar {
563                length: None,
564                parenthesized_length: false,
565            }),
566            Expression::Length(_) => Some(DataType::Int {
567                length: None,
568                integer_spelling: false,
569            }),
570
571            // Math functions
572            Expression::Abs(_)
573            | Expression::Sqrt(_)
574            | Expression::Cbrt(_)
575            | Expression::Ln(_)
576            | Expression::Exp(_)
577            | Expression::Power(_)
578            | Expression::Log(_) => Some(DataType::Double {
579                precision: None,
580                scale: None,
581            }),
582            Expression::Round(_) => Some(DataType::Double {
583                precision: None,
584                scale: None,
585            }),
586            Expression::Floor(f) => self.annotate_math_function(&f.this),
587            Expression::Ceil(f) => self.annotate_math_function(&f.this),
588            Expression::Sign(s) => self.annotate(&s.this),
589            Expression::DateFormat(_) | Expression::FormatDate(_) | Expression::TimeToStr(_) => {
590                Some(DataType::VarChar {
591                    length: None,
592                    parenthesized_length: false,
593                })
594            }
595
596            // Greatest/Least - coerce argument types
597            Expression::Greatest(v) | Expression::Least(v) => self.coerce_arg_types(&v.expressions),
598
599            // Alias - type of the inner expression
600            Expression::Alias(alias) => self.annotate(&alias.this),
601
602            // SELECT expressions - no scalar type
603            Expression::Select(_) => None,
604
605            // ============================================
606            // 3.1.8: Array/Map Indexing (Subscript/Bracket)
607            // ============================================
608            Expression::Subscript(sub) => self.annotate_subscript(sub),
609
610            // Dot access (struct.field) - returns Unknown without schema
611            Expression::Dot(_) => None,
612
613            // ============================================
614            // 3.1.9: STRUCT Construction
615            // ============================================
616            Expression::Struct(s) => self.annotate_struct(s),
617
618            // ============================================
619            // 3.1.10: MAP Construction
620            // ============================================
621            Expression::Map(map) => self.annotate_map(map),
622            Expression::MapFromEntries(mfe) => {
623                // MAP_FROM_ENTRIES(array_of_pairs) - infer from array element type
624                if let Some(DataType::Array { element_type, .. }) = self.annotate(&mfe.this) {
625                    if let DataType::Struct { fields, .. } = *element_type {
626                        if fields.len() >= 2 {
627                            return Some(DataType::Map {
628                                key_type: Box::new(fields[0].data_type.clone()),
629                                value_type: Box::new(fields[1].data_type.clone()),
630                            });
631                        }
632                    }
633                }
634                Some(DataType::Map {
635                    key_type: Box::new(DataType::Unknown),
636                    value_type: Box::new(DataType::Unknown),
637                })
638            }
639
640            // ============================================
641            // 3.1.11: SetOperation Type Coercion
642            // ============================================
643            Expression::Union(union) => self.annotate_set_operation(&union.left, &union.right),
644            Expression::Intersect(intersect) => {
645                self.annotate_set_operation(&intersect.left, &intersect.right)
646            }
647            Expression::Except(except) => self.annotate_set_operation(&except.left, &except.right),
648
649            // ============================================
650            // 3.1.12: UDTF Type Handling
651            // ============================================
652            Expression::Lateral(lateral) => {
653                // LATERAL subquery - type is the subquery's type
654                self.annotate(&lateral.this)
655            }
656            Expression::LateralView(lv) => {
657                // LATERAL VIEW - returns the exploded type
658                self.annotate_lateral_view(lv)
659            }
660            Expression::Unnest(unnest) => {
661                // UNNEST(array) - returns the element type of the array
662                if let Some(DataType::Array { element_type, .. }) = self.annotate(&unnest.this) {
663                    Some(*element_type)
664                } else {
665                    None
666                }
667            }
668            Expression::Explode(explode) => {
669                // EXPLODE(array) - returns the element type
670                if let Some(DataType::Array { element_type, .. }) = self.annotate(&explode.this) {
671                    Some(*element_type)
672                } else if let Some(DataType::Map {
673                    key_type,
674                    value_type,
675                }) = self.annotate(&explode.this)
676                {
677                    // EXPLODE(map) returns struct(key, value)
678                    Some(DataType::Struct {
679                        fields: vec![
680                            StructField::new("key".to_string(), *key_type),
681                            StructField::new("value".to_string(), *value_type),
682                        ],
683                        nested: false,
684                    })
685                } else {
686                    None
687                }
688            }
689            Expression::ExplodeOuter(explode) => {
690                // EXPLODE_OUTER - same as EXPLODE but preserves nulls
691                if let Some(DataType::Array { element_type, .. }) = self.annotate(&explode.this) {
692                    Some(*element_type)
693                } else {
694                    None
695                }
696            }
697            Expression::GenerateSeries(gs) => {
698                // GENERATE_SERIES returns the type of start/end
699                if let Some(ref start) = gs.start {
700                    self.annotate(start)
701                } else if let Some(ref end) = gs.end {
702                    self.annotate(end)
703                } else {
704                    Some(DataType::Int {
705                        length: None,
706                        integer_spelling: false,
707                    })
708                }
709            }
710
711            // Other expressions - unknown
712            _ => None,
713        }
714    }
715
716    /// Annotate types in-place on the expression tree (bottom-up).
717    ///
718    /// First recurses into children, then computes this node's type using the
719    /// read-only `annotate` method, and finally stores the result via
720    /// `set_inferred_type`.
721    pub fn annotate_in_place(&mut self, expr: &mut Expression) {
722        // 1. Recurse into children (bottom-up)
723        self.annotate_children_in_place(expr);
724
725        // 2. Compute this node's type using the read-only method
726        //    (children already have their types set, but `annotate` re-derives
727        //    from structure, which is fine since the structure hasn't changed)
728        let dt = self.annotate(expr);
729
730        // 3. Store on the node
731        if let Some(data_type) = dt {
732            expr.set_inferred_type(data_type);
733        }
734    }
735
736    /// Recursively annotate children of an expression in-place.
737    fn annotate_children_in_place(&mut self, expr: &mut Expression) {
738        match expr {
739            // Binary operations
740            Expression::And(op)
741            | Expression::Or(op)
742            | Expression::Add(op)
743            | Expression::Sub(op)
744            | Expression::Mul(op)
745            | Expression::Div(op)
746            | Expression::Mod(op)
747            | Expression::Eq(op)
748            | Expression::Neq(op)
749            | Expression::Lt(op)
750            | Expression::Lte(op)
751            | Expression::Gt(op)
752            | Expression::Gte(op)
753            | Expression::Concat(op)
754            | Expression::BitwiseAnd(op)
755            | Expression::BitwiseOr(op)
756            | Expression::BitwiseXor(op)
757            | Expression::Adjacent(op)
758            | Expression::TsMatch(op)
759            | Expression::PropertyEQ(op)
760            | Expression::ArrayContainsAll(op)
761            | Expression::ArrayContainedBy(op)
762            | Expression::ArrayOverlaps(op)
763            | Expression::JSONBContainsAllTopKeys(op)
764            | Expression::JSONBContainsAnyTopKeys(op)
765            | Expression::JSONBDeleteAtPath(op)
766            | Expression::ExtendsLeft(op)
767            | Expression::ExtendsRight(op)
768            | Expression::Is(op)
769            | Expression::MemberOf(op)
770            | Expression::Match(op)
771            | Expression::NullSafeEq(op)
772            | Expression::NullSafeNeq(op)
773            | Expression::Glob(op)
774            | Expression::BitwiseLeftShift(op)
775            | Expression::BitwiseRightShift(op) => {
776                self.annotate_in_place(&mut op.left);
777                self.annotate_in_place(&mut op.right);
778            }
779
780            // Like operations
781            Expression::Like(op) | Expression::ILike(op) => {
782                self.annotate_in_place(&mut op.left);
783                self.annotate_in_place(&mut op.right);
784            }
785
786            // Unary operations
787            Expression::Not(op) | Expression::Neg(op) | Expression::BitwiseNot(op) => {
788                self.annotate_in_place(&mut op.this);
789            }
790
791            // Cast
792            Expression::Cast(c) | Expression::TryCast(c) | Expression::SafeCast(c) => {
793                self.annotate_in_place(&mut c.this);
794            }
795
796            // Case
797            Expression::Case(c) => {
798                if let Some(ref mut operand) = c.operand {
799                    self.annotate_in_place(operand);
800                }
801                for (cond, then_expr) in &mut c.whens {
802                    self.annotate_in_place(cond);
803                    self.annotate_in_place(then_expr);
804                }
805                if let Some(ref mut else_expr) = c.else_ {
806                    self.annotate_in_place(else_expr);
807                }
808            }
809
810            // Alias
811            Expression::Alias(a) => {
812                self.annotate_in_place(&mut a.this);
813            }
814
815            // Column - leaf node, no children to recurse
816            Expression::Column(_) => {}
817
818            // Function
819            Expression::Function(f) => {
820                for arg in &mut f.args {
821                    self.annotate_in_place(arg);
822                }
823            }
824
825            // AggregateFunction
826            Expression::AggregateFunction(f) => {
827                for arg in &mut f.args {
828                    self.annotate_in_place(arg);
829                }
830            }
831
832            // WindowFunction
833            Expression::WindowFunction(w) => {
834                self.annotate_in_place(&mut w.this);
835            }
836
837            // Subquery
838            Expression::Subquery(s) => {
839                self.annotate_in_place(&mut s.this);
840            }
841
842            // UnaryFunc variants
843            Expression::Upper(f)
844            | Expression::Lower(f)
845            | Expression::Length(f)
846            | Expression::LTrim(f)
847            | Expression::RTrim(f)
848            | Expression::Reverse(f)
849            | Expression::Abs(f)
850            | Expression::Sqrt(f)
851            | Expression::Cbrt(f)
852            | Expression::Ln(f)
853            | Expression::Exp(f)
854            | Expression::Sign(f)
855            | Expression::Date(f)
856            | Expression::Time(f)
857            | Expression::Explode(f)
858            | Expression::ExplodeOuter(f)
859            | Expression::MapFromEntries(f)
860            | Expression::MapKeys(f)
861            | Expression::MapValues(f)
862            | Expression::ArrayLength(f)
863            | Expression::ArraySize(f)
864            | Expression::Cardinality(f)
865            | Expression::ArrayReverse(f)
866            | Expression::ArrayDistinct(f)
867            | Expression::ArrayFlatten(f)
868            | Expression::ArrayCompact(f)
869            | Expression::ToArray(f)
870            | Expression::JsonArrayLength(f)
871            | Expression::JsonKeys(f)
872            | Expression::JsonType(f)
873            | Expression::ParseJson(f)
874            | Expression::ToJson(f)
875            | Expression::Year(f)
876            | Expression::Month(f)
877            | Expression::Day(f)
878            | Expression::Hour(f)
879            | Expression::Minute(f)
880            | Expression::Second(f)
881            | Expression::Initcap(f)
882            | Expression::Ascii(f)
883            | Expression::Chr(f)
884            | Expression::Soundex(f)
885            | Expression::ByteLength(f)
886            | Expression::Hex(f)
887            | Expression::LowerHex(f)
888            | Expression::Unicode(f)
889            | Expression::Typeof(f)
890            | Expression::BitwiseCount(f)
891            | Expression::Epoch(f)
892            | Expression::EpochMs(f)
893            | Expression::Radians(f)
894            | Expression::Degrees(f)
895            | Expression::Sin(f)
896            | Expression::Cos(f)
897            | Expression::Tan(f)
898            | Expression::Asin(f)
899            | Expression::Acos(f)
900            | Expression::Atan(f)
901            | Expression::IsNan(f)
902            | Expression::IsInf(f) => {
903                self.annotate_in_place(&mut f.this);
904            }
905
906            // BinaryFunc variants
907            Expression::Power(f)
908            | Expression::NullIf(f)
909            | Expression::IfNull(f)
910            | Expression::Nvl(f)
911            | Expression::Contains(f)
912            | Expression::StartsWith(f)
913            | Expression::EndsWith(f)
914            | Expression::Levenshtein(f)
915            | Expression::ModFunc(f)
916            | Expression::IntDiv(f)
917            | Expression::Atan2(f)
918            | Expression::AddMonths(f)
919            | Expression::MonthsBetween(f)
920            | Expression::NextDay(f)
921            | Expression::UnixToTimeStr(f)
922            | Expression::ArrayContains(f)
923            | Expression::ArrayPosition(f)
924            | Expression::ArrayAppend(f)
925            | Expression::ArrayPrepend(f)
926            | Expression::ArrayUnion(f)
927            | Expression::ArrayExcept(f)
928            | Expression::ArrayRemove(f)
929            | Expression::StarMap(f)
930            | Expression::MapFromArrays(f)
931            | Expression::MapContainsKey(f)
932            | Expression::ElementAt(f)
933            | Expression::JsonMergePatch(f) => {
934                self.annotate_in_place(&mut f.this);
935                self.annotate_in_place(&mut f.expression);
936            }
937
938            // VarArgFunc variants
939            Expression::Coalesce(f)
940            | Expression::Greatest(f)
941            | Expression::Least(f)
942            | Expression::ArrayConcat(f)
943            | Expression::ArrayIntersect(f)
944            | Expression::ArrayZip(f)
945            | Expression::MapConcat(f)
946            | Expression::JsonArray(f) => {
947                for e in &mut f.expressions {
948                    self.annotate_in_place(e);
949                }
950            }
951
952            // AggFunc variants
953            Expression::Sum(f)
954            | Expression::Avg(f)
955            | Expression::Min(f)
956            | Expression::Max(f)
957            | Expression::ArrayAgg(f)
958            | Expression::CountIf(f)
959            | Expression::Stddev(f)
960            | Expression::StddevPop(f)
961            | Expression::StddevSamp(f)
962            | Expression::Variance(f)
963            | Expression::VarPop(f)
964            | Expression::VarSamp(f)
965            | Expression::Median(f)
966            | Expression::Mode(f)
967            | Expression::First(f)
968            | Expression::Last(f)
969            | Expression::AnyValue(f)
970            | Expression::ApproxDistinct(f)
971            | Expression::ApproxCountDistinct(f)
972            | Expression::LogicalAnd(f)
973            | Expression::LogicalOr(f)
974            | Expression::Skewness(f)
975            | Expression::ArrayConcatAgg(f)
976            | Expression::ArrayUniqueAgg(f)
977            | Expression::BoolXorAgg(f)
978            | Expression::BitwiseAndAgg(f)
979            | Expression::BitwiseOrAgg(f)
980            | Expression::BitwiseXorAgg(f) => {
981                self.annotate_in_place(&mut f.this);
982            }
983
984            // Select - recurse into expressions
985            Expression::Select(s) => {
986                for e in &mut s.expressions {
987                    self.annotate_in_place(e);
988                }
989            }
990
991            // Everything else - no children to recurse or not value-producing
992            _ => {}
993        }
994    }
995
996    /// Annotate math functions like FLOOR/CEIL that return Double for integer inputs
997    /// and preserve the input type otherwise (matching sqlglot's _annotate_math_functions).
998    fn annotate_math_function(&mut self, arg: &Expression) -> Option<DataType> {
999        let input_type = self.annotate(arg)?;
1000        match input_type {
1001            DataType::TinyInt { .. }
1002            | DataType::SmallInt { .. }
1003            | DataType::Int { .. }
1004            | DataType::BigInt { .. } => Some(DataType::Double {
1005                precision: None,
1006                scale: None,
1007            }),
1008            other => Some(other),
1009        }
1010    }
1011
1012    /// Annotate a subscript/bracket expression (array[index] or map[key])
1013    fn annotate_subscript(&mut self, sub: &Subscript) -> Option<DataType> {
1014        let base_type = self.annotate(&sub.this)?;
1015
1016        match base_type {
1017            DataType::Array { element_type, .. } => Some(*element_type),
1018            DataType::Map { value_type, .. } => Some(*value_type),
1019            DataType::Json | DataType::JsonB => Some(DataType::Json), // JSON indexing returns JSON
1020            DataType::VarChar { .. } | DataType::Text => {
1021                // String indexing returns a character
1022                Some(DataType::VarChar {
1023                    length: Some(1),
1024                    parenthesized_length: false,
1025                })
1026            }
1027            _ => None,
1028        }
1029    }
1030
1031    /// Annotate a STRUCT literal
1032    fn annotate_struct(&mut self, s: &Struct) -> Option<DataType> {
1033        let fields: Vec<StructField> = s
1034            .fields
1035            .iter()
1036            .map(|(name, expr)| {
1037                let field_type = self.annotate(expr).unwrap_or(DataType::Unknown);
1038                StructField::new(name.clone().unwrap_or_default(), field_type)
1039            })
1040            .collect();
1041        Some(DataType::Struct {
1042            fields,
1043            nested: false,
1044        })
1045    }
1046
1047    /// Annotate a MAP literal
1048    fn annotate_map(&mut self, map: &Map) -> Option<DataType> {
1049        let key_type = if let Some(first_key) = map.keys.first() {
1050            self.annotate(first_key).unwrap_or(DataType::Unknown)
1051        } else {
1052            DataType::Unknown
1053        };
1054
1055        let value_type = if let Some(first_value) = map.values.first() {
1056            self.annotate(first_value).unwrap_or(DataType::Unknown)
1057        } else {
1058            DataType::Unknown
1059        };
1060
1061        Some(DataType::Map {
1062            key_type: Box::new(key_type),
1063            value_type: Box::new(value_type),
1064        })
1065    }
1066
1067    /// Annotate a SetOperation (UNION/INTERSECT/EXCEPT)
1068    /// Returns None since set operations produce relation types, not scalar types
1069    fn annotate_set_operation(
1070        &mut self,
1071        _left: &Expression,
1072        _right: &Expression,
1073    ) -> Option<DataType> {
1074        // Set operations produce relations, not scalar types
1075        // The column types would be coerced between left and right
1076        // For now, return None as this is a relation-level type
1077        None
1078    }
1079
1080    /// Annotate a LATERAL VIEW expression
1081    fn annotate_lateral_view(&mut self, lv: &crate::expressions::LateralView) -> Option<DataType> {
1082        // The type depends on the table-generating function
1083        self.annotate(&lv.this)
1084    }
1085
1086    /// Annotate a literal value
1087    fn annotate_literal(&self, lit: &Literal) -> Option<DataType> {
1088        match lit {
1089            Literal::String(_)
1090            | Literal::NationalString(_)
1091            | Literal::TripleQuotedString(_, _)
1092            | Literal::EscapeString(_)
1093            | Literal::DollarString(_)
1094            | Literal::RawString(_) => Some(DataType::VarChar {
1095                length: None,
1096                parenthesized_length: false,
1097            }),
1098            Literal::Number(n) => {
1099                // Try to determine if it's an integer or float
1100                if n.contains('.') || n.contains('e') || n.contains('E') {
1101                    Some(DataType::Double {
1102                        precision: None,
1103                        scale: None,
1104                    })
1105                } else {
1106                    // Check if it fits in an Int or needs BigInt
1107                    if let Ok(_) = n.parse::<i32>() {
1108                        Some(DataType::Int {
1109                            length: None,
1110                            integer_spelling: false,
1111                        })
1112                    } else {
1113                        Some(DataType::BigInt { length: None })
1114                    }
1115                }
1116            }
1117            Literal::HexString(_) | Literal::BitString(_) | Literal::ByteString(_) => {
1118                Some(DataType::VarBinary { length: None })
1119            }
1120            Literal::HexNumber(_) => Some(DataType::BigInt { length: None }),
1121            Literal::Date(_) => Some(DataType::Date),
1122            Literal::Time(_) => Some(DataType::Time {
1123                precision: None,
1124                timezone: false,
1125            }),
1126            Literal::Timestamp(_) => Some(DataType::Timestamp {
1127                precision: None,
1128                timezone: false,
1129            }),
1130            Literal::Datetime(_) => Some(DataType::Custom {
1131                name: "DATETIME".to_string(),
1132            }),
1133        }
1134    }
1135
1136    /// Annotate an arithmetic binary operation
1137    fn annotate_arithmetic(&mut self, op: &BinaryOp) -> Option<DataType> {
1138        let left_type = self.annotate(&op.left);
1139        let right_type = self.annotate(&op.right);
1140
1141        match (left_type, right_type) {
1142            (Some(l), Some(r)) => self.coerce_types(&l, &r),
1143            (Some(t), None) | (None, Some(t)) => Some(t),
1144            (None, None) => None,
1145        }
1146    }
1147
1148    /// Annotate a function call
1149    fn annotate_function(&mut self, func: &Function) -> Option<DataType> {
1150        let func_name = func.name.to_uppercase();
1151
1152        // Check known function return types
1153        if let Some(return_type) = self.function_return_types.get(&func_name) {
1154            if *return_type != DataType::Unknown {
1155                return Some(return_type.clone());
1156            }
1157        }
1158
1159        // For functions with Unknown return type, infer from arguments
1160        match func_name.as_str() {
1161            "COALESCE" | "IFNULL" | "NVL" | "ISNULL" => {
1162                // Return type of first non-null argument
1163                for arg in &func.args {
1164                    if let Some(arg_type) = self.annotate(arg) {
1165                        return Some(arg_type);
1166                    }
1167                }
1168                None
1169            }
1170            "NULLIF" => {
1171                // Return type of first argument
1172                func.args.first().and_then(|arg| self.annotate(arg))
1173            }
1174            "GREATEST" | "LEAST" => {
1175                // Coerce all argument types
1176                self.coerce_arg_types(&func.args)
1177            }
1178            "IF" | "IIF" => {
1179                // Return type of THEN/ELSE branches
1180                if func.args.len() >= 2 {
1181                    self.annotate(&func.args[1])
1182                } else {
1183                    None
1184                }
1185            }
1186            _ => {
1187                // Unknown function - try to infer from first argument
1188                func.args.first().and_then(|arg| self.annotate(arg))
1189            }
1190        }
1191    }
1192
1193    /// Get return type for aggregate functions
1194    fn get_aggregate_return_type(
1195        &mut self,
1196        func_name: &str,
1197        args: &[Expression],
1198    ) -> Option<DataType> {
1199        match func_name {
1200            "COUNT" | "COUNT_IF" => Some(DataType::BigInt { length: None }),
1201            "SUM" => {
1202                if let Some(arg) = args.first() {
1203                    self.annotate_sum(arg)
1204                } else {
1205                    Some(DataType::Decimal {
1206                        precision: None,
1207                        scale: None,
1208                    })
1209                }
1210            }
1211            "AVG" => Some(DataType::Double {
1212                precision: None,
1213                scale: None,
1214            }),
1215            "MIN" | "MAX" => {
1216                // Preserves input type
1217                args.first().and_then(|arg| self.annotate(arg))
1218            }
1219            "STRING_AGG" | "GROUP_CONCAT" | "LISTAGG" | "ARRAY_AGG" => Some(DataType::VarChar {
1220                length: None,
1221                parenthesized_length: false,
1222            }),
1223            "BOOL_AND" | "BOOL_OR" | "EVERY" | "ANY" | "SOME" => Some(DataType::Boolean),
1224            "BIT_AND" | "BIT_OR" | "BIT_XOR" => Some(DataType::BigInt { length: None }),
1225            "STDDEV" | "STDDEV_POP" | "STDDEV_SAMP" | "VARIANCE" | "VAR_POP" | "VAR_SAMP" => {
1226                Some(DataType::Double {
1227                    precision: None,
1228                    scale: None,
1229                })
1230            }
1231            "PERCENTILE_CONT" | "PERCENTILE_DISC" | "MEDIAN" => {
1232                args.first().and_then(|arg| self.annotate(arg))
1233            }
1234            _ => None,
1235        }
1236    }
1237
1238    /// Annotate SUM function - promotes to at least BigInt
1239    fn annotate_sum(&mut self, arg: &Expression) -> Option<DataType> {
1240        match self.annotate(arg) {
1241            Some(DataType::TinyInt { .. })
1242            | Some(DataType::SmallInt { .. })
1243            | Some(DataType::Int { .. }) => Some(DataType::BigInt { length: None }),
1244            Some(DataType::BigInt { .. }) => Some(DataType::BigInt { length: None }),
1245            Some(DataType::Float { .. }) | Some(DataType::Double { .. }) => {
1246                Some(DataType::Double {
1247                    precision: None,
1248                    scale: None,
1249                })
1250            }
1251            Some(DataType::Decimal { precision, scale }) => {
1252                Some(DataType::Decimal { precision, scale })
1253            }
1254            _ => Some(DataType::Decimal {
1255                precision: None,
1256                scale: None,
1257            }),
1258        }
1259    }
1260
1261    /// Coerce multiple argument types to a common type
1262    fn coerce_arg_types(&mut self, args: &[Expression]) -> Option<DataType> {
1263        let mut result_type: Option<DataType> = None;
1264        for arg in args {
1265            if let Some(arg_type) = self.annotate(arg) {
1266                result_type = match result_type {
1267                    Some(t) => self.coerce_types(&t, &arg_type),
1268                    None => Some(arg_type),
1269                };
1270            }
1271        }
1272        result_type
1273    }
1274
1275    /// Coerce two types to a common type
1276    fn coerce_types(&self, left: &DataType, right: &DataType) -> Option<DataType> {
1277        // If types are the same, return that type
1278        if left == right {
1279            return Some(left.clone());
1280        }
1281
1282        // Special case: Interval + Date/Timestamp
1283        match (left, right) {
1284            (DataType::Date, DataType::Interval { .. })
1285            | (DataType::Interval { .. }, DataType::Date) => return Some(DataType::Date),
1286            (
1287                DataType::Timestamp {
1288                    precision,
1289                    timezone,
1290                },
1291                DataType::Interval { .. },
1292            )
1293            | (
1294                DataType::Interval { .. },
1295                DataType::Timestamp {
1296                    precision,
1297                    timezone,
1298                },
1299            ) => {
1300                return Some(DataType::Timestamp {
1301                    precision: *precision,
1302                    timezone: *timezone,
1303                });
1304            }
1305            _ => {}
1306        }
1307
1308        // Coerce based on class
1309        let left_class = TypeCoercionClass::from_data_type(left);
1310        let right_class = TypeCoercionClass::from_data_type(right);
1311
1312        match (left_class, right_class) {
1313            // Same class: use higher-precision type within class
1314            (Some(lc), Some(rc)) if lc == rc => {
1315                // For numeric, choose wider type
1316                if lc == TypeCoercionClass::Numeric {
1317                    Some(self.wider_numeric_type(left, right))
1318                } else {
1319                    // For text and timelike, left wins by default
1320                    Some(left.clone())
1321                }
1322            }
1323            // Different classes: higher-priority class wins
1324            (Some(lc), Some(rc)) => {
1325                if lc > rc {
1326                    Some(left.clone())
1327                } else {
1328                    Some(right.clone())
1329                }
1330            }
1331            // One unknown: use the known type
1332            (Some(_), None) => Some(left.clone()),
1333            (None, Some(_)) => Some(right.clone()),
1334            // Both unknown: return unknown
1335            (None, None) => Some(DataType::Unknown),
1336        }
1337    }
1338
1339    /// Get the wider numeric type
1340    fn wider_numeric_type(&self, left: &DataType, right: &DataType) -> DataType {
1341        let order = |dt: &DataType| -> u8 {
1342            match dt {
1343                DataType::Boolean => 0,
1344                DataType::TinyInt { .. } => 1,
1345                DataType::SmallInt { .. } => 2,
1346                DataType::Int { .. } => 3,
1347                DataType::BigInt { .. } => 4,
1348                DataType::Float { .. } => 5,
1349                DataType::Double { .. } => 6,
1350                DataType::Decimal { .. } => 7,
1351                _ => 0,
1352            }
1353        };
1354
1355        if order(left) >= order(right) {
1356            left.clone()
1357        } else {
1358            right.clone()
1359        }
1360    }
1361}
1362
1363/// Annotate types in-place on the expression tree.
1364///
1365/// Walks the AST bottom-up and sets `inferred_type` on each value-producing
1366/// node. After this call, `expr.inferred_type()` (and the same on any child
1367/// node) returns the inferred type.
1368pub fn annotate_types(
1369    expr: &mut Expression,
1370    schema: Option<&dyn Schema>,
1371    dialect: Option<DialectType>,
1372) {
1373    let mut annotator = TypeAnnotator::new(schema, dialect);
1374    annotator.annotate_in_place(expr);
1375}
1376
1377#[cfg(test)]
1378mod tests {
1379    use super::*;
1380    use crate::expressions::{BooleanLiteral, Cast, Null};
1381
1382    fn make_int_literal(val: i64) -> Expression {
1383        Expression::Literal(Literal::Number(val.to_string()))
1384    }
1385
1386    fn make_float_literal(val: f64) -> Expression {
1387        Expression::Literal(Literal::Number(val.to_string()))
1388    }
1389
1390    fn make_string_literal(val: &str) -> Expression {
1391        Expression::Literal(Literal::String(val.to_string()))
1392    }
1393
1394    fn make_bool_literal(val: bool) -> Expression {
1395        Expression::Boolean(BooleanLiteral { value: val })
1396    }
1397
1398    #[test]
1399    fn test_literal_types() {
1400        let mut annotator = TypeAnnotator::new(None, None);
1401
1402        // Integer literal
1403        let int_expr = make_int_literal(42);
1404        assert_eq!(
1405            annotator.annotate(&int_expr),
1406            Some(DataType::Int {
1407                length: None,
1408                integer_spelling: false
1409            })
1410        );
1411
1412        // Float literal
1413        let float_expr = make_float_literal(3.14);
1414        assert_eq!(
1415            annotator.annotate(&float_expr),
1416            Some(DataType::Double {
1417                precision: None,
1418                scale: None
1419            })
1420        );
1421
1422        // String literal
1423        let string_expr = make_string_literal("hello");
1424        assert_eq!(
1425            annotator.annotate(&string_expr),
1426            Some(DataType::VarChar {
1427                length: None,
1428                parenthesized_length: false
1429            })
1430        );
1431
1432        // Boolean literal
1433        let bool_expr = make_bool_literal(true);
1434        assert_eq!(annotator.annotate(&bool_expr), Some(DataType::Boolean));
1435
1436        // Null literal
1437        let null_expr = Expression::Null(Null);
1438        assert_eq!(annotator.annotate(&null_expr), None);
1439    }
1440
1441    #[test]
1442    fn test_comparison_types() {
1443        let mut annotator = TypeAnnotator::new(None, None);
1444
1445        // Comparison returns boolean
1446        let cmp = Expression::Gt(Box::new(BinaryOp::new(
1447            make_int_literal(1),
1448            make_int_literal(2),
1449        )));
1450        assert_eq!(annotator.annotate(&cmp), Some(DataType::Boolean));
1451
1452        // Equality returns boolean
1453        let eq = Expression::Eq(Box::new(BinaryOp::new(
1454            make_string_literal("a"),
1455            make_string_literal("b"),
1456        )));
1457        assert_eq!(annotator.annotate(&eq), Some(DataType::Boolean));
1458    }
1459
1460    #[test]
1461    fn test_arithmetic_types() {
1462        let mut annotator = TypeAnnotator::new(None, None);
1463
1464        // Int + Int = Int
1465        let add_int = Expression::Add(Box::new(BinaryOp::new(
1466            make_int_literal(1),
1467            make_int_literal(2),
1468        )));
1469        assert_eq!(
1470            annotator.annotate(&add_int),
1471            Some(DataType::Int {
1472                length: None,
1473                integer_spelling: false
1474            })
1475        );
1476
1477        // Int + Float = Double (wider type)
1478        let add_mixed = Expression::Add(Box::new(BinaryOp::new(
1479            make_int_literal(1),
1480            make_float_literal(2.5), // Use 2.5 so the string has a decimal point
1481        )));
1482        assert_eq!(
1483            annotator.annotate(&add_mixed),
1484            Some(DataType::Double {
1485                precision: None,
1486                scale: None
1487            })
1488        );
1489    }
1490
1491    #[test]
1492    fn test_string_concat_type() {
1493        let mut annotator = TypeAnnotator::new(None, None);
1494
1495        // String || String = VarChar
1496        let concat = Expression::Concat(Box::new(BinaryOp::new(
1497            make_string_literal("hello"),
1498            make_string_literal(" world"),
1499        )));
1500        assert_eq!(
1501            annotator.annotate(&concat),
1502            Some(DataType::VarChar {
1503                length: None,
1504                parenthesized_length: false
1505            })
1506        );
1507    }
1508
1509    #[test]
1510    fn test_cast_type() {
1511        let mut annotator = TypeAnnotator::new(None, None);
1512
1513        // CAST(1 AS VARCHAR)
1514        let cast = Expression::Cast(Box::new(Cast {
1515            this: make_int_literal(1),
1516            to: DataType::VarChar {
1517                length: Some(10),
1518                parenthesized_length: false,
1519            },
1520            trailing_comments: vec![],
1521            double_colon_syntax: false,
1522            format: None,
1523            default: None,
1524            inferred_type: None,
1525        }));
1526        assert_eq!(
1527            annotator.annotate(&cast),
1528            Some(DataType::VarChar {
1529                length: Some(10),
1530                parenthesized_length: false
1531            })
1532        );
1533    }
1534
1535    #[test]
1536    fn test_function_types() {
1537        let mut annotator = TypeAnnotator::new(None, None);
1538
1539        // COUNT returns BigInt
1540        let count =
1541            Expression::Function(Box::new(Function::new("COUNT", vec![make_int_literal(1)])));
1542        assert_eq!(
1543            annotator.annotate(&count),
1544            Some(DataType::BigInt { length: None })
1545        );
1546
1547        // UPPER returns VarChar
1548        let upper = Expression::Function(Box::new(Function::new(
1549            "UPPER",
1550            vec![make_string_literal("hello")],
1551        )));
1552        assert_eq!(
1553            annotator.annotate(&upper),
1554            Some(DataType::VarChar {
1555                length: None,
1556                parenthesized_length: false
1557            })
1558        );
1559
1560        // NOW returns Timestamp
1561        let now = Expression::Function(Box::new(Function::new("NOW", vec![])));
1562        assert_eq!(
1563            annotator.annotate(&now),
1564            Some(DataType::Timestamp {
1565                precision: None,
1566                timezone: false
1567            })
1568        );
1569    }
1570
1571    #[test]
1572    fn test_coalesce_type_inference() {
1573        let mut annotator = TypeAnnotator::new(None, None);
1574
1575        // COALESCE(NULL, 1) returns Int (type of first non-null arg)
1576        let coalesce = Expression::Function(Box::new(Function::new(
1577            "COALESCE",
1578            vec![Expression::Null(Null), make_int_literal(1)],
1579        )));
1580        assert_eq!(
1581            annotator.annotate(&coalesce),
1582            Some(DataType::Int {
1583                length: None,
1584                integer_spelling: false
1585            })
1586        );
1587    }
1588
1589    #[test]
1590    fn test_type_coercion_class() {
1591        // Text types
1592        assert_eq!(
1593            TypeCoercionClass::from_data_type(&DataType::VarChar {
1594                length: None,
1595                parenthesized_length: false
1596            }),
1597            Some(TypeCoercionClass::Text)
1598        );
1599        assert_eq!(
1600            TypeCoercionClass::from_data_type(&DataType::Text),
1601            Some(TypeCoercionClass::Text)
1602        );
1603
1604        // Numeric types
1605        assert_eq!(
1606            TypeCoercionClass::from_data_type(&DataType::Int {
1607                length: None,
1608                integer_spelling: false
1609            }),
1610            Some(TypeCoercionClass::Numeric)
1611        );
1612        assert_eq!(
1613            TypeCoercionClass::from_data_type(&DataType::Double {
1614                precision: None,
1615                scale: None
1616            }),
1617            Some(TypeCoercionClass::Numeric)
1618        );
1619
1620        // Timelike types
1621        assert_eq!(
1622            TypeCoercionClass::from_data_type(&DataType::Date),
1623            Some(TypeCoercionClass::Timelike)
1624        );
1625        assert_eq!(
1626            TypeCoercionClass::from_data_type(&DataType::Timestamp {
1627                precision: None,
1628                timezone: false
1629            }),
1630            Some(TypeCoercionClass::Timelike)
1631        );
1632
1633        // Unknown types
1634        assert_eq!(TypeCoercionClass::from_data_type(&DataType::Json), None);
1635    }
1636
1637    #[test]
1638    fn test_wider_numeric_type() {
1639        let annotator = TypeAnnotator::new(None, None);
1640
1641        // Int vs BigInt -> BigInt
1642        let result = annotator.wider_numeric_type(
1643            &DataType::Int {
1644                length: None,
1645                integer_spelling: false,
1646            },
1647            &DataType::BigInt { length: None },
1648        );
1649        assert_eq!(result, DataType::BigInt { length: None });
1650
1651        // Float vs Double -> Double
1652        let result = annotator.wider_numeric_type(
1653            &DataType::Float {
1654                precision: None,
1655                scale: None,
1656                real_spelling: false,
1657            },
1658            &DataType::Double {
1659                precision: None,
1660                scale: None,
1661            },
1662        );
1663        assert_eq!(
1664            result,
1665            DataType::Double {
1666                precision: None,
1667                scale: None
1668            }
1669        );
1670
1671        // Int vs Double -> Double
1672        let result = annotator.wider_numeric_type(
1673            &DataType::Int {
1674                length: None,
1675                integer_spelling: false,
1676            },
1677            &DataType::Double {
1678                precision: None,
1679                scale: None,
1680            },
1681        );
1682        assert_eq!(
1683            result,
1684            DataType::Double {
1685                precision: None,
1686                scale: None
1687            }
1688        );
1689    }
1690
1691    #[test]
1692    fn test_aggregate_return_types() {
1693        let mut annotator = TypeAnnotator::new(None, None);
1694
1695        // SUM(int) returns BigInt
1696        let sum_type = annotator.get_aggregate_return_type("SUM", &[make_int_literal(1)]);
1697        assert_eq!(sum_type, Some(DataType::BigInt { length: None }));
1698
1699        // AVG always returns Double
1700        let avg_type = annotator.get_aggregate_return_type("AVG", &[make_int_literal(1)]);
1701        assert_eq!(
1702            avg_type,
1703            Some(DataType::Double {
1704                precision: None,
1705                scale: None
1706            })
1707        );
1708
1709        // MIN/MAX preserve input type
1710        let min_type = annotator.get_aggregate_return_type("MIN", &[make_string_literal("a")]);
1711        assert_eq!(
1712            min_type,
1713            Some(DataType::VarChar {
1714                length: None,
1715                parenthesized_length: false
1716            })
1717        );
1718    }
1719
1720    #[test]
1721    fn test_date_literal_types() {
1722        let mut annotator = TypeAnnotator::new(None, None);
1723
1724        // DATE literal
1725        let date_expr = Expression::Literal(Literal::Date("2024-01-15".to_string()));
1726        assert_eq!(annotator.annotate(&date_expr), Some(DataType::Date));
1727
1728        // TIME literal
1729        let time_expr = Expression::Literal(Literal::Time("10:30:00".to_string()));
1730        assert_eq!(
1731            annotator.annotate(&time_expr),
1732            Some(DataType::Time {
1733                precision: None,
1734                timezone: false
1735            })
1736        );
1737
1738        // TIMESTAMP literal
1739        let ts_expr = Expression::Literal(Literal::Timestamp("2024-01-15 10:30:00".to_string()));
1740        assert_eq!(
1741            annotator.annotate(&ts_expr),
1742            Some(DataType::Timestamp {
1743                precision: None,
1744                timezone: false
1745            })
1746        );
1747    }
1748
1749    #[test]
1750    fn test_logical_operations() {
1751        let mut annotator = TypeAnnotator::new(None, None);
1752
1753        // AND returns boolean
1754        let and_expr = Expression::And(Box::new(BinaryOp::new(
1755            make_bool_literal(true),
1756            make_bool_literal(false),
1757        )));
1758        assert_eq!(annotator.annotate(&and_expr), Some(DataType::Boolean));
1759
1760        // OR returns boolean
1761        let or_expr = Expression::Or(Box::new(BinaryOp::new(
1762            make_bool_literal(true),
1763            make_bool_literal(false),
1764        )));
1765        assert_eq!(annotator.annotate(&or_expr), Some(DataType::Boolean));
1766
1767        // NOT returns boolean
1768        let not_expr = Expression::Not(Box::new(crate::expressions::UnaryOp::new(
1769            make_bool_literal(true),
1770        )));
1771        assert_eq!(annotator.annotate(&not_expr), Some(DataType::Boolean));
1772    }
1773
1774    // ========================================
1775    // Tests for newly implemented features
1776    // ========================================
1777
1778    #[test]
1779    fn test_subscript_array_type() {
1780        let mut annotator = TypeAnnotator::new(None, None);
1781
1782        // Array[index] returns element type
1783        let arr = Expression::Array(Box::new(crate::expressions::Array {
1784            expressions: vec![make_int_literal(1), make_int_literal(2)],
1785        }));
1786        let subscript = Expression::Subscript(Box::new(crate::expressions::Subscript {
1787            this: arr,
1788            index: make_int_literal(0),
1789        }));
1790        assert_eq!(
1791            annotator.annotate(&subscript),
1792            Some(DataType::Int {
1793                length: None,
1794                integer_spelling: false
1795            })
1796        );
1797    }
1798
1799    #[test]
1800    fn test_subscript_map_type() {
1801        let mut annotator = TypeAnnotator::new(None, None);
1802
1803        // Map[key] returns value type
1804        let map = Expression::Map(Box::new(crate::expressions::Map {
1805            keys: vec![make_string_literal("a")],
1806            values: vec![make_int_literal(1)],
1807        }));
1808        let subscript = Expression::Subscript(Box::new(crate::expressions::Subscript {
1809            this: map,
1810            index: make_string_literal("a"),
1811        }));
1812        assert_eq!(
1813            annotator.annotate(&subscript),
1814            Some(DataType::Int {
1815                length: None,
1816                integer_spelling: false
1817            })
1818        );
1819    }
1820
1821    #[test]
1822    fn test_struct_type() {
1823        let mut annotator = TypeAnnotator::new(None, None);
1824
1825        // STRUCT literal
1826        let struct_expr = Expression::Struct(Box::new(crate::expressions::Struct {
1827            fields: vec![
1828                (Some("name".to_string()), make_string_literal("Alice")),
1829                (Some("age".to_string()), make_int_literal(30)),
1830            ],
1831        }));
1832        let result = annotator.annotate(&struct_expr);
1833        assert!(matches!(result, Some(DataType::Struct { fields, .. }) if fields.len() == 2));
1834    }
1835
1836    #[test]
1837    fn test_map_type() {
1838        let mut annotator = TypeAnnotator::new(None, None);
1839
1840        // MAP literal
1841        let map_expr = Expression::Map(Box::new(crate::expressions::Map {
1842            keys: vec![make_string_literal("a"), make_string_literal("b")],
1843            values: vec![make_int_literal(1), make_int_literal(2)],
1844        }));
1845        let result = annotator.annotate(&map_expr);
1846        assert!(matches!(
1847            result,
1848            Some(DataType::Map { key_type, value_type })
1849            if matches!(*key_type, DataType::VarChar { .. })
1850               && matches!(*value_type, DataType::Int { .. })
1851        ));
1852    }
1853
1854    #[test]
1855    fn test_explode_array_type() {
1856        let mut annotator = TypeAnnotator::new(None, None);
1857
1858        // EXPLODE(array) returns element type
1859        let arr = Expression::Array(Box::new(crate::expressions::Array {
1860            expressions: vec![make_int_literal(1), make_int_literal(2)],
1861        }));
1862        let explode = Expression::Explode(Box::new(crate::expressions::UnaryFunc {
1863            this: arr,
1864            original_name: None,
1865            inferred_type: None,
1866        }));
1867        assert_eq!(
1868            annotator.annotate(&explode),
1869            Some(DataType::Int {
1870                length: None,
1871                integer_spelling: false
1872            })
1873        );
1874    }
1875
1876    #[test]
1877    fn test_unnest_array_type() {
1878        let mut annotator = TypeAnnotator::new(None, None);
1879
1880        // UNNEST(array) returns element type
1881        let arr = Expression::Array(Box::new(crate::expressions::Array {
1882            expressions: vec![make_string_literal("a"), make_string_literal("b")],
1883        }));
1884        let unnest = Expression::Unnest(Box::new(crate::expressions::UnnestFunc {
1885            this: arr,
1886            expressions: Vec::new(),
1887            with_ordinality: false,
1888            alias: None,
1889            offset_alias: None,
1890        }));
1891        assert_eq!(
1892            annotator.annotate(&unnest),
1893            Some(DataType::VarChar {
1894                length: None,
1895                parenthesized_length: false
1896            })
1897        );
1898    }
1899
1900    #[test]
1901    fn test_set_operation_type() {
1902        let mut annotator = TypeAnnotator::new(None, None);
1903
1904        // UNION/INTERSECT/EXCEPT return None (they produce relations, not scalars)
1905        let select = Expression::Select(Box::new(crate::expressions::Select::default()));
1906        let union = Expression::Union(Box::new(crate::expressions::Union {
1907            left: select.clone(),
1908            right: select.clone(),
1909            all: false,
1910            distinct: false,
1911            with: None,
1912            order_by: None,
1913            limit: None,
1914            offset: None,
1915            by_name: false,
1916            side: None,
1917            kind: None,
1918            corresponding: false,
1919            strict: false,
1920            on_columns: Vec::new(),
1921            distribute_by: None,
1922            sort_by: None,
1923            cluster_by: None,
1924        }));
1925        assert_eq!(annotator.annotate(&union), None);
1926    }
1927
1928    #[test]
1929    fn test_floor_ceil_input_dependent_types() {
1930        use crate::expressions::{CeilFunc, FloorFunc};
1931
1932        let mut annotator = TypeAnnotator::new(None, None);
1933
1934        // FLOOR/CEIL with integer literal → Double (integers get promoted)
1935        let floor_int = Expression::Floor(Box::new(FloorFunc {
1936            this: make_int_literal(42),
1937            scale: None,
1938            to: None,
1939        }));
1940        assert_eq!(
1941            annotator.annotate(&floor_int),
1942            Some(DataType::Double {
1943                precision: None,
1944                scale: None,
1945            })
1946        );
1947
1948        let ceil_int = Expression::Ceil(Box::new(CeilFunc {
1949            this: make_int_literal(42),
1950            decimals: None,
1951            to: None,
1952        }));
1953        assert_eq!(
1954            annotator.annotate(&ceil_int),
1955            Some(DataType::Double {
1956                precision: None,
1957                scale: None,
1958            })
1959        );
1960
1961        // FLOOR with float literal → Double (literals are always Double)
1962        let floor_float = Expression::Floor(Box::new(FloorFunc {
1963            this: make_float_literal(3.14),
1964            scale: None,
1965            to: None,
1966        }));
1967        assert_eq!(
1968            annotator.annotate(&floor_float),
1969            Some(DataType::Double {
1970                precision: None,
1971                scale: None,
1972            })
1973        );
1974
1975        // FLOOR via Function("FLOOR") path → falls through to arg-based inference
1976        let floor_fn =
1977            Expression::Function(Box::new(Function::new("FLOOR", vec![make_int_literal(1)])));
1978        assert_eq!(
1979            annotator.annotate(&floor_fn),
1980            Some(DataType::Int {
1981                length: None,
1982                integer_spelling: false,
1983            })
1984        );
1985    }
1986
1987    #[test]
1988    fn test_sign_preserves_input_type() {
1989        use crate::expressions::UnaryFunc;
1990
1991        let mut annotator = TypeAnnotator::new(None, None);
1992
1993        // SIGN with integer literal → Int (preserves input type)
1994        let sign_int = Expression::Sign(Box::new(UnaryFunc {
1995            this: make_int_literal(42),
1996            original_name: None,
1997            inferred_type: None,
1998        }));
1999        assert_eq!(
2000            annotator.annotate(&sign_int),
2001            Some(DataType::Int {
2002                length: None,
2003                integer_spelling: false,
2004            })
2005        );
2006
2007        // SIGN with float literal → Double (preserves input type)
2008        let sign_float = Expression::Sign(Box::new(UnaryFunc {
2009            this: make_float_literal(3.14),
2010            original_name: None,
2011            inferred_type: None,
2012        }));
2013        assert_eq!(
2014            annotator.annotate(&sign_float),
2015            Some(DataType::Double {
2016                precision: None,
2017                scale: None,
2018            })
2019        );
2020
2021        // SIGN with a CAST to INT → Int (preserves input type)
2022        let sign_cast = Expression::Sign(Box::new(UnaryFunc {
2023            this: Expression::Cast(Box::new(Cast {
2024                this: make_int_literal(42),
2025                to: DataType::Int {
2026                    length: None,
2027                    integer_spelling: false,
2028                },
2029                format: None,
2030                trailing_comments: Vec::new(),
2031                double_colon_syntax: false,
2032                default: None,
2033                inferred_type: None,
2034            })),
2035            original_name: None,
2036            inferred_type: None,
2037        }));
2038        assert_eq!(
2039            annotator.annotate(&sign_cast),
2040            Some(DataType::Int {
2041                length: None,
2042                integer_spelling: false,
2043            })
2044        );
2045    }
2046
2047    #[test]
2048    fn test_date_format_types() {
2049        use crate::expressions::{DateFormatFunc, TimeToStr};
2050
2051        let mut annotator = TypeAnnotator::new(None, None);
2052
2053        // DateFormat → VarChar
2054        let date_fmt = Expression::DateFormat(Box::new(DateFormatFunc {
2055            this: make_string_literal("2024-01-01"),
2056            format: make_string_literal("%Y-%m-%d"),
2057        }));
2058        assert_eq!(
2059            annotator.annotate(&date_fmt),
2060            Some(DataType::VarChar {
2061                length: None,
2062                parenthesized_length: false,
2063            })
2064        );
2065
2066        // FormatDate → VarChar
2067        let format_date = Expression::FormatDate(Box::new(DateFormatFunc {
2068            this: make_string_literal("2024-01-01"),
2069            format: make_string_literal("%Y-%m-%d"),
2070        }));
2071        assert_eq!(
2072            annotator.annotate(&format_date),
2073            Some(DataType::VarChar {
2074                length: None,
2075                parenthesized_length: false,
2076            })
2077        );
2078
2079        // TimeToStr → VarChar
2080        let time_to_str = Expression::TimeToStr(Box::new(TimeToStr {
2081            this: Box::new(make_string_literal("2024-01-01")),
2082            format: "%Y-%m-%d".to_string(),
2083            culture: None,
2084            zone: None,
2085        }));
2086        assert_eq!(
2087            annotator.annotate(&time_to_str),
2088            Some(DataType::VarChar {
2089                length: None,
2090                parenthesized_length: false,
2091            })
2092        );
2093
2094        // DATE_FORMAT via Function path → VarChar (uses function_return_types)
2095        let date_fmt_fn = Expression::Function(Box::new(Function::new(
2096            "DATE_FORMAT",
2097            vec![
2098                make_string_literal("2024-01-01"),
2099                make_string_literal("%Y-%m-%d"),
2100            ],
2101        )));
2102        assert_eq!(
2103            annotator.annotate(&date_fmt_fn),
2104            Some(DataType::VarChar {
2105                length: None,
2106                parenthesized_length: false,
2107            })
2108        );
2109    }
2110
2111    // ===== In-place annotation tests (Step 9) =====
2112
2113    #[test]
2114    fn test_annotate_in_place_sets_type_on_root() {
2115        // Literals don't have inferred_type field, so test with a BinaryOp
2116        let mut expr = Expression::Add(Box::new(BinaryOp::new(
2117            make_int_literal(1),
2118            make_int_literal(2),
2119        )));
2120        annotate_types(&mut expr, None, None);
2121        assert_eq!(
2122            expr.inferred_type(),
2123            Some(&DataType::Int {
2124                length: None,
2125                integer_spelling: false,
2126            })
2127        );
2128    }
2129
2130    #[test]
2131    fn test_annotate_in_place_sets_types_on_children() {
2132        // (a + b) + (c - d) where all are ints
2133        // This tests that inner BinaryOp children also get annotated
2134        let inner_add = Expression::Add(Box::new(BinaryOp::new(
2135            make_int_literal(1),
2136            make_float_literal(2.5),
2137        )));
2138        let inner_sub = Expression::Sub(Box::new(BinaryOp::new(
2139            make_int_literal(3),
2140            make_int_literal(4),
2141        )));
2142        let mut expr = Expression::Add(Box::new(BinaryOp::new(inner_add, inner_sub)));
2143        annotate_types(&mut expr, None, None);
2144
2145        // Root (Add) should be Double (wider of Double and Int)
2146        assert_eq!(
2147            expr.inferred_type(),
2148            Some(&DataType::Double {
2149                precision: None,
2150                scale: None,
2151            })
2152        );
2153
2154        // Children should also have types
2155        if let Expression::Add(op) = &expr {
2156            // Left child (1 + 2.5) should be Double
2157            assert_eq!(
2158                op.left.inferred_type(),
2159                Some(&DataType::Double {
2160                    precision: None,
2161                    scale: None,
2162                })
2163            );
2164            // Right child (3 - 4) should be Int
2165            assert_eq!(
2166                op.right.inferred_type(),
2167                Some(&DataType::Int {
2168                    length: None,
2169                    integer_spelling: false,
2170                })
2171            );
2172        } else {
2173            panic!("Expected Add expression");
2174        }
2175    }
2176
2177    #[test]
2178    fn test_annotate_in_place_comparison() {
2179        let mut expr = Expression::Eq(Box::new(BinaryOp::new(
2180            make_int_literal(1),
2181            make_int_literal(2),
2182        )));
2183        annotate_types(&mut expr, None, None);
2184        assert_eq!(expr.inferred_type(), Some(&DataType::Boolean));
2185    }
2186
2187    #[test]
2188    fn test_annotate_in_place_cast() {
2189        let mut expr = Expression::Cast(Box::new(Cast {
2190            this: make_int_literal(42),
2191            to: DataType::VarChar {
2192                length: None,
2193                parenthesized_length: false,
2194            },
2195            trailing_comments: vec![],
2196            double_colon_syntax: false,
2197            format: None,
2198            default: None,
2199            inferred_type: None,
2200        }));
2201        annotate_types(&mut expr, None, None);
2202        assert_eq!(
2203            expr.inferred_type(),
2204            Some(&DataType::VarChar {
2205                length: None,
2206                parenthesized_length: false,
2207            })
2208        );
2209    }
2210
2211    #[test]
2212    fn test_annotate_in_place_nested_expression() {
2213        // (1 + 2) > 0  -> should be Boolean at root, Int for the Add
2214        let add = Expression::Add(Box::new(BinaryOp::new(
2215            make_int_literal(1),
2216            make_int_literal(2),
2217        )));
2218        let mut expr = Expression::Gt(Box::new(BinaryOp::new(add, make_int_literal(0))));
2219        annotate_types(&mut expr, None, None);
2220
2221        assert_eq!(expr.inferred_type(), Some(&DataType::Boolean));
2222
2223        // The left child (Add) should be Int
2224        if let Expression::Gt(op) = &expr {
2225            assert_eq!(
2226                op.left.inferred_type(),
2227                Some(&DataType::Int {
2228                    length: None,
2229                    integer_spelling: false,
2230                })
2231            );
2232        }
2233    }
2234
2235    #[test]
2236    fn test_annotate_in_place_parsed_sql() {
2237        use crate::parser::Parser;
2238        let mut expr =
2239            Parser::parse_sql("SELECT 1 + 2.0, 'hello', TRUE").expect("parse failed")[0].clone();
2240        annotate_types(&mut expr, None, None);
2241
2242        // The expression tree should have types annotated throughout
2243        // We can't easily inspect deep inside a parsed Select, but at minimum
2244        // the root Select itself won't have a type (it's not value-producing)
2245        assert!(expr.inferred_type().is_none());
2246    }
2247
2248    #[test]
2249    fn test_inferred_type_json_roundtrip() {
2250        let mut expr = Expression::Add(Box::new(BinaryOp::new(
2251            make_int_literal(1),
2252            make_int_literal(2),
2253        )));
2254        annotate_types(&mut expr, None, None);
2255
2256        // Serialize to JSON
2257        let json = serde_json::to_string(&expr).expect("serialize failed");
2258        // The JSON should contain the inferred_type
2259        assert!(json.contains("inferred_type"));
2260
2261        // Deserialize back
2262        let deserialized: Expression = serde_json::from_str(&json).expect("deserialize failed");
2263        assert_eq!(
2264            deserialized.inferred_type(),
2265            Some(&DataType::Int {
2266                length: None,
2267                integer_spelling: false,
2268            })
2269        );
2270    }
2271
2272    #[test]
2273    fn test_inferred_type_none_not_serialized() {
2274        // When inferred_type is None, it should not appear in JSON
2275        let expr = Expression::Add(Box::new(BinaryOp::new(
2276            make_int_literal(1),
2277            make_int_literal(2),
2278        )));
2279        let json = serde_json::to_string(&expr).expect("serialize failed");
2280        assert!(!json.contains("inferred_type"));
2281    }
2282}