Skip to main content

polyglot_sql/optimizer/
annotate_types.rs

1//! Type Annotation for SQL Expressions
2//!
3//! This module provides type inference and annotation for SQL AST nodes.
4//! It walks the expression tree and assigns data types to expressions based on:
5//! - Literal values (strings, numbers, booleans)
6//! - Column references (from schema)
7//! - Function return types
8//! - Operator result types (with coercion rules)
9//!
10//! Based on SQLGlot's optimizer/annotate_types.py
11
12use std::collections::HashMap;
13
14use crate::dialects::DialectType;
15use crate::expressions::{
16    BinaryOp, DataType, Expression, Function, IfFunc, ListAggOverflow, Literal, Map, Nvl2Func,
17    Struct, StructField, Subscript,
18};
19use crate::schema::Schema;
20
21/// Type coercion class for determining result types in binary operations.
22/// Higher-priority classes win during coercion.
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
24pub enum TypeCoercionClass {
25    /// Text types (CHAR, VARCHAR, TEXT)
26    Text = 0,
27    /// Numeric types (INT, FLOAT, DECIMAL, etc.)
28    Numeric = 1,
29    /// Time-like types (DATE, TIME, TIMESTAMP, INTERVAL)
30    Timelike = 2,
31}
32
33impl TypeCoercionClass {
34    /// Get the coercion class for a data type
35    pub fn from_data_type(dt: &DataType) -> Option<Self> {
36        match dt {
37            // Text types
38            DataType::Char { .. }
39            | DataType::VarChar { .. }
40            | DataType::Text
41            | DataType::Binary { .. }
42            | DataType::VarBinary { .. }
43            | DataType::Blob => Some(TypeCoercionClass::Text),
44
45            // Numeric types
46            DataType::Boolean
47            | DataType::TinyInt { .. }
48            | DataType::SmallInt { .. }
49            | DataType::Int { .. }
50            | DataType::BigInt { .. }
51            | DataType::Float { .. }
52            | DataType::Double { .. }
53            | DataType::Decimal { .. } => Some(TypeCoercionClass::Numeric),
54
55            // Timelike types
56            DataType::Date
57            | DataType::Time { .. }
58            | DataType::Timestamp { .. }
59            | DataType::Interval { .. } => Some(TypeCoercionClass::Timelike),
60
61            // Other types don't have a coercion class
62            _ => None,
63        }
64    }
65}
66
67/// Type annotation configuration and state
68pub struct TypeAnnotator<'a> {
69    /// Schema for looking up column types
70    _schema: Option<&'a dyn Schema>,
71    /// Dialect for dialect-specific type rules
72    _dialect: Option<DialectType>,
73    /// Whether to annotate types for all expressions
74    annotate_aggregates: bool,
75    /// Function return type mappings
76    function_return_types: HashMap<String, DataType>,
77}
78
79impl<'a> TypeAnnotator<'a> {
80    /// Create a new type annotator
81    pub fn new(schema: Option<&'a dyn Schema>, dialect: Option<DialectType>) -> Self {
82        let mut annotator = Self {
83            _schema: schema,
84            _dialect: dialect,
85            annotate_aggregates: true,
86            function_return_types: HashMap::new(),
87        };
88        annotator.init_function_return_types();
89        annotator
90    }
91
92    /// Initialize function return type mappings
93    fn init_function_return_types(&mut self) {
94        // Aggregate functions
95        self.function_return_types
96            .insert("COUNT".to_string(), DataType::BigInt { length: None });
97        self.function_return_types.insert(
98            "SUM".to_string(),
99            DataType::Decimal {
100                precision: None,
101                scale: None,
102            },
103        );
104        self.function_return_types.insert(
105            "AVG".to_string(),
106            DataType::Double {
107                precision: None,
108                scale: None,
109            },
110        );
111
112        // String functions
113        self.function_return_types.insert(
114            "CONCAT".to_string(),
115            DataType::VarChar {
116                length: None,
117                parenthesized_length: false,
118            },
119        );
120        self.function_return_types.insert(
121            "UPPER".to_string(),
122            DataType::VarChar {
123                length: None,
124                parenthesized_length: false,
125            },
126        );
127        self.function_return_types.insert(
128            "LOWER".to_string(),
129            DataType::VarChar {
130                length: None,
131                parenthesized_length: false,
132            },
133        );
134        self.function_return_types.insert(
135            "TRIM".to_string(),
136            DataType::VarChar {
137                length: None,
138                parenthesized_length: false,
139            },
140        );
141        self.function_return_types.insert(
142            "LTRIM".to_string(),
143            DataType::VarChar {
144                length: None,
145                parenthesized_length: false,
146            },
147        );
148        self.function_return_types.insert(
149            "RTRIM".to_string(),
150            DataType::VarChar {
151                length: None,
152                parenthesized_length: false,
153            },
154        );
155        self.function_return_types.insert(
156            "SUBSTRING".to_string(),
157            DataType::VarChar {
158                length: None,
159                parenthesized_length: false,
160            },
161        );
162        self.function_return_types.insert(
163            "SUBSTR".to_string(),
164            DataType::VarChar {
165                length: None,
166                parenthesized_length: false,
167            },
168        );
169        self.function_return_types.insert(
170            "REPLACE".to_string(),
171            DataType::VarChar {
172                length: None,
173                parenthesized_length: false,
174            },
175        );
176        self.function_return_types.insert(
177            "LENGTH".to_string(),
178            DataType::Int {
179                length: None,
180                integer_spelling: false,
181            },
182        );
183        self.function_return_types.insert(
184            "CHAR_LENGTH".to_string(),
185            DataType::Int {
186                length: None,
187                integer_spelling: false,
188            },
189        );
190
191        // Date/Time functions
192        self.function_return_types.insert(
193            "NOW".to_string(),
194            DataType::Timestamp {
195                precision: None,
196                timezone: false,
197            },
198        );
199        self.function_return_types.insert(
200            "CURRENT_TIMESTAMP".to_string(),
201            DataType::Timestamp {
202                precision: None,
203                timezone: false,
204            },
205        );
206        self.function_return_types
207            .insert("CURRENT_DATE".to_string(), DataType::Date);
208        self.function_return_types.insert(
209            "CURRENT_TIME".to_string(),
210            DataType::Time {
211                precision: None,
212                timezone: false,
213            },
214        );
215        self.function_return_types
216            .insert("DATE".to_string(), DataType::Date);
217        self.function_return_types.insert(
218            "YEAR".to_string(),
219            DataType::Int {
220                length: None,
221                integer_spelling: false,
222            },
223        );
224        self.function_return_types.insert(
225            "MONTH".to_string(),
226            DataType::Int {
227                length: None,
228                integer_spelling: false,
229            },
230        );
231        self.function_return_types.insert(
232            "DAY".to_string(),
233            DataType::Int {
234                length: None,
235                integer_spelling: false,
236            },
237        );
238        self.function_return_types.insert(
239            "HOUR".to_string(),
240            DataType::Int {
241                length: None,
242                integer_spelling: false,
243            },
244        );
245        self.function_return_types.insert(
246            "MINUTE".to_string(),
247            DataType::Int {
248                length: None,
249                integer_spelling: false,
250            },
251        );
252        self.function_return_types.insert(
253            "SECOND".to_string(),
254            DataType::Int {
255                length: None,
256                integer_spelling: false,
257            },
258        );
259        self.function_return_types.insert(
260            "EXTRACT".to_string(),
261            DataType::Int {
262                length: None,
263                integer_spelling: false,
264            },
265        );
266        self.function_return_types.insert(
267            "DATE_DIFF".to_string(),
268            DataType::Int {
269                length: None,
270                integer_spelling: false,
271            },
272        );
273        self.function_return_types.insert(
274            "DATEDIFF".to_string(),
275            DataType::Int {
276                length: None,
277                integer_spelling: false,
278            },
279        );
280
281        // Math functions
282        self.function_return_types.insert(
283            "ABS".to_string(),
284            DataType::Double {
285                precision: None,
286                scale: None,
287            },
288        );
289        self.function_return_types.insert(
290            "ROUND".to_string(),
291            DataType::Double {
292                precision: None,
293                scale: None,
294            },
295        );
296        self.function_return_types.insert(
297            "DATE_FORMAT".to_string(),
298            DataType::VarChar {
299                length: None,
300                parenthesized_length: false,
301            },
302        );
303        self.function_return_types.insert(
304            "FORMAT_DATE".to_string(),
305            DataType::VarChar {
306                length: None,
307                parenthesized_length: false,
308            },
309        );
310        self.function_return_types.insert(
311            "TIME_TO_STR".to_string(),
312            DataType::VarChar {
313                length: None,
314                parenthesized_length: false,
315            },
316        );
317        self.function_return_types.insert(
318            "SQRT".to_string(),
319            DataType::Double {
320                precision: None,
321                scale: None,
322            },
323        );
324        self.function_return_types.insert(
325            "POWER".to_string(),
326            DataType::Double {
327                precision: None,
328                scale: None,
329            },
330        );
331        self.function_return_types.insert(
332            "MOD".to_string(),
333            DataType::Int {
334                length: None,
335                integer_spelling: false,
336            },
337        );
338        self.function_return_types.insert(
339            "LOG".to_string(),
340            DataType::Double {
341                precision: None,
342                scale: None,
343            },
344        );
345        self.function_return_types.insert(
346            "LN".to_string(),
347            DataType::Double {
348                precision: None,
349                scale: None,
350            },
351        );
352        self.function_return_types.insert(
353            "EXP".to_string(),
354            DataType::Double {
355                precision: None,
356                scale: None,
357            },
358        );
359
360        // Null-handling functions return Unknown (infer from args)
361        self.function_return_types
362            .insert("COALESCE".to_string(), DataType::Unknown);
363        self.function_return_types
364            .insert("NULLIF".to_string(), DataType::Unknown);
365        self.function_return_types
366            .insert("GREATEST".to_string(), DataType::Unknown);
367        self.function_return_types
368            .insert("LEAST".to_string(), DataType::Unknown);
369    }
370
371    /// Annotate types for an expression tree
372    pub fn annotate(&mut self, expr: &Expression) -> Option<DataType> {
373        match expr {
374            // Literals
375            Expression::Literal(lit) => self.annotate_literal(lit),
376            Expression::Boolean(_) => Some(DataType::Boolean),
377            Expression::Null(_) => None, // NULL has no type
378
379            // Arithmetic binary operations
380            Expression::Add(op)
381            | Expression::Sub(op)
382            | Expression::Mul(op)
383            | Expression::Div(op)
384            | Expression::Mod(op) => self.annotate_arithmetic(op),
385
386            // Comparison operations - always boolean
387            Expression::Eq(_)
388            | Expression::Neq(_)
389            | Expression::Lt(_)
390            | Expression::Lte(_)
391            | Expression::Gt(_)
392            | Expression::Gte(_)
393            | Expression::Like(_)
394            | Expression::ILike(_) => Some(DataType::Boolean),
395
396            // Logical operations - always boolean
397            Expression::And(_) | Expression::Or(_) | Expression::Not(_) => Some(DataType::Boolean),
398
399            // Predicates - always boolean
400            Expression::Between(_)
401            | Expression::In(_)
402            | Expression::IsNull(_)
403            | Expression::IsTrue(_)
404            | Expression::IsFalse(_)
405            | Expression::Is(_)
406            | Expression::Exists(_) => Some(DataType::Boolean),
407
408            // String concatenation
409            Expression::Concat(_) => Some(DataType::VarChar {
410                length: None,
411                parenthesized_length: false,
412            }),
413
414            // Bitwise operations - integer
415            Expression::BitwiseAnd(_)
416            | Expression::BitwiseOr(_)
417            | Expression::BitwiseXor(_)
418            | Expression::BitwiseNot(_) => Some(DataType::BigInt { length: None }),
419
420            // Negation preserves type
421            Expression::Neg(op) => self.annotate(&op.this),
422
423            // Functions
424            Expression::Function(func) => self.annotate_function(func),
425            Expression::IfFunc(if_func) => self.annotate_if_func(if_func),
426            Expression::Nvl2(nvl2) => self.annotate_nvl2(nvl2),
427
428            // Typed aggregate functions
429            Expression::Count(_) => Some(DataType::BigInt { length: None }),
430            Expression::Sum(agg) => self.annotate_sum(&agg.this),
431            Expression::SumIf(f) => self.annotate_sum(&f.this),
432            Expression::Avg(_) => Some(DataType::Double {
433                precision: None,
434                scale: None,
435            }),
436            Expression::Min(agg) => self.annotate(&agg.this),
437            Expression::Max(agg) => self.annotate(&agg.this),
438            Expression::GroupConcat(_) | Expression::StringAgg(_) | Expression::ListAgg(_) => {
439                Some(DataType::VarChar {
440                    length: None,
441                    parenthesized_length: false,
442                })
443            }
444
445            // Generic aggregate function
446            Expression::AggregateFunction(agg) => {
447                if !self.annotate_aggregates {
448                    return None;
449                }
450                let func_name = agg.name.to_uppercase();
451                self.get_aggregate_return_type(&func_name, &agg.args)
452            }
453
454            // Column references - look up type from schema if available
455            Expression::Column(col) => {
456                if let Some(schema) = &self._schema {
457                    let table_name = col.table.as_ref().map(|t| t.name.as_str()).unwrap_or("");
458                    schema.get_column_type(table_name, &col.name.name).ok()
459                } else {
460                    None
461                }
462            }
463
464            // Cast expressions
465            Expression::Cast(cast) => Some(cast.to.clone()),
466            Expression::SafeCast(cast) => Some(cast.to.clone()),
467            Expression::TryCast(cast) => Some(cast.to.clone()),
468
469            // Subqueries - type is the type of the first SELECT expression
470            Expression::Subquery(subq) => {
471                if let Expression::Select(select) = &subq.this {
472                    if let Some(first) = select.expressions.first() {
473                        self.annotate(first)
474                    } else {
475                        None
476                    }
477                } else {
478                    None
479                }
480            }
481
482            // CASE expression - type of the first THEN/ELSE
483            Expression::Case(case) => {
484                if let Some(else_expr) = &case.else_ {
485                    self.annotate(else_expr)
486                } else if let Some((_, then_expr)) = case.whens.first() {
487                    self.annotate(then_expr)
488                } else {
489                    None
490                }
491            }
492
493            // Array expressions
494            Expression::Array(arr) => {
495                if let Some(first) = arr.expressions.first() {
496                    if let Some(elem_type) = self.annotate(first) {
497                        Some(DataType::Array {
498                            element_type: Box::new(elem_type),
499                            dimension: None,
500                        })
501                    } else {
502                        Some(DataType::Array {
503                            element_type: Box::new(DataType::Unknown),
504                            dimension: None,
505                        })
506                    }
507                } else {
508                    Some(DataType::Array {
509                        element_type: Box::new(DataType::Unknown),
510                        dimension: None,
511                    })
512                }
513            }
514
515            // Interval expressions
516            Expression::Interval(_) => Some(DataType::Interval {
517                unit: None,
518                to: None,
519            }),
520
521            // Window functions inherit type from their function
522            Expression::WindowFunction(window) => self.annotate(&window.this),
523
524            // Date/time expressions
525            Expression::CurrentDate(_) => Some(DataType::Date),
526            Expression::CurrentTime(_) => Some(DataType::Time {
527                precision: None,
528                timezone: false,
529            }),
530            Expression::CurrentTimestamp(_) | Expression::CurrentTimestampLTZ(_) => {
531                Some(DataType::Timestamp {
532                    precision: None,
533                    timezone: false,
534                })
535            }
536
537            // Date functions
538            Expression::DateAdd(_)
539            | Expression::DateSub(_)
540            | Expression::ToDate(_)
541            | Expression::Date(_) => Some(DataType::Date),
542            Expression::DateDiff(_) | Expression::Extract(_) => Some(DataType::Int {
543                length: None,
544                integer_spelling: false,
545            }),
546            Expression::ToTimestamp(_) => Some(DataType::Timestamp {
547                precision: None,
548                timezone: false,
549            }),
550
551            // String functions
552            Expression::Upper(_)
553            | Expression::Lower(_)
554            | Expression::Trim(_)
555            | Expression::LTrim(_)
556            | Expression::RTrim(_)
557            | Expression::Replace(_)
558            | Expression::Substring(_)
559            | Expression::Reverse(_)
560            | Expression::Left(_)
561            | Expression::Right(_)
562            | Expression::Repeat(_)
563            | Expression::Lpad(_)
564            | Expression::Rpad(_)
565            | Expression::ConcatWs(_)
566            | Expression::Overlay(_) => Some(DataType::VarChar {
567                length: None,
568                parenthesized_length: false,
569            }),
570            Expression::Length(_) => Some(DataType::Int {
571                length: None,
572                integer_spelling: false,
573            }),
574
575            // Math functions
576            Expression::Abs(_)
577            | Expression::Sqrt(_)
578            | Expression::Cbrt(_)
579            | Expression::Ln(_)
580            | Expression::Exp(_)
581            | Expression::Power(_)
582            | Expression::Log(_) => Some(DataType::Double {
583                precision: None,
584                scale: None,
585            }),
586            Expression::Round(_) => Some(DataType::Double {
587                precision: None,
588                scale: None,
589            }),
590            Expression::Floor(f) => self.annotate_math_function(&f.this),
591            Expression::Ceil(f) => self.annotate_math_function(&f.this),
592            Expression::Sign(s) => self.annotate(&s.this),
593            Expression::DateFormat(_) | Expression::FormatDate(_) | Expression::TimeToStr(_) => {
594                Some(DataType::VarChar {
595                    length: None,
596                    parenthesized_length: false,
597                })
598            }
599
600            // Greatest/Least - coerce argument types
601            Expression::Greatest(v) | Expression::Least(v) => self.coerce_arg_types(&v.expressions),
602
603            // Alias - type of the inner expression
604            Expression::Alias(alias) => self.annotate(&alias.this),
605
606            // SELECT expressions - no scalar type
607            Expression::Select(_) => None,
608
609            // ============================================
610            // 3.1.8: Array/Map Indexing (Subscript/Bracket)
611            // ============================================
612            Expression::Subscript(sub) => self.annotate_subscript(sub),
613
614            // Dot access (struct.field) - returns Unknown without schema
615            Expression::Dot(_) => None,
616
617            // ============================================
618            // 3.1.9: STRUCT Construction
619            // ============================================
620            Expression::Struct(s) => self.annotate_struct(s),
621
622            // ============================================
623            // 3.1.10: MAP Construction
624            // ============================================
625            Expression::Map(map) => self.annotate_map(map),
626            Expression::MapFromEntries(mfe) => {
627                // MAP_FROM_ENTRIES(array_of_pairs) - infer from array element type
628                if let Some(DataType::Array { element_type, .. }) = self.annotate(&mfe.this) {
629                    if let DataType::Struct { fields, .. } = *element_type {
630                        if fields.len() >= 2 {
631                            return Some(DataType::Map {
632                                key_type: Box::new(fields[0].data_type.clone()),
633                                value_type: Box::new(fields[1].data_type.clone()),
634                            });
635                        }
636                    }
637                }
638                Some(DataType::Map {
639                    key_type: Box::new(DataType::Unknown),
640                    value_type: Box::new(DataType::Unknown),
641                })
642            }
643
644            // ============================================
645            // 3.1.11: SetOperation Type Coercion
646            // ============================================
647            Expression::Union(union) => self.annotate_set_operation(&union.left, &union.right),
648            Expression::Intersect(intersect) => {
649                self.annotate_set_operation(&intersect.left, &intersect.right)
650            }
651            Expression::Except(except) => self.annotate_set_operation(&except.left, &except.right),
652
653            // ============================================
654            // 3.1.12: UDTF Type Handling
655            // ============================================
656            Expression::Lateral(lateral) => {
657                // LATERAL subquery - type is the subquery's type
658                self.annotate(&lateral.this)
659            }
660            Expression::LateralView(lv) => {
661                // LATERAL VIEW - returns the exploded type
662                self.annotate_lateral_view(lv)
663            }
664            Expression::Unnest(unnest) => {
665                // UNNEST(array) - returns the element type of the array
666                if let Some(DataType::Array { element_type, .. }) = self.annotate(&unnest.this) {
667                    Some(*element_type)
668                } else {
669                    None
670                }
671            }
672            Expression::Explode(explode) => {
673                // EXPLODE(array) - returns the element type
674                if let Some(DataType::Array { element_type, .. }) = self.annotate(&explode.this) {
675                    Some(*element_type)
676                } else if let Some(DataType::Map {
677                    key_type,
678                    value_type,
679                }) = self.annotate(&explode.this)
680                {
681                    // EXPLODE(map) returns struct(key, value)
682                    Some(DataType::Struct {
683                        fields: vec![
684                            StructField::new("key".to_string(), *key_type),
685                            StructField::new("value".to_string(), *value_type),
686                        ],
687                        nested: false,
688                    })
689                } else {
690                    None
691                }
692            }
693            Expression::ExplodeOuter(explode) => {
694                // EXPLODE_OUTER - same as EXPLODE but preserves nulls
695                if let Some(DataType::Array { element_type, .. }) = self.annotate(&explode.this) {
696                    Some(*element_type)
697                } else {
698                    None
699                }
700            }
701            Expression::GenerateSeries(gs) => {
702                // GENERATE_SERIES returns the type of start/end
703                if let Some(ref start) = gs.start {
704                    self.annotate(start)
705                } else if let Some(ref end) = gs.end {
706                    self.annotate(end)
707                } else {
708                    Some(DataType::Int {
709                        length: None,
710                        integer_spelling: false,
711                    })
712                }
713            }
714
715            // Other expressions - unknown
716            _ => None,
717        }
718    }
719
720    /// Annotate types in-place on the expression tree (bottom-up).
721    ///
722    /// First recurses into children, then computes this node's type using the
723    /// read-only `annotate` method, and finally stores the result via
724    /// `set_inferred_type`.
725    pub fn annotate_in_place(&mut self, expr: &mut Expression) {
726        // 1. Recurse into children (bottom-up)
727        self.annotate_children_in_place(expr);
728
729        // 2. Compute this node's type using the read-only method
730        //    (children already have their types set, but `annotate` re-derives
731        //    from structure, which is fine since the structure hasn't changed)
732        let dt = self.annotate(expr);
733
734        // 3. Store on the node
735        if let Some(data_type) = dt {
736            expr.set_inferred_type(data_type);
737        }
738    }
739
740    /// Recursively annotate children of an expression in-place.
741    fn annotate_children_in_place(&mut self, expr: &mut Expression) {
742        match expr {
743            // Binary operations
744            Expression::And(op)
745            | Expression::Or(op)
746            | Expression::Add(op)
747            | Expression::Sub(op)
748            | Expression::Mul(op)
749            | Expression::Div(op)
750            | Expression::Mod(op)
751            | Expression::Eq(op)
752            | Expression::Neq(op)
753            | Expression::Lt(op)
754            | Expression::Lte(op)
755            | Expression::Gt(op)
756            | Expression::Gte(op)
757            | Expression::Concat(op)
758            | Expression::BitwiseAnd(op)
759            | Expression::BitwiseOr(op)
760            | Expression::BitwiseXor(op)
761            | Expression::Adjacent(op)
762            | Expression::TsMatch(op)
763            | Expression::PropertyEQ(op)
764            | Expression::ArrayContainsAll(op)
765            | Expression::ArrayContainedBy(op)
766            | Expression::ArrayOverlaps(op)
767            | Expression::JSONBContainsAllTopKeys(op)
768            | Expression::JSONBContainsAnyTopKeys(op)
769            | Expression::JSONBDeleteAtPath(op)
770            | Expression::ExtendsLeft(op)
771            | Expression::ExtendsRight(op)
772            | Expression::Is(op)
773            | Expression::MemberOf(op)
774            | Expression::Match(op)
775            | Expression::NullSafeEq(op)
776            | Expression::NullSafeNeq(op)
777            | Expression::Glob(op)
778            | Expression::BitwiseLeftShift(op)
779            | Expression::BitwiseRightShift(op) => {
780                self.annotate_in_place(&mut op.left);
781                self.annotate_in_place(&mut op.right);
782            }
783
784            // Like operations
785            Expression::Like(op) | Expression::ILike(op) => {
786                self.annotate_in_place(&mut op.left);
787                self.annotate_in_place(&mut op.right);
788            }
789
790            // Unary operations
791            Expression::Not(op) | Expression::Neg(op) | Expression::BitwiseNot(op) => {
792                self.annotate_in_place(&mut op.this);
793            }
794
795            // Cast
796            Expression::Cast(c) | Expression::TryCast(c) | Expression::SafeCast(c) => {
797                self.annotate_in_place(&mut c.this);
798            }
799
800            // Case
801            Expression::Case(c) => {
802                if let Some(ref mut operand) = c.operand {
803                    self.annotate_in_place(operand);
804                }
805                for (cond, then_expr) in &mut c.whens {
806                    self.annotate_in_place(cond);
807                    self.annotate_in_place(then_expr);
808                }
809                if let Some(ref mut else_expr) = c.else_ {
810                    self.annotate_in_place(else_expr);
811                }
812            }
813
814            // Alias
815            Expression::Alias(a) => {
816                self.annotate_in_place(&mut a.this);
817            }
818
819            // Column - leaf node, no children to recurse
820            Expression::Column(_) => {}
821
822            // Function
823            Expression::Function(f) => {
824                for arg in &mut f.args {
825                    self.annotate_in_place(arg);
826                }
827            }
828
829            // Dedicated conditional functions
830            Expression::IfFunc(f) => {
831                self.annotate_in_place(&mut f.condition);
832                self.annotate_in_place(&mut f.true_value);
833                if let Some(false_value) = &mut f.false_value {
834                    self.annotate_in_place(false_value);
835                }
836            }
837            Expression::Nvl2(f) => {
838                self.annotate_in_place(&mut f.this);
839                self.annotate_in_place(&mut f.true_value);
840                self.annotate_in_place(&mut f.false_value);
841            }
842
843            // AggregateFunction
844            Expression::AggregateFunction(f) => {
845                for arg in &mut f.args {
846                    self.annotate_in_place(arg);
847                }
848            }
849
850            // Dedicated aggregate / string functions
851            Expression::Count(f) => {
852                if let Some(this) = &mut f.this {
853                    self.annotate_in_place(this);
854                }
855                if let Some(filter) = &mut f.filter {
856                    self.annotate_in_place(filter);
857                }
858            }
859            Expression::GroupConcat(f) => {
860                self.annotate_in_place(&mut f.this);
861                if let Some(separator) = &mut f.separator {
862                    self.annotate_in_place(separator);
863                }
864                if let Some(order_by) = &mut f.order_by {
865                    for ordered in order_by {
866                        self.annotate_in_place(&mut ordered.this);
867                    }
868                }
869                if let Some(filter) = &mut f.filter {
870                    self.annotate_in_place(filter);
871                }
872            }
873            Expression::StringAgg(f) => {
874                self.annotate_in_place(&mut f.this);
875                if let Some(separator) = &mut f.separator {
876                    self.annotate_in_place(separator);
877                }
878                if let Some(order_by) = &mut f.order_by {
879                    for ordered in order_by {
880                        self.annotate_in_place(&mut ordered.this);
881                    }
882                }
883                if let Some(filter) = &mut f.filter {
884                    self.annotate_in_place(filter);
885                }
886                if let Some(limit) = &mut f.limit {
887                    self.annotate_in_place(limit);
888                }
889            }
890            Expression::ListAgg(f) => {
891                self.annotate_in_place(&mut f.this);
892                if let Some(separator) = &mut f.separator {
893                    self.annotate_in_place(separator);
894                }
895                if let Some(order_by) = &mut f.order_by {
896                    for ordered in order_by {
897                        self.annotate_in_place(&mut ordered.this);
898                    }
899                }
900                if let Some(filter) = &mut f.filter {
901                    self.annotate_in_place(filter);
902                }
903                if let Some(ListAggOverflow::Truncate {
904                    filler: Some(filler),
905                    ..
906                }) = &mut f.on_overflow
907                {
908                    self.annotate_in_place(filler);
909                }
910            }
911            Expression::SumIf(f) => {
912                self.annotate_in_place(&mut f.this);
913                self.annotate_in_place(&mut f.condition);
914                if let Some(filter) = &mut f.filter {
915                    self.annotate_in_place(filter);
916                }
917            }
918
919            // WindowFunction
920            Expression::WindowFunction(w) => {
921                self.annotate_in_place(&mut w.this);
922            }
923
924            // Subquery
925            Expression::Subquery(s) => {
926                self.annotate_in_place(&mut s.this);
927            }
928
929            // UnaryFunc variants
930            Expression::Upper(f)
931            | Expression::Lower(f)
932            | Expression::Length(f)
933            | Expression::LTrim(f)
934            | Expression::RTrim(f)
935            | Expression::Reverse(f)
936            | Expression::Abs(f)
937            | Expression::Sqrt(f)
938            | Expression::Cbrt(f)
939            | Expression::Ln(f)
940            | Expression::Exp(f)
941            | Expression::Sign(f)
942            | Expression::Date(f)
943            | Expression::Time(f)
944            | Expression::Explode(f)
945            | Expression::ExplodeOuter(f)
946            | Expression::MapFromEntries(f)
947            | Expression::MapKeys(f)
948            | Expression::MapValues(f)
949            | Expression::ArrayLength(f)
950            | Expression::ArraySize(f)
951            | Expression::Cardinality(f)
952            | Expression::ArrayReverse(f)
953            | Expression::ArrayDistinct(f)
954            | Expression::ArrayFlatten(f)
955            | Expression::ArrayCompact(f)
956            | Expression::ToArray(f)
957            | Expression::JsonArrayLength(f)
958            | Expression::JsonKeys(f)
959            | Expression::JsonType(f)
960            | Expression::ParseJson(f)
961            | Expression::ToJson(f)
962            | Expression::Year(f)
963            | Expression::Month(f)
964            | Expression::Day(f)
965            | Expression::Hour(f)
966            | Expression::Minute(f)
967            | Expression::Second(f)
968            | Expression::Initcap(f)
969            | Expression::Ascii(f)
970            | Expression::Chr(f)
971            | Expression::Soundex(f)
972            | Expression::ByteLength(f)
973            | Expression::Hex(f)
974            | Expression::LowerHex(f)
975            | Expression::Unicode(f)
976            | Expression::Typeof(f)
977            | Expression::BitwiseCount(f)
978            | Expression::Epoch(f)
979            | Expression::EpochMs(f)
980            | Expression::Radians(f)
981            | Expression::Degrees(f)
982            | Expression::Sin(f)
983            | Expression::Cos(f)
984            | Expression::Tan(f)
985            | Expression::Asin(f)
986            | Expression::Acos(f)
987            | Expression::Atan(f)
988            | Expression::IsNan(f)
989            | Expression::IsInf(f) => {
990                self.annotate_in_place(&mut f.this);
991            }
992
993            // BinaryFunc variants
994            Expression::Power(f)
995            | Expression::NullIf(f)
996            | Expression::IfNull(f)
997            | Expression::Nvl(f)
998            | Expression::Contains(f)
999            | Expression::StartsWith(f)
1000            | Expression::EndsWith(f)
1001            | Expression::Levenshtein(f)
1002            | Expression::ModFunc(f)
1003            | Expression::IntDiv(f)
1004            | Expression::Atan2(f)
1005            | Expression::AddMonths(f)
1006            | Expression::MonthsBetween(f)
1007            | Expression::NextDay(f)
1008            | Expression::UnixToTimeStr(f)
1009            | Expression::ArrayContains(f)
1010            | Expression::ArrayPosition(f)
1011            | Expression::ArrayAppend(f)
1012            | Expression::ArrayPrepend(f)
1013            | Expression::ArrayUnion(f)
1014            | Expression::ArrayExcept(f)
1015            | Expression::ArrayRemove(f)
1016            | Expression::StarMap(f)
1017            | Expression::MapFromArrays(f)
1018            | Expression::MapContainsKey(f)
1019            | Expression::ElementAt(f)
1020            | Expression::JsonMergePatch(f) => {
1021                self.annotate_in_place(&mut f.this);
1022                self.annotate_in_place(&mut f.expression);
1023            }
1024
1025            // VarArgFunc variants
1026            Expression::Coalesce(f)
1027            | Expression::Greatest(f)
1028            | Expression::Least(f)
1029            | Expression::ArrayConcat(f)
1030            | Expression::ArrayIntersect(f)
1031            | Expression::ArrayZip(f)
1032            | Expression::MapConcat(f)
1033            | Expression::JsonArray(f) => {
1034                for e in &mut f.expressions {
1035                    self.annotate_in_place(e);
1036                }
1037            }
1038
1039            // AggFunc variants
1040            Expression::Sum(f)
1041            | Expression::Avg(f)
1042            | Expression::Min(f)
1043            | Expression::Max(f)
1044            | Expression::ArrayAgg(f)
1045            | Expression::CountIf(f)
1046            | Expression::Stddev(f)
1047            | Expression::StddevPop(f)
1048            | Expression::StddevSamp(f)
1049            | Expression::Variance(f)
1050            | Expression::VarPop(f)
1051            | Expression::VarSamp(f)
1052            | Expression::Median(f)
1053            | Expression::Mode(f)
1054            | Expression::First(f)
1055            | Expression::Last(f)
1056            | Expression::AnyValue(f)
1057            | Expression::ApproxDistinct(f)
1058            | Expression::ApproxCountDistinct(f)
1059            | Expression::LogicalAnd(f)
1060            | Expression::LogicalOr(f)
1061            | Expression::Skewness(f)
1062            | Expression::ArrayConcatAgg(f)
1063            | Expression::ArrayUniqueAgg(f)
1064            | Expression::BoolXorAgg(f)
1065            | Expression::BitwiseAndAgg(f)
1066            | Expression::BitwiseOrAgg(f)
1067            | Expression::BitwiseXorAgg(f) => {
1068                self.annotate_in_place(&mut f.this);
1069            }
1070
1071            // Select - recurse into expressions
1072            Expression::Select(s) => {
1073                for e in &mut s.expressions {
1074                    self.annotate_in_place(e);
1075                }
1076            }
1077
1078            // Everything else - no children to recurse or not value-producing
1079            _ => {}
1080        }
1081    }
1082
1083    /// Annotate math functions like FLOOR/CEIL that return Double for integer inputs
1084    /// and preserve the input type otherwise (matching sqlglot's _annotate_math_functions).
1085    fn annotate_math_function(&mut self, arg: &Expression) -> Option<DataType> {
1086        let input_type = self.annotate(arg)?;
1087        match input_type {
1088            DataType::TinyInt { .. }
1089            | DataType::SmallInt { .. }
1090            | DataType::Int { .. }
1091            | DataType::BigInt { .. } => Some(DataType::Double {
1092                precision: None,
1093                scale: None,
1094            }),
1095            other => Some(other),
1096        }
1097    }
1098
1099    /// Annotate a subscript/bracket expression (array[index] or map[key])
1100    fn annotate_subscript(&mut self, sub: &Subscript) -> Option<DataType> {
1101        let base_type = self.annotate(&sub.this)?;
1102
1103        match base_type {
1104            DataType::Array { element_type, .. } => Some(*element_type),
1105            DataType::Map { value_type, .. } => Some(*value_type),
1106            DataType::Json | DataType::JsonB => Some(DataType::Json), // JSON indexing returns JSON
1107            DataType::VarChar { .. } | DataType::Text => {
1108                // String indexing returns a character
1109                Some(DataType::VarChar {
1110                    length: Some(1),
1111                    parenthesized_length: false,
1112                })
1113            }
1114            _ => None,
1115        }
1116    }
1117
1118    /// Annotate a STRUCT literal
1119    fn annotate_struct(&mut self, s: &Struct) -> Option<DataType> {
1120        let fields: Vec<StructField> = s
1121            .fields
1122            .iter()
1123            .map(|(name, expr)| {
1124                let field_type = self.annotate(expr).unwrap_or(DataType::Unknown);
1125                StructField::new(name.clone().unwrap_or_default(), field_type)
1126            })
1127            .collect();
1128        Some(DataType::Struct {
1129            fields,
1130            nested: false,
1131        })
1132    }
1133
1134    /// Annotate a MAP literal
1135    fn annotate_map(&mut self, map: &Map) -> Option<DataType> {
1136        let key_type = if let Some(first_key) = map.keys.first() {
1137            self.annotate(first_key).unwrap_or(DataType::Unknown)
1138        } else {
1139            DataType::Unknown
1140        };
1141
1142        let value_type = if let Some(first_value) = map.values.first() {
1143            self.annotate(first_value).unwrap_or(DataType::Unknown)
1144        } else {
1145            DataType::Unknown
1146        };
1147
1148        Some(DataType::Map {
1149            key_type: Box::new(key_type),
1150            value_type: Box::new(value_type),
1151        })
1152    }
1153
1154    /// Annotate a SetOperation (UNION/INTERSECT/EXCEPT)
1155    /// Returns None since set operations produce relation types, not scalar types
1156    fn annotate_set_operation(
1157        &mut self,
1158        _left: &Expression,
1159        _right: &Expression,
1160    ) -> Option<DataType> {
1161        // Set operations produce relations, not scalar types
1162        // The column types would be coerced between left and right
1163        // For now, return None as this is a relation-level type
1164        None
1165    }
1166
1167    /// Annotate a LATERAL VIEW expression
1168    fn annotate_lateral_view(&mut self, lv: &crate::expressions::LateralView) -> Option<DataType> {
1169        // The type depends on the table-generating function
1170        self.annotate(&lv.this)
1171    }
1172
1173    /// Annotate a literal value
1174    fn annotate_literal(&self, lit: &Literal) -> Option<DataType> {
1175        match lit {
1176            Literal::String(_)
1177            | Literal::NationalString(_)
1178            | Literal::TripleQuotedString(_, _)
1179            | Literal::EscapeString(_)
1180            | Literal::DollarString(_)
1181            | Literal::RawString(_) => Some(DataType::VarChar {
1182                length: None,
1183                parenthesized_length: false,
1184            }),
1185            Literal::Number(n) => {
1186                // Try to determine if it's an integer or float
1187                if n.contains('.') || n.contains('e') || n.contains('E') {
1188                    Some(DataType::Double {
1189                        precision: None,
1190                        scale: None,
1191                    })
1192                } else {
1193                    // Check if it fits in an Int or needs BigInt
1194                    if let Ok(_) = n.parse::<i32>() {
1195                        Some(DataType::Int {
1196                            length: None,
1197                            integer_spelling: false,
1198                        })
1199                    } else {
1200                        Some(DataType::BigInt { length: None })
1201                    }
1202                }
1203            }
1204            Literal::HexString(_) | Literal::BitString(_) | Literal::ByteString(_) => {
1205                Some(DataType::VarBinary { length: None })
1206            }
1207            Literal::HexNumber(_) => Some(DataType::BigInt { length: None }),
1208            Literal::Date(_) => Some(DataType::Date),
1209            Literal::Time(_) => Some(DataType::Time {
1210                precision: None,
1211                timezone: false,
1212            }),
1213            Literal::Timestamp(_) => Some(DataType::Timestamp {
1214                precision: None,
1215                timezone: false,
1216            }),
1217            Literal::Datetime(_) => Some(DataType::Custom {
1218                name: "DATETIME".to_string(),
1219            }),
1220        }
1221    }
1222
1223    /// Annotate an arithmetic binary operation
1224    fn annotate_arithmetic(&mut self, op: &BinaryOp) -> Option<DataType> {
1225        let left_type = self.annotate(&op.left);
1226        let right_type = self.annotate(&op.right);
1227
1228        match (left_type, right_type) {
1229            (Some(l), Some(r)) => self.coerce_types(&l, &r),
1230            (Some(t), None) | (None, Some(t)) => Some(t),
1231            (None, None) => None,
1232        }
1233    }
1234
1235    /// Annotate a function call
1236    fn annotate_function(&mut self, func: &Function) -> Option<DataType> {
1237        let func_name = func.name.to_uppercase();
1238
1239        // Check known function return types
1240        if let Some(return_type) = self.function_return_types.get(&func_name) {
1241            if *return_type != DataType::Unknown {
1242                return Some(return_type.clone());
1243            }
1244        }
1245
1246        // For functions with Unknown return type, infer from arguments
1247        match func_name.as_str() {
1248            "COALESCE" | "IFNULL" | "NVL" | "ISNULL" => {
1249                // Return type of first non-null argument
1250                for arg in &func.args {
1251                    if let Some(arg_type) = self.annotate(arg) {
1252                        return Some(arg_type);
1253                    }
1254                }
1255                None
1256            }
1257            "NULLIF" => {
1258                // Return type of first argument
1259                func.args.first().and_then(|arg| self.annotate(arg))
1260            }
1261            "GREATEST" | "LEAST" => {
1262                // Coerce all argument types
1263                self.coerce_arg_types(&func.args)
1264            }
1265            "IF" | "IIF" => {
1266                // Return type of THEN/ELSE branches
1267                if func.args.len() >= 2 {
1268                    self.annotate(&func.args[1])
1269                } else {
1270                    None
1271                }
1272            }
1273            _ => {
1274                // Unknown function - try to infer from first argument
1275                func.args.first().and_then(|arg| self.annotate(arg))
1276            }
1277        }
1278    }
1279
1280    /// Annotate IF/IIF/IFF conditional function
1281    fn annotate_if_func(&mut self, func: &IfFunc) -> Option<DataType> {
1282        let true_type = self.annotate(&func.true_value);
1283        let false_type = func
1284            .false_value
1285            .as_ref()
1286            .and_then(|expr| self.annotate(expr));
1287
1288        match (true_type, false_type) {
1289            (Some(left), Some(right)) => self.coerce_types(&left, &right),
1290            (Some(dt), None) | (None, Some(dt)) => Some(dt),
1291            (None, None) => None,
1292        }
1293    }
1294
1295    /// Annotate NVL2 conditional function from its true/false branches
1296    fn annotate_nvl2(&mut self, func: &Nvl2Func) -> Option<DataType> {
1297        let true_type = self.annotate(&func.true_value);
1298        let false_type = self.annotate(&func.false_value);
1299
1300        match (true_type, false_type) {
1301            (Some(left), Some(right)) => self.coerce_types(&left, &right),
1302            (Some(dt), None) | (None, Some(dt)) => Some(dt),
1303            (None, None) => None,
1304        }
1305    }
1306
1307    /// Get return type for aggregate functions
1308    fn get_aggregate_return_type(
1309        &mut self,
1310        func_name: &str,
1311        args: &[Expression],
1312    ) -> Option<DataType> {
1313        match func_name {
1314            "COUNT" | "COUNT_IF" => Some(DataType::BigInt { length: None }),
1315            "SUM_IF" => {
1316                if let Some(arg) = args.first() {
1317                    self.annotate_sum(arg)
1318                } else {
1319                    Some(DataType::Decimal {
1320                        precision: None,
1321                        scale: None,
1322                    })
1323                }
1324            }
1325            "SUM" => {
1326                if let Some(arg) = args.first() {
1327                    self.annotate_sum(arg)
1328                } else {
1329                    Some(DataType::Decimal {
1330                        precision: None,
1331                        scale: None,
1332                    })
1333                }
1334            }
1335            "AVG" => Some(DataType::Double {
1336                precision: None,
1337                scale: None,
1338            }),
1339            "MIN" | "MAX" => {
1340                // Preserves input type
1341                args.first().and_then(|arg| self.annotate(arg))
1342            }
1343            "STRING_AGG" | "GROUP_CONCAT" | "LISTAGG" | "ARRAY_AGG" => Some(DataType::VarChar {
1344                length: None,
1345                parenthesized_length: false,
1346            }),
1347            "BOOL_AND" | "BOOL_OR" | "EVERY" | "ANY" | "SOME" => Some(DataType::Boolean),
1348            "BIT_AND" | "BIT_OR" | "BIT_XOR" => Some(DataType::BigInt { length: None }),
1349            "STDDEV" | "STDDEV_POP" | "STDDEV_SAMP" | "VARIANCE" | "VAR_POP" | "VAR_SAMP" => {
1350                Some(DataType::Double {
1351                    precision: None,
1352                    scale: None,
1353                })
1354            }
1355            "PERCENTILE_CONT" | "PERCENTILE_DISC" | "MEDIAN" => {
1356                args.first().and_then(|arg| self.annotate(arg))
1357            }
1358            _ => None,
1359        }
1360    }
1361
1362    /// Annotate SUM function - promotes to at least BigInt
1363    fn annotate_sum(&mut self, arg: &Expression) -> Option<DataType> {
1364        match self.annotate(arg) {
1365            Some(DataType::TinyInt { .. })
1366            | Some(DataType::SmallInt { .. })
1367            | Some(DataType::Int { .. }) => Some(DataType::BigInt { length: None }),
1368            Some(DataType::BigInt { .. }) => Some(DataType::BigInt { length: None }),
1369            Some(DataType::Float { .. }) | Some(DataType::Double { .. }) => {
1370                Some(DataType::Double {
1371                    precision: None,
1372                    scale: None,
1373                })
1374            }
1375            Some(DataType::Decimal { precision, scale }) => {
1376                Some(DataType::Decimal { precision, scale })
1377            }
1378            _ => Some(DataType::Decimal {
1379                precision: None,
1380                scale: None,
1381            }),
1382        }
1383    }
1384
1385    /// Coerce multiple argument types to a common type
1386    fn coerce_arg_types(&mut self, args: &[Expression]) -> Option<DataType> {
1387        let mut result_type: Option<DataType> = None;
1388        for arg in args {
1389            if let Some(arg_type) = self.annotate(arg) {
1390                result_type = match result_type {
1391                    Some(t) => self.coerce_types(&t, &arg_type),
1392                    None => Some(arg_type),
1393                };
1394            }
1395        }
1396        result_type
1397    }
1398
1399    /// Coerce two types to a common type
1400    fn coerce_types(&self, left: &DataType, right: &DataType) -> Option<DataType> {
1401        // If types are the same, return that type
1402        if left == right {
1403            return Some(left.clone());
1404        }
1405
1406        // Special case: Interval + Date/Timestamp
1407        match (left, right) {
1408            (DataType::Date, DataType::Interval { .. })
1409            | (DataType::Interval { .. }, DataType::Date) => return Some(DataType::Date),
1410            (
1411                DataType::Timestamp {
1412                    precision,
1413                    timezone,
1414                },
1415                DataType::Interval { .. },
1416            )
1417            | (
1418                DataType::Interval { .. },
1419                DataType::Timestamp {
1420                    precision,
1421                    timezone,
1422                },
1423            ) => {
1424                return Some(DataType::Timestamp {
1425                    precision: *precision,
1426                    timezone: *timezone,
1427                });
1428            }
1429            _ => {}
1430        }
1431
1432        // Coerce based on class
1433        let left_class = TypeCoercionClass::from_data_type(left);
1434        let right_class = TypeCoercionClass::from_data_type(right);
1435
1436        match (left_class, right_class) {
1437            // Same class: use higher-precision type within class
1438            (Some(lc), Some(rc)) if lc == rc => {
1439                // For numeric, choose wider type
1440                if lc == TypeCoercionClass::Numeric {
1441                    Some(self.wider_numeric_type(left, right))
1442                } else {
1443                    // For text and timelike, left wins by default
1444                    Some(left.clone())
1445                }
1446            }
1447            // Different classes: higher-priority class wins
1448            (Some(lc), Some(rc)) => {
1449                if lc > rc {
1450                    Some(left.clone())
1451                } else {
1452                    Some(right.clone())
1453                }
1454            }
1455            // One unknown: use the known type
1456            (Some(_), None) => Some(left.clone()),
1457            (None, Some(_)) => Some(right.clone()),
1458            // Both unknown: return unknown
1459            (None, None) => Some(DataType::Unknown),
1460        }
1461    }
1462
1463    /// Get the wider numeric type
1464    fn wider_numeric_type(&self, left: &DataType, right: &DataType) -> DataType {
1465        let order = |dt: &DataType| -> u8 {
1466            match dt {
1467                DataType::Boolean => 0,
1468                DataType::TinyInt { .. } => 1,
1469                DataType::SmallInt { .. } => 2,
1470                DataType::Int { .. } => 3,
1471                DataType::BigInt { .. } => 4,
1472                DataType::Float { .. } => 5,
1473                DataType::Double { .. } => 6,
1474                DataType::Decimal { .. } => 7,
1475                _ => 0,
1476            }
1477        };
1478
1479        if order(left) >= order(right) {
1480            left.clone()
1481        } else {
1482            right.clone()
1483        }
1484    }
1485}
1486
1487/// Annotate types in-place on the expression tree.
1488///
1489/// Walks the AST bottom-up and sets `inferred_type` on each value-producing
1490/// node. After this call, `expr.inferred_type()` (and the same on any child
1491/// node) returns the inferred type.
1492pub fn annotate_types(
1493    expr: &mut Expression,
1494    schema: Option<&dyn Schema>,
1495    dialect: Option<DialectType>,
1496) {
1497    let mut annotator = TypeAnnotator::new(schema, dialect);
1498    annotator.annotate_in_place(expr);
1499}
1500
1501#[cfg(test)]
1502mod tests {
1503    use super::*;
1504    use crate::expressions::{BooleanLiteral, Cast, Null};
1505    use crate::{parse_one, DialectType, MappingSchema, Schema};
1506
1507    fn make_int_literal(val: i64) -> Expression {
1508        Expression::Literal(Literal::Number(val.to_string()))
1509    }
1510
1511    fn make_float_literal(val: f64) -> Expression {
1512        Expression::Literal(Literal::Number(val.to_string()))
1513    }
1514
1515    fn make_string_literal(val: &str) -> Expression {
1516        Expression::Literal(Literal::String(val.to_string()))
1517    }
1518
1519    fn make_bool_literal(val: bool) -> Expression {
1520        Expression::Boolean(BooleanLiteral { value: val })
1521    }
1522
1523    #[test]
1524    fn test_literal_types() {
1525        let mut annotator = TypeAnnotator::new(None, None);
1526
1527        // Integer literal
1528        let int_expr = make_int_literal(42);
1529        assert_eq!(
1530            annotator.annotate(&int_expr),
1531            Some(DataType::Int {
1532                length: None,
1533                integer_spelling: false
1534            })
1535        );
1536
1537        // Float literal
1538        let float_expr = make_float_literal(3.14);
1539        assert_eq!(
1540            annotator.annotate(&float_expr),
1541            Some(DataType::Double {
1542                precision: None,
1543                scale: None
1544            })
1545        );
1546
1547        // String literal
1548        let string_expr = make_string_literal("hello");
1549        assert_eq!(
1550            annotator.annotate(&string_expr),
1551            Some(DataType::VarChar {
1552                length: None,
1553                parenthesized_length: false
1554            })
1555        );
1556
1557        // Boolean literal
1558        let bool_expr = make_bool_literal(true);
1559        assert_eq!(annotator.annotate(&bool_expr), Some(DataType::Boolean));
1560
1561        // Null literal
1562        let null_expr = Expression::Null(Null);
1563        assert_eq!(annotator.annotate(&null_expr), None);
1564    }
1565
1566    #[test]
1567    fn test_comparison_types() {
1568        let mut annotator = TypeAnnotator::new(None, None);
1569
1570        // Comparison returns boolean
1571        let cmp = Expression::Gt(Box::new(BinaryOp::new(
1572            make_int_literal(1),
1573            make_int_literal(2),
1574        )));
1575        assert_eq!(annotator.annotate(&cmp), Some(DataType::Boolean));
1576
1577        // Equality returns boolean
1578        let eq = Expression::Eq(Box::new(BinaryOp::new(
1579            make_string_literal("a"),
1580            make_string_literal("b"),
1581        )));
1582        assert_eq!(annotator.annotate(&eq), Some(DataType::Boolean));
1583    }
1584
1585    #[test]
1586    fn test_arithmetic_types() {
1587        let mut annotator = TypeAnnotator::new(None, None);
1588
1589        // Int + Int = Int
1590        let add_int = Expression::Add(Box::new(BinaryOp::new(
1591            make_int_literal(1),
1592            make_int_literal(2),
1593        )));
1594        assert_eq!(
1595            annotator.annotate(&add_int),
1596            Some(DataType::Int {
1597                length: None,
1598                integer_spelling: false
1599            })
1600        );
1601
1602        // Int + Float = Double (wider type)
1603        let add_mixed = Expression::Add(Box::new(BinaryOp::new(
1604            make_int_literal(1),
1605            make_float_literal(2.5), // Use 2.5 so the string has a decimal point
1606        )));
1607        assert_eq!(
1608            annotator.annotate(&add_mixed),
1609            Some(DataType::Double {
1610                precision: None,
1611                scale: None
1612            })
1613        );
1614    }
1615
1616    #[test]
1617    fn test_string_concat_type() {
1618        let mut annotator = TypeAnnotator::new(None, None);
1619
1620        // String || String = VarChar
1621        let concat = Expression::Concat(Box::new(BinaryOp::new(
1622            make_string_literal("hello"),
1623            make_string_literal(" world"),
1624        )));
1625        assert_eq!(
1626            annotator.annotate(&concat),
1627            Some(DataType::VarChar {
1628                length: None,
1629                parenthesized_length: false
1630            })
1631        );
1632    }
1633
1634    #[test]
1635    fn test_cast_type() {
1636        let mut annotator = TypeAnnotator::new(None, None);
1637
1638        // CAST(1 AS VARCHAR)
1639        let cast = Expression::Cast(Box::new(Cast {
1640            this: make_int_literal(1),
1641            to: DataType::VarChar {
1642                length: Some(10),
1643                parenthesized_length: false,
1644            },
1645            trailing_comments: vec![],
1646            double_colon_syntax: false,
1647            format: None,
1648            default: None,
1649            inferred_type: None,
1650        }));
1651        assert_eq!(
1652            annotator.annotate(&cast),
1653            Some(DataType::VarChar {
1654                length: Some(10),
1655                parenthesized_length: false
1656            })
1657        );
1658    }
1659
1660    #[test]
1661    fn test_function_types() {
1662        let mut annotator = TypeAnnotator::new(None, None);
1663
1664        // COUNT returns BigInt
1665        let count =
1666            Expression::Function(Box::new(Function::new("COUNT", vec![make_int_literal(1)])));
1667        assert_eq!(
1668            annotator.annotate(&count),
1669            Some(DataType::BigInt { length: None })
1670        );
1671
1672        // UPPER returns VarChar
1673        let upper = Expression::Function(Box::new(Function::new(
1674            "UPPER",
1675            vec![make_string_literal("hello")],
1676        )));
1677        assert_eq!(
1678            annotator.annotate(&upper),
1679            Some(DataType::VarChar {
1680                length: None,
1681                parenthesized_length: false
1682            })
1683        );
1684
1685        // NOW returns Timestamp
1686        let now = Expression::Function(Box::new(Function::new("NOW", vec![])));
1687        assert_eq!(
1688            annotator.annotate(&now),
1689            Some(DataType::Timestamp {
1690                precision: None,
1691                timezone: false
1692            })
1693        );
1694    }
1695
1696    #[test]
1697    fn test_coalesce_type_inference() {
1698        let mut annotator = TypeAnnotator::new(None, None);
1699
1700        // COALESCE(NULL, 1) returns Int (type of first non-null arg)
1701        let coalesce = Expression::Function(Box::new(Function::new(
1702            "COALESCE",
1703            vec![Expression::Null(Null), make_int_literal(1)],
1704        )));
1705        assert_eq!(
1706            annotator.annotate(&coalesce),
1707            Some(DataType::Int {
1708                length: None,
1709                integer_spelling: false
1710            })
1711        );
1712    }
1713
1714    #[test]
1715    fn test_type_coercion_class() {
1716        // Text types
1717        assert_eq!(
1718            TypeCoercionClass::from_data_type(&DataType::VarChar {
1719                length: None,
1720                parenthesized_length: false
1721            }),
1722            Some(TypeCoercionClass::Text)
1723        );
1724        assert_eq!(
1725            TypeCoercionClass::from_data_type(&DataType::Text),
1726            Some(TypeCoercionClass::Text)
1727        );
1728
1729        // Numeric types
1730        assert_eq!(
1731            TypeCoercionClass::from_data_type(&DataType::Int {
1732                length: None,
1733                integer_spelling: false
1734            }),
1735            Some(TypeCoercionClass::Numeric)
1736        );
1737        assert_eq!(
1738            TypeCoercionClass::from_data_type(&DataType::Double {
1739                precision: None,
1740                scale: None
1741            }),
1742            Some(TypeCoercionClass::Numeric)
1743        );
1744
1745        // Timelike types
1746        assert_eq!(
1747            TypeCoercionClass::from_data_type(&DataType::Date),
1748            Some(TypeCoercionClass::Timelike)
1749        );
1750        assert_eq!(
1751            TypeCoercionClass::from_data_type(&DataType::Timestamp {
1752                precision: None,
1753                timezone: false
1754            }),
1755            Some(TypeCoercionClass::Timelike)
1756        );
1757
1758        // Unknown types
1759        assert_eq!(TypeCoercionClass::from_data_type(&DataType::Json), None);
1760    }
1761
1762    #[test]
1763    fn test_wider_numeric_type() {
1764        let annotator = TypeAnnotator::new(None, None);
1765
1766        // Int vs BigInt -> BigInt
1767        let result = annotator.wider_numeric_type(
1768            &DataType::Int {
1769                length: None,
1770                integer_spelling: false,
1771            },
1772            &DataType::BigInt { length: None },
1773        );
1774        assert_eq!(result, DataType::BigInt { length: None });
1775
1776        // Float vs Double -> Double
1777        let result = annotator.wider_numeric_type(
1778            &DataType::Float {
1779                precision: None,
1780                scale: None,
1781                real_spelling: false,
1782            },
1783            &DataType::Double {
1784                precision: None,
1785                scale: None,
1786            },
1787        );
1788        assert_eq!(
1789            result,
1790            DataType::Double {
1791                precision: None,
1792                scale: None
1793            }
1794        );
1795
1796        // Int vs Double -> Double
1797        let result = annotator.wider_numeric_type(
1798            &DataType::Int {
1799                length: None,
1800                integer_spelling: false,
1801            },
1802            &DataType::Double {
1803                precision: None,
1804                scale: None,
1805            },
1806        );
1807        assert_eq!(
1808            result,
1809            DataType::Double {
1810                precision: None,
1811                scale: None
1812            }
1813        );
1814    }
1815
1816    #[test]
1817    fn test_aggregate_return_types() {
1818        let mut annotator = TypeAnnotator::new(None, None);
1819
1820        // SUM(int) returns BigInt
1821        let sum_type = annotator.get_aggregate_return_type("SUM", &[make_int_literal(1)]);
1822        assert_eq!(sum_type, Some(DataType::BigInt { length: None }));
1823
1824        // AVG always returns Double
1825        let avg_type = annotator.get_aggregate_return_type("AVG", &[make_int_literal(1)]);
1826        assert_eq!(
1827            avg_type,
1828            Some(DataType::Double {
1829                precision: None,
1830                scale: None
1831            })
1832        );
1833
1834        // MIN/MAX preserve input type
1835        let min_type = annotator.get_aggregate_return_type("MIN", &[make_string_literal("a")]);
1836        assert_eq!(
1837            min_type,
1838            Some(DataType::VarChar {
1839                length: None,
1840                parenthesized_length: false
1841            })
1842        );
1843    }
1844
1845    #[test]
1846    fn test_date_literal_types() {
1847        let mut annotator = TypeAnnotator::new(None, None);
1848
1849        // DATE literal
1850        let date_expr = Expression::Literal(Literal::Date("2024-01-15".to_string()));
1851        assert_eq!(annotator.annotate(&date_expr), Some(DataType::Date));
1852
1853        // TIME literal
1854        let time_expr = Expression::Literal(Literal::Time("10:30:00".to_string()));
1855        assert_eq!(
1856            annotator.annotate(&time_expr),
1857            Some(DataType::Time {
1858                precision: None,
1859                timezone: false
1860            })
1861        );
1862
1863        // TIMESTAMP literal
1864        let ts_expr = Expression::Literal(Literal::Timestamp("2024-01-15 10:30:00".to_string()));
1865        assert_eq!(
1866            annotator.annotate(&ts_expr),
1867            Some(DataType::Timestamp {
1868                precision: None,
1869                timezone: false
1870            })
1871        );
1872    }
1873
1874    #[test]
1875    fn test_logical_operations() {
1876        let mut annotator = TypeAnnotator::new(None, None);
1877
1878        // AND returns boolean
1879        let and_expr = Expression::And(Box::new(BinaryOp::new(
1880            make_bool_literal(true),
1881            make_bool_literal(false),
1882        )));
1883        assert_eq!(annotator.annotate(&and_expr), Some(DataType::Boolean));
1884
1885        // OR returns boolean
1886        let or_expr = Expression::Or(Box::new(BinaryOp::new(
1887            make_bool_literal(true),
1888            make_bool_literal(false),
1889        )));
1890        assert_eq!(annotator.annotate(&or_expr), Some(DataType::Boolean));
1891
1892        // NOT returns boolean
1893        let not_expr = Expression::Not(Box::new(crate::expressions::UnaryOp::new(
1894            make_bool_literal(true),
1895        )));
1896        assert_eq!(annotator.annotate(&not_expr), Some(DataType::Boolean));
1897    }
1898
1899    // ========================================
1900    // Tests for newly implemented features
1901    // ========================================
1902
1903    #[test]
1904    fn test_subscript_array_type() {
1905        let mut annotator = TypeAnnotator::new(None, None);
1906
1907        // Array[index] returns element type
1908        let arr = Expression::Array(Box::new(crate::expressions::Array {
1909            expressions: vec![make_int_literal(1), make_int_literal(2)],
1910        }));
1911        let subscript = Expression::Subscript(Box::new(crate::expressions::Subscript {
1912            this: arr,
1913            index: make_int_literal(0),
1914        }));
1915        assert_eq!(
1916            annotator.annotate(&subscript),
1917            Some(DataType::Int {
1918                length: None,
1919                integer_spelling: false
1920            })
1921        );
1922    }
1923
1924    #[test]
1925    fn test_subscript_map_type() {
1926        let mut annotator = TypeAnnotator::new(None, None);
1927
1928        // Map[key] returns value type
1929        let map = Expression::Map(Box::new(crate::expressions::Map {
1930            keys: vec![make_string_literal("a")],
1931            values: vec![make_int_literal(1)],
1932        }));
1933        let subscript = Expression::Subscript(Box::new(crate::expressions::Subscript {
1934            this: map,
1935            index: make_string_literal("a"),
1936        }));
1937        assert_eq!(
1938            annotator.annotate(&subscript),
1939            Some(DataType::Int {
1940                length: None,
1941                integer_spelling: false
1942            })
1943        );
1944    }
1945
1946    #[test]
1947    fn test_struct_type() {
1948        let mut annotator = TypeAnnotator::new(None, None);
1949
1950        // STRUCT literal
1951        let struct_expr = Expression::Struct(Box::new(crate::expressions::Struct {
1952            fields: vec![
1953                (Some("name".to_string()), make_string_literal("Alice")),
1954                (Some("age".to_string()), make_int_literal(30)),
1955            ],
1956        }));
1957        let result = annotator.annotate(&struct_expr);
1958        assert!(matches!(result, Some(DataType::Struct { fields, .. }) if fields.len() == 2));
1959    }
1960
1961    #[test]
1962    fn test_map_type() {
1963        let mut annotator = TypeAnnotator::new(None, None);
1964
1965        // MAP literal
1966        let map_expr = Expression::Map(Box::new(crate::expressions::Map {
1967            keys: vec![make_string_literal("a"), make_string_literal("b")],
1968            values: vec![make_int_literal(1), make_int_literal(2)],
1969        }));
1970        let result = annotator.annotate(&map_expr);
1971        assert!(matches!(
1972            result,
1973            Some(DataType::Map { key_type, value_type })
1974            if matches!(*key_type, DataType::VarChar { .. })
1975               && matches!(*value_type, DataType::Int { .. })
1976        ));
1977    }
1978
1979    #[test]
1980    fn test_explode_array_type() {
1981        let mut annotator = TypeAnnotator::new(None, None);
1982
1983        // EXPLODE(array) returns element type
1984        let arr = Expression::Array(Box::new(crate::expressions::Array {
1985            expressions: vec![make_int_literal(1), make_int_literal(2)],
1986        }));
1987        let explode = Expression::Explode(Box::new(crate::expressions::UnaryFunc {
1988            this: arr,
1989            original_name: None,
1990            inferred_type: None,
1991        }));
1992        assert_eq!(
1993            annotator.annotate(&explode),
1994            Some(DataType::Int {
1995                length: None,
1996                integer_spelling: false
1997            })
1998        );
1999    }
2000
2001    #[test]
2002    fn test_unnest_array_type() {
2003        let mut annotator = TypeAnnotator::new(None, None);
2004
2005        // UNNEST(array) returns element type
2006        let arr = Expression::Array(Box::new(crate::expressions::Array {
2007            expressions: vec![make_string_literal("a"), make_string_literal("b")],
2008        }));
2009        let unnest = Expression::Unnest(Box::new(crate::expressions::UnnestFunc {
2010            this: arr,
2011            expressions: Vec::new(),
2012            with_ordinality: false,
2013            alias: None,
2014            offset_alias: None,
2015        }));
2016        assert_eq!(
2017            annotator.annotate(&unnest),
2018            Some(DataType::VarChar {
2019                length: None,
2020                parenthesized_length: false
2021            })
2022        );
2023    }
2024
2025    #[test]
2026    fn test_set_operation_type() {
2027        let mut annotator = TypeAnnotator::new(None, None);
2028
2029        // UNION/INTERSECT/EXCEPT return None (they produce relations, not scalars)
2030        let select = Expression::Select(Box::new(crate::expressions::Select::default()));
2031        let union = Expression::Union(Box::new(crate::expressions::Union {
2032            left: select.clone(),
2033            right: select.clone(),
2034            all: false,
2035            distinct: false,
2036            with: None,
2037            order_by: None,
2038            limit: None,
2039            offset: None,
2040            by_name: false,
2041            side: None,
2042            kind: None,
2043            corresponding: false,
2044            strict: false,
2045            on_columns: Vec::new(),
2046            distribute_by: None,
2047            sort_by: None,
2048            cluster_by: None,
2049        }));
2050        assert_eq!(annotator.annotate(&union), None);
2051    }
2052
2053    #[test]
2054    fn test_floor_ceil_input_dependent_types() {
2055        use crate::expressions::{CeilFunc, FloorFunc};
2056
2057        let mut annotator = TypeAnnotator::new(None, None);
2058
2059        // FLOOR/CEIL with integer literal → Double (integers get promoted)
2060        let floor_int = Expression::Floor(Box::new(FloorFunc {
2061            this: make_int_literal(42),
2062            scale: None,
2063            to: None,
2064        }));
2065        assert_eq!(
2066            annotator.annotate(&floor_int),
2067            Some(DataType::Double {
2068                precision: None,
2069                scale: None,
2070            })
2071        );
2072
2073        let ceil_int = Expression::Ceil(Box::new(CeilFunc {
2074            this: make_int_literal(42),
2075            decimals: None,
2076            to: None,
2077        }));
2078        assert_eq!(
2079            annotator.annotate(&ceil_int),
2080            Some(DataType::Double {
2081                precision: None,
2082                scale: None,
2083            })
2084        );
2085
2086        // FLOOR with float literal → Double (literals are always Double)
2087        let floor_float = Expression::Floor(Box::new(FloorFunc {
2088            this: make_float_literal(3.14),
2089            scale: None,
2090            to: None,
2091        }));
2092        assert_eq!(
2093            annotator.annotate(&floor_float),
2094            Some(DataType::Double {
2095                precision: None,
2096                scale: None,
2097            })
2098        );
2099
2100        // FLOOR via Function("FLOOR") path → falls through to arg-based inference
2101        let floor_fn =
2102            Expression::Function(Box::new(Function::new("FLOOR", vec![make_int_literal(1)])));
2103        assert_eq!(
2104            annotator.annotate(&floor_fn),
2105            Some(DataType::Int {
2106                length: None,
2107                integer_spelling: false,
2108            })
2109        );
2110    }
2111
2112    #[test]
2113    fn test_sign_preserves_input_type() {
2114        use crate::expressions::UnaryFunc;
2115
2116        let mut annotator = TypeAnnotator::new(None, None);
2117
2118        // SIGN with integer literal → Int (preserves input type)
2119        let sign_int = Expression::Sign(Box::new(UnaryFunc {
2120            this: make_int_literal(42),
2121            original_name: None,
2122            inferred_type: None,
2123        }));
2124        assert_eq!(
2125            annotator.annotate(&sign_int),
2126            Some(DataType::Int {
2127                length: None,
2128                integer_spelling: false,
2129            })
2130        );
2131
2132        // SIGN with float literal → Double (preserves input type)
2133        let sign_float = Expression::Sign(Box::new(UnaryFunc {
2134            this: make_float_literal(3.14),
2135            original_name: None,
2136            inferred_type: None,
2137        }));
2138        assert_eq!(
2139            annotator.annotate(&sign_float),
2140            Some(DataType::Double {
2141                precision: None,
2142                scale: None,
2143            })
2144        );
2145
2146        // SIGN with a CAST to INT → Int (preserves input type)
2147        let sign_cast = Expression::Sign(Box::new(UnaryFunc {
2148            this: Expression::Cast(Box::new(Cast {
2149                this: make_int_literal(42),
2150                to: DataType::Int {
2151                    length: None,
2152                    integer_spelling: false,
2153                },
2154                format: None,
2155                trailing_comments: Vec::new(),
2156                double_colon_syntax: false,
2157                default: None,
2158                inferred_type: None,
2159            })),
2160            original_name: None,
2161            inferred_type: None,
2162        }));
2163        assert_eq!(
2164            annotator.annotate(&sign_cast),
2165            Some(DataType::Int {
2166                length: None,
2167                integer_spelling: false,
2168            })
2169        );
2170    }
2171
2172    #[test]
2173    fn test_date_format_types() {
2174        use crate::expressions::{DateFormatFunc, TimeToStr};
2175
2176        let mut annotator = TypeAnnotator::new(None, None);
2177
2178        // DateFormat → VarChar
2179        let date_fmt = Expression::DateFormat(Box::new(DateFormatFunc {
2180            this: make_string_literal("2024-01-01"),
2181            format: make_string_literal("%Y-%m-%d"),
2182        }));
2183        assert_eq!(
2184            annotator.annotate(&date_fmt),
2185            Some(DataType::VarChar {
2186                length: None,
2187                parenthesized_length: false,
2188            })
2189        );
2190
2191        // FormatDate → VarChar
2192        let format_date = Expression::FormatDate(Box::new(DateFormatFunc {
2193            this: make_string_literal("2024-01-01"),
2194            format: make_string_literal("%Y-%m-%d"),
2195        }));
2196        assert_eq!(
2197            annotator.annotate(&format_date),
2198            Some(DataType::VarChar {
2199                length: None,
2200                parenthesized_length: false,
2201            })
2202        );
2203
2204        // TimeToStr → VarChar
2205        let time_to_str = Expression::TimeToStr(Box::new(TimeToStr {
2206            this: Box::new(make_string_literal("2024-01-01")),
2207            format: "%Y-%m-%d".to_string(),
2208            culture: None,
2209            zone: None,
2210        }));
2211        assert_eq!(
2212            annotator.annotate(&time_to_str),
2213            Some(DataType::VarChar {
2214                length: None,
2215                parenthesized_length: false,
2216            })
2217        );
2218
2219        // DATE_FORMAT via Function path → VarChar (uses function_return_types)
2220        let date_fmt_fn = Expression::Function(Box::new(Function::new(
2221            "DATE_FORMAT",
2222            vec![
2223                make_string_literal("2024-01-01"),
2224                make_string_literal("%Y-%m-%d"),
2225            ],
2226        )));
2227        assert_eq!(
2228            annotator.annotate(&date_fmt_fn),
2229            Some(DataType::VarChar {
2230                length: None,
2231                parenthesized_length: false,
2232            })
2233        );
2234    }
2235
2236    // ===== In-place annotation tests (Step 9) =====
2237
2238    #[test]
2239    fn test_annotate_in_place_sets_type_on_root() {
2240        // Literals don't have inferred_type field, so test with a BinaryOp
2241        let mut expr = Expression::Add(Box::new(BinaryOp::new(
2242            make_int_literal(1),
2243            make_int_literal(2),
2244        )));
2245        annotate_types(&mut expr, None, None);
2246        assert_eq!(
2247            expr.inferred_type(),
2248            Some(&DataType::Int {
2249                length: None,
2250                integer_spelling: false,
2251            })
2252        );
2253    }
2254
2255    #[test]
2256    fn test_annotate_in_place_sets_types_on_children() {
2257        // (a + b) + (c - d) where all are ints
2258        // This tests that inner BinaryOp children also get annotated
2259        let inner_add = Expression::Add(Box::new(BinaryOp::new(
2260            make_int_literal(1),
2261            make_float_literal(2.5),
2262        )));
2263        let inner_sub = Expression::Sub(Box::new(BinaryOp::new(
2264            make_int_literal(3),
2265            make_int_literal(4),
2266        )));
2267        let mut expr = Expression::Add(Box::new(BinaryOp::new(inner_add, inner_sub)));
2268        annotate_types(&mut expr, None, None);
2269
2270        // Root (Add) should be Double (wider of Double and Int)
2271        assert_eq!(
2272            expr.inferred_type(),
2273            Some(&DataType::Double {
2274                precision: None,
2275                scale: None,
2276            })
2277        );
2278
2279        // Children should also have types
2280        if let Expression::Add(op) = &expr {
2281            // Left child (1 + 2.5) should be Double
2282            assert_eq!(
2283                op.left.inferred_type(),
2284                Some(&DataType::Double {
2285                    precision: None,
2286                    scale: None,
2287                })
2288            );
2289            // Right child (3 - 4) should be Int
2290            assert_eq!(
2291                op.right.inferred_type(),
2292                Some(&DataType::Int {
2293                    length: None,
2294                    integer_spelling: false,
2295                })
2296            );
2297        } else {
2298            panic!("Expected Add expression");
2299        }
2300    }
2301
2302    #[test]
2303    fn test_annotate_in_place_comparison() {
2304        let mut expr = Expression::Eq(Box::new(BinaryOp::new(
2305            make_int_literal(1),
2306            make_int_literal(2),
2307        )));
2308        annotate_types(&mut expr, None, None);
2309        assert_eq!(expr.inferred_type(), Some(&DataType::Boolean));
2310    }
2311
2312    #[test]
2313    fn test_annotate_in_place_cast() {
2314        let mut expr = Expression::Cast(Box::new(Cast {
2315            this: make_int_literal(42),
2316            to: DataType::VarChar {
2317                length: None,
2318                parenthesized_length: false,
2319            },
2320            trailing_comments: vec![],
2321            double_colon_syntax: false,
2322            format: None,
2323            default: None,
2324            inferred_type: None,
2325        }));
2326        annotate_types(&mut expr, None, None);
2327        assert_eq!(
2328            expr.inferred_type(),
2329            Some(&DataType::VarChar {
2330                length: None,
2331                parenthesized_length: false,
2332            })
2333        );
2334    }
2335
2336    #[test]
2337    fn test_annotate_in_place_nested_expression() {
2338        // (1 + 2) > 0  -> should be Boolean at root, Int for the Add
2339        let add = Expression::Add(Box::new(BinaryOp::new(
2340            make_int_literal(1),
2341            make_int_literal(2),
2342        )));
2343        let mut expr = Expression::Gt(Box::new(BinaryOp::new(add, make_int_literal(0))));
2344        annotate_types(&mut expr, None, None);
2345
2346        assert_eq!(expr.inferred_type(), Some(&DataType::Boolean));
2347
2348        // The left child (Add) should be Int
2349        if let Expression::Gt(op) = &expr {
2350            assert_eq!(
2351                op.left.inferred_type(),
2352                Some(&DataType::Int {
2353                    length: None,
2354                    integer_spelling: false,
2355                })
2356            );
2357        }
2358    }
2359
2360    #[test]
2361    fn test_annotate_in_place_parsed_sql() {
2362        use crate::parser::Parser;
2363        let mut expr =
2364            Parser::parse_sql("SELECT 1 + 2.0, 'hello', TRUE").expect("parse failed")[0].clone();
2365        annotate_types(&mut expr, None, None);
2366
2367        // The expression tree should have types annotated throughout
2368        // We can't easily inspect deep inside a parsed Select, but at minimum
2369        // the root Select itself won't have a type (it's not value-producing)
2370        assert!(expr.inferred_type().is_none());
2371    }
2372
2373    #[test]
2374    fn test_inferred_type_json_roundtrip() {
2375        let mut expr = Expression::Add(Box::new(BinaryOp::new(
2376            make_int_literal(1),
2377            make_int_literal(2),
2378        )));
2379        annotate_types(&mut expr, None, None);
2380
2381        // Serialize to JSON
2382        let json = serde_json::to_string(&expr).expect("serialize failed");
2383        // The JSON should contain the inferred_type
2384        assert!(json.contains("inferred_type"));
2385
2386        // Deserialize back
2387        let deserialized: Expression = serde_json::from_str(&json).expect("deserialize failed");
2388        assert_eq!(
2389            deserialized.inferred_type(),
2390            Some(&DataType::Int {
2391                length: None,
2392                integer_spelling: false,
2393            })
2394        );
2395    }
2396
2397    #[test]
2398    fn test_inferred_type_none_not_serialized() {
2399        // When inferred_type is None, it should not appear in JSON
2400        let expr = Expression::Add(Box::new(BinaryOp::new(
2401            make_int_literal(1),
2402            make_int_literal(2),
2403        )));
2404        let json = serde_json::to_string(&expr).expect("serialize failed");
2405        assert!(!json.contains("inferred_type"));
2406    }
2407
2408    #[test]
2409    fn test_annotate_if_func_bigquery_node_and_alias_type() {
2410        let mut schema = MappingSchema::with_dialect(DialectType::BigQuery);
2411        schema
2412            .add_table(
2413                "t",
2414                &[("col1".to_string(), DataType::String { length: None })],
2415                None,
2416            )
2417            .unwrap();
2418
2419        let mut expr = parse_one(
2420            "SELECT IF(col1 IS NOT NULL, 1, 0) AS x FROM t",
2421            DialectType::BigQuery,
2422        )
2423        .unwrap();
2424        annotate_types(&mut expr, Some(&schema), Some(DialectType::BigQuery));
2425
2426        let Expression::Select(select) = &expr else {
2427            panic!("expected select");
2428        };
2429        let Expression::Alias(alias) = &select.expressions[0] else {
2430            panic!("expected alias");
2431        };
2432
2433        assert_eq!(
2434            alias.this.inferred_type(),
2435            Some(&DataType::Int {
2436                length: None,
2437                integer_spelling: false,
2438            })
2439        );
2440        assert_eq!(
2441            select.expressions[0].inferred_type(),
2442            Some(&DataType::Int {
2443                length: None,
2444                integer_spelling: false,
2445            })
2446        );
2447    }
2448
2449    #[test]
2450    fn test_annotate_nvl2_node_type() {
2451        let mut expr = parse_one("SELECT NVL2(a, 1, 0) AS x", DialectType::Generic).unwrap();
2452        annotate_types(&mut expr, None, None);
2453
2454        let Expression::Select(select) = &expr else {
2455            panic!("expected select");
2456        };
2457        let Expression::Alias(alias) = &select.expressions[0] else {
2458            panic!("expected alias");
2459        };
2460
2461        assert_eq!(
2462            alias.this.inferred_type(),
2463            Some(&DataType::Int {
2464                length: None,
2465                integer_spelling: false,
2466            })
2467        );
2468    }
2469
2470    #[test]
2471    fn test_annotate_count_node_type() {
2472        let mut expr = parse_one("SELECT COUNT(1) AS x", DialectType::Generic).unwrap();
2473        annotate_types(&mut expr, None, None);
2474
2475        let Expression::Select(select) = &expr else {
2476            panic!("expected select");
2477        };
2478        let Expression::Alias(alias) = &select.expressions[0] else {
2479            panic!("expected alias");
2480        };
2481
2482        assert_eq!(
2483            alias.this.inferred_type(),
2484            Some(&DataType::BigInt { length: None })
2485        );
2486    }
2487
2488    #[test]
2489    fn test_annotate_group_concat_node_type() {
2490        let mut expr = parse_one("SELECT GROUP_CONCAT(a) AS x", DialectType::Generic).unwrap();
2491        annotate_types(&mut expr, None, None);
2492
2493        let Expression::Select(select) = &expr else {
2494            panic!("expected select");
2495        };
2496        let Expression::Alias(alias) = &select.expressions[0] else {
2497            panic!("expected alias");
2498        };
2499
2500        assert_eq!(
2501            alias.this.inferred_type(),
2502            Some(&DataType::VarChar {
2503                length: None,
2504                parenthesized_length: false,
2505            })
2506        );
2507    }
2508
2509    #[test]
2510    fn test_annotate_sum_if_generic_aggregate_type() {
2511        let mut expr =
2512            parse_one("SELECT SUM_IF(1, a > 0) AS x FROM t", DialectType::Generic).unwrap();
2513        annotate_types(&mut expr, None, None);
2514
2515        let Expression::Select(select) = &expr else {
2516            panic!("expected select");
2517        };
2518        let Expression::Alias(alias) = &select.expressions[0] else {
2519            panic!("expected alias");
2520        };
2521
2522        assert_eq!(
2523            select.expressions[0].inferred_type(),
2524            Some(&DataType::BigInt { length: None })
2525        );
2526        assert_eq!(
2527            alias.this.inferred_type(),
2528            Some(&DataType::BigInt { length: None })
2529        );
2530    }
2531}