Skip to main content

polyglot_sql/optimizer/
annotate_types.rs

1//! Type Annotation for SQL Expressions
2//!
3//! This module provides type inference and annotation for SQL AST nodes.
4//! It walks the expression tree and assigns data types to expressions based on:
5//! - Literal values (strings, numbers, booleans)
6//! - Column references (from schema)
7//! - Function return types
8//! - Operator result types (with coercion rules)
9//!
10//! Based on SQLGlot's optimizer/annotate_types.py
11
12use std::collections::HashMap;
13
14use crate::dialects::DialectType;
15use crate::expressions::{
16    BinaryOp, DataType, Expression, Function, IfFunc, ListAggOverflow, Literal, Map, Nvl2Func,
17    Struct, StructField, Subscript,
18};
19use crate::schema::Schema;
20
21/// Type coercion class for determining result types in binary operations.
22/// Higher-priority classes win during coercion.
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
24pub enum TypeCoercionClass {
25    /// Text types (CHAR, VARCHAR, TEXT)
26    Text = 0,
27    /// Numeric types (INT, FLOAT, DECIMAL, etc.)
28    Numeric = 1,
29    /// Time-like types (DATE, TIME, TIMESTAMP, INTERVAL)
30    Timelike = 2,
31}
32
33impl TypeCoercionClass {
34    /// Get the coercion class for a data type
35    pub fn from_data_type(dt: &DataType) -> Option<Self> {
36        match dt {
37            // Text types
38            DataType::Char { .. }
39            | DataType::VarChar { .. }
40            | DataType::Text
41            | DataType::Binary { .. }
42            | DataType::VarBinary { .. }
43            | DataType::Blob => Some(TypeCoercionClass::Text),
44
45            // Numeric types
46            DataType::Boolean
47            | DataType::TinyInt { .. }
48            | DataType::SmallInt { .. }
49            | DataType::Int { .. }
50            | DataType::BigInt { .. }
51            | DataType::Float { .. }
52            | DataType::Double { .. }
53            | DataType::Decimal { .. } => Some(TypeCoercionClass::Numeric),
54
55            // Timelike types
56            DataType::Date
57            | DataType::Time { .. }
58            | DataType::Timestamp { .. }
59            | DataType::Interval { .. } => Some(TypeCoercionClass::Timelike),
60
61            // Other types don't have a coercion class
62            _ => None,
63        }
64    }
65}
66
67/// Type annotation configuration and state
68pub struct TypeAnnotator<'a> {
69    /// Schema for looking up column types
70    _schema: Option<&'a dyn Schema>,
71    /// Dialect for dialect-specific type rules
72    _dialect: Option<DialectType>,
73    /// Whether to annotate types for all expressions
74    annotate_aggregates: bool,
75    /// Function return type mappings
76    function_return_types: HashMap<String, DataType>,
77}
78
79impl<'a> TypeAnnotator<'a> {
80    /// Create a new type annotator
81    pub fn new(schema: Option<&'a dyn Schema>, dialect: Option<DialectType>) -> Self {
82        let mut annotator = Self {
83            _schema: schema,
84            _dialect: dialect,
85            annotate_aggregates: true,
86            function_return_types: HashMap::new(),
87        };
88        annotator.init_function_return_types();
89        annotator
90    }
91
92    /// Initialize function return type mappings
93    fn init_function_return_types(&mut self) {
94        // Aggregate functions
95        self.function_return_types
96            .insert("COUNT".to_string(), DataType::BigInt { length: None });
97        self.function_return_types.insert(
98            "SUM".to_string(),
99            DataType::Decimal {
100                precision: None,
101                scale: None,
102            },
103        );
104        self.function_return_types.insert(
105            "AVG".to_string(),
106            DataType::Double {
107                precision: None,
108                scale: None,
109            },
110        );
111
112        // String functions
113        self.function_return_types.insert(
114            "CONCAT".to_string(),
115            DataType::VarChar {
116                length: None,
117                parenthesized_length: false,
118            },
119        );
120        self.function_return_types.insert(
121            "UPPER".to_string(),
122            DataType::VarChar {
123                length: None,
124                parenthesized_length: false,
125            },
126        );
127        self.function_return_types.insert(
128            "LOWER".to_string(),
129            DataType::VarChar {
130                length: None,
131                parenthesized_length: false,
132            },
133        );
134        self.function_return_types.insert(
135            "TRIM".to_string(),
136            DataType::VarChar {
137                length: None,
138                parenthesized_length: false,
139            },
140        );
141        self.function_return_types.insert(
142            "LTRIM".to_string(),
143            DataType::VarChar {
144                length: None,
145                parenthesized_length: false,
146            },
147        );
148        self.function_return_types.insert(
149            "RTRIM".to_string(),
150            DataType::VarChar {
151                length: None,
152                parenthesized_length: false,
153            },
154        );
155        self.function_return_types.insert(
156            "SUBSTRING".to_string(),
157            DataType::VarChar {
158                length: None,
159                parenthesized_length: false,
160            },
161        );
162        self.function_return_types.insert(
163            "SUBSTR".to_string(),
164            DataType::VarChar {
165                length: None,
166                parenthesized_length: false,
167            },
168        );
169        self.function_return_types.insert(
170            "REPLACE".to_string(),
171            DataType::VarChar {
172                length: None,
173                parenthesized_length: false,
174            },
175        );
176        self.function_return_types.insert(
177            "LENGTH".to_string(),
178            DataType::Int {
179                length: None,
180                integer_spelling: false,
181            },
182        );
183        self.function_return_types.insert(
184            "CHAR_LENGTH".to_string(),
185            DataType::Int {
186                length: None,
187                integer_spelling: false,
188            },
189        );
190
191        // Date/Time functions
192        self.function_return_types.insert(
193            "NOW".to_string(),
194            DataType::Timestamp {
195                precision: None,
196                timezone: false,
197            },
198        );
199        self.function_return_types.insert(
200            "CURRENT_TIMESTAMP".to_string(),
201            DataType::Timestamp {
202                precision: None,
203                timezone: false,
204            },
205        );
206        self.function_return_types
207            .insert("CURRENT_DATE".to_string(), DataType::Date);
208        self.function_return_types.insert(
209            "CURRENT_TIME".to_string(),
210            DataType::Time {
211                precision: None,
212                timezone: false,
213            },
214        );
215        self.function_return_types
216            .insert("DATE".to_string(), DataType::Date);
217        self.function_return_types.insert(
218            "YEAR".to_string(),
219            DataType::Int {
220                length: None,
221                integer_spelling: false,
222            },
223        );
224        self.function_return_types.insert(
225            "MONTH".to_string(),
226            DataType::Int {
227                length: None,
228                integer_spelling: false,
229            },
230        );
231        self.function_return_types.insert(
232            "DAY".to_string(),
233            DataType::Int {
234                length: None,
235                integer_spelling: false,
236            },
237        );
238        self.function_return_types.insert(
239            "HOUR".to_string(),
240            DataType::Int {
241                length: None,
242                integer_spelling: false,
243            },
244        );
245        self.function_return_types.insert(
246            "MINUTE".to_string(),
247            DataType::Int {
248                length: None,
249                integer_spelling: false,
250            },
251        );
252        self.function_return_types.insert(
253            "SECOND".to_string(),
254            DataType::Int {
255                length: None,
256                integer_spelling: false,
257            },
258        );
259        self.function_return_types.insert(
260            "EXTRACT".to_string(),
261            DataType::Int {
262                length: None,
263                integer_spelling: false,
264            },
265        );
266        self.function_return_types.insert(
267            "DATE_DIFF".to_string(),
268            DataType::Int {
269                length: None,
270                integer_spelling: false,
271            },
272        );
273        self.function_return_types.insert(
274            "DATEDIFF".to_string(),
275            DataType::Int {
276                length: None,
277                integer_spelling: false,
278            },
279        );
280
281        // Math functions
282        self.function_return_types.insert(
283            "ABS".to_string(),
284            DataType::Double {
285                precision: None,
286                scale: None,
287            },
288        );
289        self.function_return_types.insert(
290            "ROUND".to_string(),
291            DataType::Double {
292                precision: None,
293                scale: None,
294            },
295        );
296        self.function_return_types.insert(
297            "DATE_FORMAT".to_string(),
298            DataType::VarChar {
299                length: None,
300                parenthesized_length: false,
301            },
302        );
303        self.function_return_types.insert(
304            "FORMAT_DATE".to_string(),
305            DataType::VarChar {
306                length: None,
307                parenthesized_length: false,
308            },
309        );
310        self.function_return_types.insert(
311            "TIME_TO_STR".to_string(),
312            DataType::VarChar {
313                length: None,
314                parenthesized_length: false,
315            },
316        );
317        self.function_return_types.insert(
318            "SQRT".to_string(),
319            DataType::Double {
320                precision: None,
321                scale: None,
322            },
323        );
324        self.function_return_types.insert(
325            "POWER".to_string(),
326            DataType::Double {
327                precision: None,
328                scale: None,
329            },
330        );
331        self.function_return_types.insert(
332            "MOD".to_string(),
333            DataType::Int {
334                length: None,
335                integer_spelling: false,
336            },
337        );
338        self.function_return_types.insert(
339            "LOG".to_string(),
340            DataType::Double {
341                precision: None,
342                scale: None,
343            },
344        );
345        self.function_return_types.insert(
346            "LN".to_string(),
347            DataType::Double {
348                precision: None,
349                scale: None,
350            },
351        );
352        self.function_return_types.insert(
353            "EXP".to_string(),
354            DataType::Double {
355                precision: None,
356                scale: None,
357            },
358        );
359
360        // Null-handling functions return Unknown (infer from args)
361        self.function_return_types
362            .insert("COALESCE".to_string(), DataType::Unknown);
363        self.function_return_types
364            .insert("NULLIF".to_string(), DataType::Unknown);
365        self.function_return_types
366            .insert("GREATEST".to_string(), DataType::Unknown);
367        self.function_return_types
368            .insert("LEAST".to_string(), DataType::Unknown);
369    }
370
371    /// Annotate types for an expression tree
372    pub fn annotate(&mut self, expr: &Expression) -> Option<DataType> {
373        match expr {
374            // Literals
375            Expression::Literal(lit) => self.annotate_literal(lit),
376            Expression::Boolean(_) => Some(DataType::Boolean),
377            Expression::Null(_) => None, // NULL has no type
378
379            // Arithmetic binary operations
380            Expression::Add(op)
381            | Expression::Sub(op)
382            | Expression::Mul(op)
383            | Expression::Div(op)
384            | Expression::Mod(op) => self.annotate_arithmetic(op),
385
386            // Comparison operations - always boolean
387            Expression::Eq(_)
388            | Expression::Neq(_)
389            | Expression::Lt(_)
390            | Expression::Lte(_)
391            | Expression::Gt(_)
392            | Expression::Gte(_)
393            | Expression::Like(_)
394            | Expression::ILike(_) => Some(DataType::Boolean),
395
396            // Logical operations - always boolean
397            Expression::And(_) | Expression::Or(_) | Expression::Not(_) => Some(DataType::Boolean),
398
399            // Predicates - always boolean
400            Expression::Between(_)
401            | Expression::In(_)
402            | Expression::IsNull(_)
403            | Expression::IsTrue(_)
404            | Expression::IsFalse(_)
405            | Expression::Is(_)
406            | Expression::Exists(_) => Some(DataType::Boolean),
407
408            // String concatenation
409            Expression::Concat(_) => Some(DataType::VarChar {
410                length: None,
411                parenthesized_length: false,
412            }),
413
414            // Bitwise operations - integer
415            Expression::BitwiseAnd(_)
416            | Expression::BitwiseOr(_)
417            | Expression::BitwiseXor(_)
418            | Expression::BitwiseNot(_) => Some(DataType::BigInt { length: None }),
419
420            // Negation preserves type
421            Expression::Neg(op) => self.annotate(&op.this),
422
423            // Functions
424            Expression::Function(func) => self.annotate_function(func),
425            Expression::IfFunc(if_func) => self.annotate_if_func(if_func),
426            Expression::Nvl2(nvl2) => self.annotate_nvl2(nvl2),
427
428            // Typed aggregate functions
429            Expression::Count(_) => Some(DataType::BigInt { length: None }),
430            Expression::Sum(agg) => self.annotate_sum(&agg.this),
431            Expression::SumIf(f) => self.annotate_sum(&f.this),
432            Expression::Avg(_) => Some(DataType::Double {
433                precision: None,
434                scale: None,
435            }),
436            Expression::Min(agg) => self.annotate(&agg.this),
437            Expression::Max(agg) => self.annotate(&agg.this),
438            Expression::GroupConcat(_) | Expression::StringAgg(_) | Expression::ListAgg(_) => {
439                Some(DataType::VarChar {
440                    length: None,
441                    parenthesized_length: false,
442                })
443            }
444
445            // Generic aggregate function
446            Expression::AggregateFunction(agg) => {
447                if !self.annotate_aggregates {
448                    return None;
449                }
450                let func_name = agg.name.to_uppercase();
451                self.get_aggregate_return_type(&func_name, &agg.args)
452            }
453
454            // Column references - look up type from schema if available
455            Expression::Column(col) => {
456                if let Some(schema) = &self._schema {
457                    let table_name = col.table.as_ref().map(|t| t.name.as_str()).unwrap_or("");
458                    schema.get_column_type(table_name, &col.name.name).ok()
459                } else {
460                    None
461                }
462            }
463
464            // Cast expressions
465            Expression::Cast(cast) => Some(cast.to.clone()),
466            Expression::SafeCast(cast) => Some(cast.to.clone()),
467            Expression::TryCast(cast) => Some(cast.to.clone()),
468
469            // Subqueries - type is the type of the first SELECT expression
470            Expression::Subquery(subq) => {
471                if let Expression::Select(select) = &subq.this {
472                    if let Some(first) = select.expressions.first() {
473                        self.annotate(first)
474                    } else {
475                        None
476                    }
477                } else {
478                    None
479                }
480            }
481
482            // CASE expression - type of the first THEN/ELSE
483            Expression::Case(case) => {
484                if let Some(else_expr) = &case.else_ {
485                    self.annotate(else_expr)
486                } else if let Some((_, then_expr)) = case.whens.first() {
487                    self.annotate(then_expr)
488                } else {
489                    None
490                }
491            }
492
493            // Array expressions
494            Expression::Array(arr) => {
495                if let Some(first) = arr.expressions.first() {
496                    if let Some(elem_type) = self.annotate(first) {
497                        Some(DataType::Array {
498                            element_type: Box::new(elem_type),
499                            dimension: None,
500                        })
501                    } else {
502                        Some(DataType::Array {
503                            element_type: Box::new(DataType::Unknown),
504                            dimension: None,
505                        })
506                    }
507                } else {
508                    Some(DataType::Array {
509                        element_type: Box::new(DataType::Unknown),
510                        dimension: None,
511                    })
512                }
513            }
514
515            // Interval expressions
516            Expression::Interval(_) => Some(DataType::Interval {
517                unit: None,
518                to: None,
519            }),
520
521            // Window functions inherit type from their function
522            Expression::WindowFunction(window) => self.annotate(&window.this),
523
524            // Date/time expressions
525            Expression::CurrentDate(_) => Some(DataType::Date),
526            Expression::CurrentTime(_) => Some(DataType::Time {
527                precision: None,
528                timezone: false,
529            }),
530            Expression::CurrentTimestamp(_) | Expression::CurrentTimestampLTZ(_) => {
531                Some(DataType::Timestamp {
532                    precision: None,
533                    timezone: false,
534                })
535            }
536
537            // Date functions
538            Expression::DateAdd(_)
539            | Expression::DateSub(_)
540            | Expression::ToDate(_)
541            | Expression::Date(_) => Some(DataType::Date),
542            Expression::DateDiff(_) | Expression::Extract(_) => Some(DataType::Int {
543                length: None,
544                integer_spelling: false,
545            }),
546            Expression::ToTimestamp(_) => Some(DataType::Timestamp {
547                precision: None,
548                timezone: false,
549            }),
550
551            // String functions
552            Expression::Upper(_)
553            | Expression::Lower(_)
554            | Expression::Trim(_)
555            | Expression::LTrim(_)
556            | Expression::RTrim(_)
557            | Expression::Replace(_)
558            | Expression::Substring(_)
559            | Expression::Reverse(_)
560            | Expression::Left(_)
561            | Expression::Right(_)
562            | Expression::Repeat(_)
563            | Expression::Lpad(_)
564            | Expression::Rpad(_)
565            | Expression::ConcatWs(_)
566            | Expression::Overlay(_) => Some(DataType::VarChar {
567                length: None,
568                parenthesized_length: false,
569            }),
570            Expression::Length(_) => Some(DataType::Int {
571                length: None,
572                integer_spelling: false,
573            }),
574
575            // Math functions
576            Expression::Abs(_)
577            | Expression::Sqrt(_)
578            | Expression::Cbrt(_)
579            | Expression::Ln(_)
580            | Expression::Exp(_)
581            | Expression::Power(_)
582            | Expression::Log(_) => Some(DataType::Double {
583                precision: None,
584                scale: None,
585            }),
586            Expression::Round(_) => Some(DataType::Double {
587                precision: None,
588                scale: None,
589            }),
590            Expression::Floor(f) => self.annotate_math_function(&f.this),
591            Expression::Ceil(f) => self.annotate_math_function(&f.this),
592            Expression::Sign(s) => self.annotate(&s.this),
593            Expression::DateFormat(_) | Expression::FormatDate(_) | Expression::TimeToStr(_) => {
594                Some(DataType::VarChar {
595                    length: None,
596                    parenthesized_length: false,
597                })
598            }
599
600            // Greatest/Least - coerce argument types
601            Expression::Greatest(v) | Expression::Least(v) => self.coerce_arg_types(&v.expressions),
602
603            // Alias - type of the inner expression
604            Expression::Alias(alias) => self.annotate(&alias.this),
605
606            // SELECT expressions - no scalar type
607            Expression::Select(_) => None,
608
609            // ============================================
610            // 3.1.8: Array/Map Indexing (Subscript/Bracket)
611            // ============================================
612            Expression::Subscript(sub) => self.annotate_subscript(sub),
613
614            // Dot access (struct.field) - returns Unknown without schema
615            Expression::Dot(_) => None,
616
617            // ============================================
618            // 3.1.9: STRUCT Construction
619            // ============================================
620            Expression::Struct(s) => self.annotate_struct(s),
621
622            // ============================================
623            // 3.1.10: MAP Construction
624            // ============================================
625            Expression::Map(map) => self.annotate_map(map),
626            Expression::MapFromEntries(mfe) => {
627                // MAP_FROM_ENTRIES(array_of_pairs) - infer from array element type
628                if let Some(DataType::Array { element_type, .. }) = self.annotate(&mfe.this) {
629                    if let DataType::Struct { fields, .. } = *element_type {
630                        if fields.len() >= 2 {
631                            return Some(DataType::Map {
632                                key_type: Box::new(fields[0].data_type.clone()),
633                                value_type: Box::new(fields[1].data_type.clone()),
634                            });
635                        }
636                    }
637                }
638                Some(DataType::Map {
639                    key_type: Box::new(DataType::Unknown),
640                    value_type: Box::new(DataType::Unknown),
641                })
642            }
643
644            // ============================================
645            // 3.1.11: SetOperation Type Coercion
646            // ============================================
647            Expression::Union(union) => self.annotate_set_operation(&union.left, &union.right),
648            Expression::Intersect(intersect) => {
649                self.annotate_set_operation(&intersect.left, &intersect.right)
650            }
651            Expression::Except(except) => self.annotate_set_operation(&except.left, &except.right),
652
653            // ============================================
654            // 3.1.12: UDTF Type Handling
655            // ============================================
656            Expression::Lateral(lateral) => {
657                // LATERAL subquery - type is the subquery's type
658                self.annotate(&lateral.this)
659            }
660            Expression::LateralView(lv) => {
661                // LATERAL VIEW - returns the exploded type
662                self.annotate_lateral_view(lv)
663            }
664            Expression::Unnest(unnest) => {
665                // UNNEST(array) - returns the element type of the array
666                if let Some(DataType::Array { element_type, .. }) = self.annotate(&unnest.this) {
667                    Some(*element_type)
668                } else {
669                    None
670                }
671            }
672            Expression::Explode(explode) => {
673                // EXPLODE(array) - returns the element type
674                if let Some(DataType::Array { element_type, .. }) = self.annotate(&explode.this) {
675                    Some(*element_type)
676                } else if let Some(DataType::Map {
677                    key_type,
678                    value_type,
679                }) = self.annotate(&explode.this)
680                {
681                    // EXPLODE(map) returns struct(key, value)
682                    Some(DataType::Struct {
683                        fields: vec![
684                            StructField::new("key".to_string(), *key_type),
685                            StructField::new("value".to_string(), *value_type),
686                        ],
687                        nested: false,
688                    })
689                } else {
690                    None
691                }
692            }
693            Expression::ExplodeOuter(explode) => {
694                // EXPLODE_OUTER - same as EXPLODE but preserves nulls
695                if let Some(DataType::Array { element_type, .. }) = self.annotate(&explode.this) {
696                    Some(*element_type)
697                } else {
698                    None
699                }
700            }
701            Expression::GenerateSeries(gs) => {
702                // GENERATE_SERIES returns the type of start/end
703                if let Some(ref start) = gs.start {
704                    self.annotate(start)
705                } else if let Some(ref end) = gs.end {
706                    self.annotate(end)
707                } else {
708                    Some(DataType::Int {
709                        length: None,
710                        integer_spelling: false,
711                    })
712                }
713            }
714
715            // Other expressions - unknown
716            _ => None,
717        }
718    }
719
720    /// Annotate types in-place on the expression tree (bottom-up).
721    ///
722    /// First recurses into children, then computes this node's type using the
723    /// read-only `annotate` method, and finally stores the result via
724    /// `set_inferred_type`.
725    pub fn annotate_in_place(&mut self, expr: &mut Expression) {
726        // 1. Recurse into children (bottom-up)
727        self.annotate_children_in_place(expr);
728
729        // 2. Compute this node's type using the read-only method
730        //    (children already have their types set, but `annotate` re-derives
731        //    from structure, which is fine since the structure hasn't changed)
732        let dt = self.annotate(expr);
733
734        // 3. Store on the node
735        if let Some(data_type) = dt {
736            expr.set_inferred_type(data_type);
737        }
738    }
739
740    /// Recursively annotate children of an expression in-place.
741    fn annotate_children_in_place(&mut self, expr: &mut Expression) {
742        match expr {
743            // Binary operations
744            Expression::And(op)
745            | Expression::Or(op)
746            | Expression::Add(op)
747            | Expression::Sub(op)
748            | Expression::Mul(op)
749            | Expression::Div(op)
750            | Expression::Mod(op)
751            | Expression::Eq(op)
752            | Expression::Neq(op)
753            | Expression::Lt(op)
754            | Expression::Lte(op)
755            | Expression::Gt(op)
756            | Expression::Gte(op)
757            | Expression::Concat(op)
758            | Expression::BitwiseAnd(op)
759            | Expression::BitwiseOr(op)
760            | Expression::BitwiseXor(op)
761            | Expression::Adjacent(op)
762            | Expression::TsMatch(op)
763            | Expression::PropertyEQ(op)
764            | Expression::ArrayContainsAll(op)
765            | Expression::ArrayContainedBy(op)
766            | Expression::ArrayOverlaps(op)
767            | Expression::JSONBContainsAllTopKeys(op)
768            | Expression::JSONBContainsAnyTopKeys(op)
769            | Expression::JSONBDeleteAtPath(op)
770            | Expression::ExtendsLeft(op)
771            | Expression::ExtendsRight(op)
772            | Expression::Is(op)
773            | Expression::MemberOf(op)
774            | Expression::Match(op)
775            | Expression::NullSafeEq(op)
776            | Expression::NullSafeNeq(op)
777            | Expression::Glob(op)
778            | Expression::BitwiseLeftShift(op)
779            | Expression::BitwiseRightShift(op) => {
780                self.annotate_in_place(&mut op.left);
781                self.annotate_in_place(&mut op.right);
782            }
783
784            // Like operations
785            Expression::Like(op) | Expression::ILike(op) => {
786                self.annotate_in_place(&mut op.left);
787                self.annotate_in_place(&mut op.right);
788            }
789
790            // Unary operations
791            Expression::Not(op) | Expression::Neg(op) | Expression::BitwiseNot(op) => {
792                self.annotate_in_place(&mut op.this);
793            }
794
795            // Cast
796            Expression::Cast(c) | Expression::TryCast(c) | Expression::SafeCast(c) => {
797                self.annotate_in_place(&mut c.this);
798            }
799
800            // Case
801            Expression::Case(c) => {
802                if let Some(ref mut operand) = c.operand {
803                    self.annotate_in_place(operand);
804                }
805                for (cond, then_expr) in &mut c.whens {
806                    self.annotate_in_place(cond);
807                    self.annotate_in_place(then_expr);
808                }
809                if let Some(ref mut else_expr) = c.else_ {
810                    self.annotate_in_place(else_expr);
811                }
812            }
813
814            // Alias
815            Expression::Alias(a) => {
816                self.annotate_in_place(&mut a.this);
817            }
818
819            // Column - leaf node, no children to recurse
820            Expression::Column(_) => {}
821
822            // Function
823            Expression::Function(f) => {
824                for arg in &mut f.args {
825                    self.annotate_in_place(arg);
826                }
827            }
828
829            // Dedicated conditional functions
830            Expression::IfFunc(f) => {
831                self.annotate_in_place(&mut f.condition);
832                self.annotate_in_place(&mut f.true_value);
833                if let Some(false_value) = &mut f.false_value {
834                    self.annotate_in_place(false_value);
835                }
836            }
837            Expression::Nvl2(f) => {
838                self.annotate_in_place(&mut f.this);
839                self.annotate_in_place(&mut f.true_value);
840                self.annotate_in_place(&mut f.false_value);
841            }
842
843            // AggregateFunction
844            Expression::AggregateFunction(f) => {
845                for arg in &mut f.args {
846                    self.annotate_in_place(arg);
847                }
848            }
849
850            // Dedicated aggregate / string functions
851            Expression::Count(f) => {
852                if let Some(this) = &mut f.this {
853                    self.annotate_in_place(this);
854                }
855                if let Some(filter) = &mut f.filter {
856                    self.annotate_in_place(filter);
857                }
858            }
859            Expression::GroupConcat(f) => {
860                self.annotate_in_place(&mut f.this);
861                if let Some(separator) = &mut f.separator {
862                    self.annotate_in_place(separator);
863                }
864                if let Some(order_by) = &mut f.order_by {
865                    for ordered in order_by {
866                        self.annotate_in_place(&mut ordered.this);
867                    }
868                }
869                if let Some(filter) = &mut f.filter {
870                    self.annotate_in_place(filter);
871                }
872            }
873            Expression::StringAgg(f) => {
874                self.annotate_in_place(&mut f.this);
875                if let Some(separator) = &mut f.separator {
876                    self.annotate_in_place(separator);
877                }
878                if let Some(order_by) = &mut f.order_by {
879                    for ordered in order_by {
880                        self.annotate_in_place(&mut ordered.this);
881                    }
882                }
883                if let Some(filter) = &mut f.filter {
884                    self.annotate_in_place(filter);
885                }
886                if let Some(limit) = &mut f.limit {
887                    self.annotate_in_place(limit);
888                }
889            }
890            Expression::ListAgg(f) => {
891                self.annotate_in_place(&mut f.this);
892                if let Some(separator) = &mut f.separator {
893                    self.annotate_in_place(separator);
894                }
895                if let Some(order_by) = &mut f.order_by {
896                    for ordered in order_by {
897                        self.annotate_in_place(&mut ordered.this);
898                    }
899                }
900                if let Some(filter) = &mut f.filter {
901                    self.annotate_in_place(filter);
902                }
903                if let Some(ListAggOverflow::Truncate {
904                    filler: Some(filler),
905                    ..
906                }) = &mut f.on_overflow
907                {
908                    self.annotate_in_place(filler);
909                }
910            }
911            Expression::SumIf(f) => {
912                self.annotate_in_place(&mut f.this);
913                self.annotate_in_place(&mut f.condition);
914                if let Some(filter) = &mut f.filter {
915                    self.annotate_in_place(filter);
916                }
917            }
918
919            // WindowFunction
920            Expression::WindowFunction(w) => {
921                self.annotate_in_place(&mut w.this);
922            }
923
924            // Subquery
925            Expression::Subquery(s) => {
926                self.annotate_in_place(&mut s.this);
927            }
928
929            // UnaryFunc variants
930            Expression::Upper(f)
931            | Expression::Lower(f)
932            | Expression::Length(f)
933            | Expression::LTrim(f)
934            | Expression::RTrim(f)
935            | Expression::Reverse(f)
936            | Expression::Abs(f)
937            | Expression::Sqrt(f)
938            | Expression::Cbrt(f)
939            | Expression::Ln(f)
940            | Expression::Exp(f)
941            | Expression::Sign(f)
942            | Expression::Date(f)
943            | Expression::Time(f)
944            | Expression::Explode(f)
945            | Expression::ExplodeOuter(f)
946            | Expression::MapFromEntries(f)
947            | Expression::MapKeys(f)
948            | Expression::MapValues(f)
949            | Expression::ArrayLength(f)
950            | Expression::ArraySize(f)
951            | Expression::Cardinality(f)
952            | Expression::ArrayReverse(f)
953            | Expression::ArrayDistinct(f)
954            | Expression::ArrayFlatten(f)
955            | Expression::ArrayCompact(f)
956            | Expression::ToArray(f)
957            | Expression::JsonArrayLength(f)
958            | Expression::JsonKeys(f)
959            | Expression::JsonType(f)
960            | Expression::ParseJson(f)
961            | Expression::ToJson(f)
962            | Expression::Year(f)
963            | Expression::Month(f)
964            | Expression::Day(f)
965            | Expression::Hour(f)
966            | Expression::Minute(f)
967            | Expression::Second(f)
968            | Expression::Initcap(f)
969            | Expression::Ascii(f)
970            | Expression::Chr(f)
971            | Expression::Soundex(f)
972            | Expression::ByteLength(f)
973            | Expression::Hex(f)
974            | Expression::LowerHex(f)
975            | Expression::Unicode(f)
976            | Expression::Typeof(f)
977            | Expression::BitwiseCount(f)
978            | Expression::Epoch(f)
979            | Expression::EpochMs(f)
980            | Expression::Radians(f)
981            | Expression::Degrees(f)
982            | Expression::Sin(f)
983            | Expression::Cos(f)
984            | Expression::Tan(f)
985            | Expression::Asin(f)
986            | Expression::Acos(f)
987            | Expression::Atan(f)
988            | Expression::IsNan(f)
989            | Expression::IsInf(f) => {
990                self.annotate_in_place(&mut f.this);
991            }
992
993            // BinaryFunc variants
994            Expression::Power(f)
995            | Expression::NullIf(f)
996            | Expression::IfNull(f)
997            | Expression::Nvl(f)
998            | Expression::Contains(f)
999            | Expression::StartsWith(f)
1000            | Expression::EndsWith(f)
1001            | Expression::Levenshtein(f)
1002            | Expression::ModFunc(f)
1003            | Expression::IntDiv(f)
1004            | Expression::Atan2(f)
1005            | Expression::AddMonths(f)
1006            | Expression::MonthsBetween(f)
1007            | Expression::NextDay(f)
1008            | Expression::UnixToTimeStr(f)
1009            | Expression::ArrayContains(f)
1010            | Expression::ArrayPosition(f)
1011            | Expression::ArrayAppend(f)
1012            | Expression::ArrayPrepend(f)
1013            | Expression::ArrayUnion(f)
1014            | Expression::ArrayExcept(f)
1015            | Expression::ArrayRemove(f)
1016            | Expression::StarMap(f)
1017            | Expression::MapFromArrays(f)
1018            | Expression::MapContainsKey(f)
1019            | Expression::ElementAt(f)
1020            | Expression::JsonMergePatch(f) => {
1021                self.annotate_in_place(&mut f.this);
1022                self.annotate_in_place(&mut f.expression);
1023            }
1024
1025            // VarArgFunc variants
1026            Expression::Coalesce(f)
1027            | Expression::Greatest(f)
1028            | Expression::Least(f)
1029            | Expression::ArrayConcat(f)
1030            | Expression::ArrayIntersect(f)
1031            | Expression::ArrayZip(f)
1032            | Expression::MapConcat(f)
1033            | Expression::JsonArray(f) => {
1034                for e in &mut f.expressions {
1035                    self.annotate_in_place(e);
1036                }
1037            }
1038
1039            // AggFunc variants
1040            Expression::Sum(f)
1041            | Expression::Avg(f)
1042            | Expression::Min(f)
1043            | Expression::Max(f)
1044            | Expression::ArrayAgg(f)
1045            | Expression::CountIf(f)
1046            | Expression::Stddev(f)
1047            | Expression::StddevPop(f)
1048            | Expression::StddevSamp(f)
1049            | Expression::Variance(f)
1050            | Expression::VarPop(f)
1051            | Expression::VarSamp(f)
1052            | Expression::Median(f)
1053            | Expression::Mode(f)
1054            | Expression::First(f)
1055            | Expression::Last(f)
1056            | Expression::AnyValue(f)
1057            | Expression::ApproxDistinct(f)
1058            | Expression::ApproxCountDistinct(f)
1059            | Expression::LogicalAnd(f)
1060            | Expression::LogicalOr(f)
1061            | Expression::Skewness(f)
1062            | Expression::ArrayConcatAgg(f)
1063            | Expression::ArrayUniqueAgg(f)
1064            | Expression::BoolXorAgg(f)
1065            | Expression::BitwiseAndAgg(f)
1066            | Expression::BitwiseOrAgg(f)
1067            | Expression::BitwiseXorAgg(f) => {
1068                self.annotate_in_place(&mut f.this);
1069            }
1070
1071            // Select - recurse into expressions
1072            Expression::Select(s) => {
1073                for e in &mut s.expressions {
1074                    self.annotate_in_place(e);
1075                }
1076            }
1077
1078            // Everything else - no children to recurse or not value-producing
1079            _ => {}
1080        }
1081    }
1082
1083    /// Annotate math functions like FLOOR/CEIL that return Double for integer inputs
1084    /// and preserve the input type otherwise (matching sqlglot's _annotate_math_functions).
1085    fn annotate_math_function(&mut self, arg: &Expression) -> Option<DataType> {
1086        let input_type = self.annotate(arg)?;
1087        match input_type {
1088            DataType::TinyInt { .. }
1089            | DataType::SmallInt { .. }
1090            | DataType::Int { .. }
1091            | DataType::BigInt { .. } => Some(DataType::Double {
1092                precision: None,
1093                scale: None,
1094            }),
1095            other => Some(other),
1096        }
1097    }
1098
1099    /// Annotate a subscript/bracket expression (array[index] or map[key])
1100    fn annotate_subscript(&mut self, sub: &Subscript) -> Option<DataType> {
1101        let base_type = self.annotate(&sub.this)?;
1102
1103        match base_type {
1104            DataType::Array { element_type, .. } => Some(*element_type),
1105            DataType::Map { value_type, .. } => Some(*value_type),
1106            DataType::Json | DataType::JsonB => Some(DataType::Json), // JSON indexing returns JSON
1107            DataType::VarChar { .. } | DataType::Text => {
1108                // String indexing returns a character
1109                Some(DataType::VarChar {
1110                    length: Some(1),
1111                    parenthesized_length: false,
1112                })
1113            }
1114            _ => None,
1115        }
1116    }
1117
1118    /// Annotate a STRUCT literal
1119    fn annotate_struct(&mut self, s: &Struct) -> Option<DataType> {
1120        let fields: Vec<StructField> = s
1121            .fields
1122            .iter()
1123            .map(|(name, expr)| {
1124                let field_type = self.annotate(expr).unwrap_or(DataType::Unknown);
1125                StructField::new(name.clone().unwrap_or_default(), field_type)
1126            })
1127            .collect();
1128        Some(DataType::Struct {
1129            fields,
1130            nested: false,
1131        })
1132    }
1133
1134    /// Annotate a MAP literal
1135    fn annotate_map(&mut self, map: &Map) -> Option<DataType> {
1136        let key_type = if let Some(first_key) = map.keys.first() {
1137            self.annotate(first_key).unwrap_or(DataType::Unknown)
1138        } else {
1139            DataType::Unknown
1140        };
1141
1142        let value_type = if let Some(first_value) = map.values.first() {
1143            self.annotate(first_value).unwrap_or(DataType::Unknown)
1144        } else {
1145            DataType::Unknown
1146        };
1147
1148        Some(DataType::Map {
1149            key_type: Box::new(key_type),
1150            value_type: Box::new(value_type),
1151        })
1152    }
1153
1154    /// Annotate a SetOperation (UNION/INTERSECT/EXCEPT)
1155    /// Returns None since set operations produce relation types, not scalar types
1156    fn annotate_set_operation(
1157        &mut self,
1158        _left: &Expression,
1159        _right: &Expression,
1160    ) -> Option<DataType> {
1161        // Set operations produce relations, not scalar types
1162        // The column types would be coerced between left and right
1163        // For now, return None as this is a relation-level type
1164        None
1165    }
1166
1167    /// Annotate a LATERAL VIEW expression
1168    fn annotate_lateral_view(&mut self, lv: &crate::expressions::LateralView) -> Option<DataType> {
1169        // The type depends on the table-generating function
1170        self.annotate(&lv.this)
1171    }
1172
1173    /// Annotate a literal value
1174    fn annotate_literal(&self, lit: &Literal) -> Option<DataType> {
1175        match lit {
1176            Literal::String(_)
1177            | Literal::NationalString(_)
1178            | Literal::TripleQuotedString(_, _)
1179            | Literal::EscapeString(_)
1180            | Literal::DollarString(_)
1181            | Literal::RawString(_) => Some(DataType::VarChar {
1182                length: None,
1183                parenthesized_length: false,
1184            }),
1185            Literal::Number(n) => {
1186                // Try to determine if it's an integer or float
1187                if n.contains('.') || n.contains('e') || n.contains('E') {
1188                    Some(DataType::Double {
1189                        precision: None,
1190                        scale: None,
1191                    })
1192                } else {
1193                    // Check if it fits in an Int or needs BigInt
1194                    if let Ok(_) = n.parse::<i32>() {
1195                        Some(DataType::Int {
1196                            length: None,
1197                            integer_spelling: false,
1198                        })
1199                    } else {
1200                        Some(DataType::BigInt { length: None })
1201                    }
1202                }
1203            }
1204            Literal::HexString(_) | Literal::BitString(_) | Literal::ByteString(_) => {
1205                Some(DataType::VarBinary { length: None })
1206            }
1207            Literal::HexNumber(_) => Some(DataType::BigInt { length: None }),
1208            Literal::Date(_) => Some(DataType::Date),
1209            Literal::Time(_) => Some(DataType::Time {
1210                precision: None,
1211                timezone: false,
1212            }),
1213            Literal::Timestamp(_) => Some(DataType::Timestamp {
1214                precision: None,
1215                timezone: false,
1216            }),
1217            Literal::Datetime(_) => Some(DataType::Custom {
1218                name: "DATETIME".to_string(),
1219            }),
1220        }
1221    }
1222
1223    /// Annotate an arithmetic binary operation
1224    fn annotate_arithmetic(&mut self, op: &BinaryOp) -> Option<DataType> {
1225        let left_type = self.annotate(&op.left);
1226        let right_type = self.annotate(&op.right);
1227
1228        match (left_type, right_type) {
1229            (Some(l), Some(r)) => self.coerce_types(&l, &r),
1230            (Some(t), None) | (None, Some(t)) => Some(t),
1231            (None, None) => None,
1232        }
1233    }
1234
1235    /// Annotate a function call
1236    fn annotate_function(&mut self, func: &Function) -> Option<DataType> {
1237        let func_name = func.name.to_uppercase();
1238
1239        // Check known function return types
1240        if let Some(return_type) = self.function_return_types.get(&func_name) {
1241            if *return_type != DataType::Unknown {
1242                return Some(return_type.clone());
1243            }
1244        }
1245
1246        // For functions with Unknown return type, infer from arguments
1247        match func_name.as_str() {
1248            "COALESCE" | "IFNULL" | "NVL" | "ISNULL" => {
1249                // Return type of first non-null argument
1250                for arg in &func.args {
1251                    if let Some(arg_type) = self.annotate(arg) {
1252                        return Some(arg_type);
1253                    }
1254                }
1255                None
1256            }
1257            "NULLIF" => {
1258                // Return type of first argument
1259                func.args.first().and_then(|arg| self.annotate(arg))
1260            }
1261            "GREATEST" | "LEAST" => {
1262                // Coerce all argument types
1263                self.coerce_arg_types(&func.args)
1264            }
1265            "IF" | "IIF" => {
1266                // Return type of THEN/ELSE branches
1267                if func.args.len() >= 2 {
1268                    self.annotate(&func.args[1])
1269                } else {
1270                    None
1271                }
1272            }
1273            _ => {
1274                // Unknown function - try to infer from first argument
1275                func.args.first().and_then(|arg| self.annotate(arg))
1276            }
1277        }
1278    }
1279
1280    /// Annotate IF/IIF/IFF conditional function
1281    fn annotate_if_func(&mut self, func: &IfFunc) -> Option<DataType> {
1282        let true_type = self.annotate(&func.true_value);
1283        let false_type = func
1284            .false_value
1285            .as_ref()
1286            .and_then(|expr| self.annotate(expr));
1287
1288        match (true_type, false_type) {
1289            (Some(left), Some(right)) => self.coerce_types(&left, &right),
1290            (Some(dt), None) | (None, Some(dt)) => Some(dt),
1291            (None, None) => None,
1292        }
1293    }
1294
1295    /// Annotate NVL2 conditional function from its true/false branches
1296    fn annotate_nvl2(&mut self, func: &Nvl2Func) -> Option<DataType> {
1297        let true_type = self.annotate(&func.true_value);
1298        let false_type = self.annotate(&func.false_value);
1299
1300        match (true_type, false_type) {
1301            (Some(left), Some(right)) => self.coerce_types(&left, &right),
1302            (Some(dt), None) | (None, Some(dt)) => Some(dt),
1303            (None, None) => None,
1304        }
1305    }
1306
1307    /// Get return type for aggregate functions
1308    fn get_aggregate_return_type(
1309        &mut self,
1310        func_name: &str,
1311        args: &[Expression],
1312    ) -> Option<DataType> {
1313        match func_name {
1314            "COUNT" | "COUNT_IF" => Some(DataType::BigInt { length: None }),
1315            "SUM_IF" => {
1316                if let Some(arg) = args.first() {
1317                    self.annotate_sum(arg)
1318                } else {
1319                    Some(DataType::Decimal {
1320                        precision: None,
1321                        scale: None,
1322                    })
1323                }
1324            }
1325            "SUM" => {
1326                if let Some(arg) = args.first() {
1327                    self.annotate_sum(arg)
1328                } else {
1329                    Some(DataType::Decimal {
1330                        precision: None,
1331                        scale: None,
1332                    })
1333                }
1334            }
1335            "AVG" => Some(DataType::Double {
1336                precision: None,
1337                scale: None,
1338            }),
1339            "MIN" | "MAX" => {
1340                // Preserves input type
1341                args.first().and_then(|arg| self.annotate(arg))
1342            }
1343            "STRING_AGG" | "GROUP_CONCAT" | "LISTAGG" | "ARRAY_AGG" => Some(DataType::VarChar {
1344                length: None,
1345                parenthesized_length: false,
1346            }),
1347            "BOOL_AND" | "BOOL_OR" | "EVERY" | "ANY" | "SOME" => Some(DataType::Boolean),
1348            "BIT_AND" | "BIT_OR" | "BIT_XOR" => Some(DataType::BigInt { length: None }),
1349            "STDDEV" | "STDDEV_POP" | "STDDEV_SAMP" | "VARIANCE" | "VAR_POP" | "VAR_SAMP" => {
1350                Some(DataType::Double {
1351                    precision: None,
1352                    scale: None,
1353                })
1354            }
1355            "PERCENTILE_CONT" | "PERCENTILE_DISC" | "MEDIAN" => {
1356                args.first().and_then(|arg| self.annotate(arg))
1357            }
1358            _ => None,
1359        }
1360    }
1361
1362    /// Annotate SUM function - promotes to at least BigInt
1363    fn annotate_sum(&mut self, arg: &Expression) -> Option<DataType> {
1364        match self.annotate(arg) {
1365            Some(DataType::TinyInt { .. })
1366            | Some(DataType::SmallInt { .. })
1367            | Some(DataType::Int { .. }) => Some(DataType::BigInt { length: None }),
1368            Some(DataType::BigInt { .. }) => Some(DataType::BigInt { length: None }),
1369            Some(DataType::Float { .. }) | Some(DataType::Double { .. }) => {
1370                Some(DataType::Double {
1371                    precision: None,
1372                    scale: None,
1373                })
1374            }
1375            Some(DataType::Decimal { precision, scale }) => {
1376                Some(DataType::Decimal { precision, scale })
1377            }
1378            _ => Some(DataType::Decimal {
1379                precision: None,
1380                scale: None,
1381            }),
1382        }
1383    }
1384
1385    /// Coerce multiple argument types to a common type
1386    fn coerce_arg_types(&mut self, args: &[Expression]) -> Option<DataType> {
1387        let mut result_type: Option<DataType> = None;
1388        for arg in args {
1389            if let Some(arg_type) = self.annotate(arg) {
1390                result_type = match result_type {
1391                    Some(t) => self.coerce_types(&t, &arg_type),
1392                    None => Some(arg_type),
1393                };
1394            }
1395        }
1396        result_type
1397    }
1398
1399    /// Coerce two types to a common type
1400    fn coerce_types(&self, left: &DataType, right: &DataType) -> Option<DataType> {
1401        // If types are the same, return that type
1402        if left == right {
1403            return Some(left.clone());
1404        }
1405
1406        // Special case: Interval + Date/Timestamp
1407        match (left, right) {
1408            (DataType::Date, DataType::Interval { .. })
1409            | (DataType::Interval { .. }, DataType::Date) => return Some(DataType::Date),
1410            (
1411                DataType::Timestamp {
1412                    precision,
1413                    timezone,
1414                },
1415                DataType::Interval { .. },
1416            )
1417            | (
1418                DataType::Interval { .. },
1419                DataType::Timestamp {
1420                    precision,
1421                    timezone,
1422                },
1423            ) => {
1424                return Some(DataType::Timestamp {
1425                    precision: *precision,
1426                    timezone: *timezone,
1427                });
1428            }
1429            _ => {}
1430        }
1431
1432        // Coerce based on class
1433        let left_class = TypeCoercionClass::from_data_type(left);
1434        let right_class = TypeCoercionClass::from_data_type(right);
1435
1436        match (left_class, right_class) {
1437            // Same class: use higher-precision type within class
1438            (Some(lc), Some(rc)) if lc == rc => {
1439                // For numeric, choose wider type
1440                if lc == TypeCoercionClass::Numeric {
1441                    Some(self.wider_numeric_type(left, right))
1442                } else {
1443                    // For text and timelike, left wins by default
1444                    Some(left.clone())
1445                }
1446            }
1447            // Different classes: higher-priority class wins
1448            (Some(lc), Some(rc)) => {
1449                if lc > rc {
1450                    Some(left.clone())
1451                } else {
1452                    Some(right.clone())
1453                }
1454            }
1455            // One unknown: use the known type
1456            (Some(_), None) => Some(left.clone()),
1457            (None, Some(_)) => Some(right.clone()),
1458            // Both unknown: return unknown
1459            (None, None) => Some(DataType::Unknown),
1460        }
1461    }
1462
1463    /// Get the wider numeric type
1464    fn wider_numeric_type(&self, left: &DataType, right: &DataType) -> DataType {
1465        let order = |dt: &DataType| -> u8 {
1466            match dt {
1467                DataType::Boolean => 0,
1468                DataType::TinyInt { .. } => 1,
1469                DataType::SmallInt { .. } => 2,
1470                DataType::Int { .. } => 3,
1471                DataType::BigInt { .. } => 4,
1472                DataType::Float { .. } => 5,
1473                DataType::Double { .. } => 6,
1474                DataType::Decimal { .. } => 7,
1475                _ => 0,
1476            }
1477        };
1478
1479        if order(left) >= order(right) {
1480            left.clone()
1481        } else {
1482            right.clone()
1483        }
1484    }
1485}
1486
1487/// Annotate types in-place on the expression tree.
1488///
1489/// Walks the AST bottom-up and sets `inferred_type` on each value-producing
1490/// node. After this call, `expr.inferred_type()` (and the same on any child
1491/// node) returns the inferred type.
1492pub fn annotate_types(
1493    expr: &mut Expression,
1494    schema: Option<&dyn Schema>,
1495    dialect: Option<DialectType>,
1496) {
1497    let mut annotator = TypeAnnotator::new(schema, dialect);
1498    annotator.annotate_in_place(expr);
1499}
1500
1501#[cfg(test)]
1502mod tests {
1503    use super::*;
1504    use crate::expressions::{BooleanLiteral, Cast, Null};
1505    use crate::{parse_one, DialectType, MappingSchema, Schema};
1506
1507    fn make_int_literal(val: i64) -> Expression {
1508        Expression::Literal(Box::new(Literal::Number(val.to_string())))
1509    }
1510
1511    fn make_float_literal(val: f64) -> Expression {
1512        Expression::Literal(Box::new(Literal::Number(val.to_string())))
1513    }
1514
1515    fn make_string_literal(val: &str) -> Expression {
1516        Expression::Literal(Box::new(Literal::String(val.to_string())))
1517    }
1518
1519    fn make_bool_literal(val: bool) -> Expression {
1520        Expression::Boolean(BooleanLiteral { value: val })
1521    }
1522
1523    #[test]
1524    fn test_literal_types() {
1525        let mut annotator = TypeAnnotator::new(None, None);
1526
1527        // Integer literal
1528        let int_expr = make_int_literal(42);
1529        assert_eq!(
1530            annotator.annotate(&int_expr),
1531            Some(DataType::Int {
1532                length: None,
1533                integer_spelling: false
1534            })
1535        );
1536
1537        // Float literal
1538        let float_expr = make_float_literal(3.14);
1539        assert_eq!(
1540            annotator.annotate(&float_expr),
1541            Some(DataType::Double {
1542                precision: None,
1543                scale: None
1544            })
1545        );
1546
1547        // String literal
1548        let string_expr = make_string_literal("hello");
1549        assert_eq!(
1550            annotator.annotate(&string_expr),
1551            Some(DataType::VarChar {
1552                length: None,
1553                parenthesized_length: false
1554            })
1555        );
1556
1557        // Boolean literal
1558        let bool_expr = make_bool_literal(true);
1559        assert_eq!(annotator.annotate(&bool_expr), Some(DataType::Boolean));
1560
1561        // Null literal
1562        let null_expr = Expression::Null(Null);
1563        assert_eq!(annotator.annotate(&null_expr), None);
1564    }
1565
1566    #[test]
1567    fn test_comparison_types() {
1568        let mut annotator = TypeAnnotator::new(None, None);
1569
1570        // Comparison returns boolean
1571        let cmp = Expression::Gt(Box::new(BinaryOp::new(
1572            make_int_literal(1),
1573            make_int_literal(2),
1574        )));
1575        assert_eq!(annotator.annotate(&cmp), Some(DataType::Boolean));
1576
1577        // Equality returns boolean
1578        let eq = Expression::Eq(Box::new(BinaryOp::new(
1579            make_string_literal("a"),
1580            make_string_literal("b"),
1581        )));
1582        assert_eq!(annotator.annotate(&eq), Some(DataType::Boolean));
1583    }
1584
1585    #[test]
1586    fn test_arithmetic_types() {
1587        let mut annotator = TypeAnnotator::new(None, None);
1588
1589        // Int + Int = Int
1590        let add_int = Expression::Add(Box::new(BinaryOp::new(
1591            make_int_literal(1),
1592            make_int_literal(2),
1593        )));
1594        assert_eq!(
1595            annotator.annotate(&add_int),
1596            Some(DataType::Int {
1597                length: None,
1598                integer_spelling: false
1599            })
1600        );
1601
1602        // Int + Float = Double (wider type)
1603        let add_mixed = Expression::Add(Box::new(BinaryOp::new(
1604            make_int_literal(1),
1605            make_float_literal(2.5), // Use 2.5 so the string has a decimal point
1606        )));
1607        assert_eq!(
1608            annotator.annotate(&add_mixed),
1609            Some(DataType::Double {
1610                precision: None,
1611                scale: None
1612            })
1613        );
1614    }
1615
1616    #[test]
1617    fn test_string_concat_type() {
1618        let mut annotator = TypeAnnotator::new(None, None);
1619
1620        // String || String = VarChar
1621        let concat = Expression::Concat(Box::new(BinaryOp::new(
1622            make_string_literal("hello"),
1623            make_string_literal(" world"),
1624        )));
1625        assert_eq!(
1626            annotator.annotate(&concat),
1627            Some(DataType::VarChar {
1628                length: None,
1629                parenthesized_length: false
1630            })
1631        );
1632    }
1633
1634    #[test]
1635    fn test_cast_type() {
1636        let mut annotator = TypeAnnotator::new(None, None);
1637
1638        // CAST(1 AS VARCHAR)
1639        let cast = Expression::Cast(Box::new(Cast {
1640            this: make_int_literal(1),
1641            to: DataType::VarChar {
1642                length: Some(10),
1643                parenthesized_length: false,
1644            },
1645            trailing_comments: vec![],
1646            double_colon_syntax: false,
1647            format: None,
1648            default: None,
1649            inferred_type: None,
1650        }));
1651        assert_eq!(
1652            annotator.annotate(&cast),
1653            Some(DataType::VarChar {
1654                length: Some(10),
1655                parenthesized_length: false
1656            })
1657        );
1658    }
1659
1660    #[test]
1661    fn test_function_types() {
1662        let mut annotator = TypeAnnotator::new(None, None);
1663
1664        // COUNT returns BigInt
1665        let count =
1666            Expression::Function(Box::new(Function::new("COUNT", vec![make_int_literal(1)])));
1667        assert_eq!(
1668            annotator.annotate(&count),
1669            Some(DataType::BigInt { length: None })
1670        );
1671
1672        // UPPER returns VarChar
1673        let upper = Expression::Function(Box::new(Function::new(
1674            "UPPER",
1675            vec![make_string_literal("hello")],
1676        )));
1677        assert_eq!(
1678            annotator.annotate(&upper),
1679            Some(DataType::VarChar {
1680                length: None,
1681                parenthesized_length: false
1682            })
1683        );
1684
1685        // NOW returns Timestamp
1686        let now = Expression::Function(Box::new(Function::new("NOW", vec![])));
1687        assert_eq!(
1688            annotator.annotate(&now),
1689            Some(DataType::Timestamp {
1690                precision: None,
1691                timezone: false
1692            })
1693        );
1694    }
1695
1696    #[test]
1697    fn test_coalesce_type_inference() {
1698        let mut annotator = TypeAnnotator::new(None, None);
1699
1700        // COALESCE(NULL, 1) returns Int (type of first non-null arg)
1701        let coalesce = Expression::Function(Box::new(Function::new(
1702            "COALESCE",
1703            vec![Expression::Null(Null), make_int_literal(1)],
1704        )));
1705        assert_eq!(
1706            annotator.annotate(&coalesce),
1707            Some(DataType::Int {
1708                length: None,
1709                integer_spelling: false
1710            })
1711        );
1712    }
1713
1714    #[test]
1715    fn test_type_coercion_class() {
1716        // Text types
1717        assert_eq!(
1718            TypeCoercionClass::from_data_type(&DataType::VarChar {
1719                length: None,
1720                parenthesized_length: false
1721            }),
1722            Some(TypeCoercionClass::Text)
1723        );
1724        assert_eq!(
1725            TypeCoercionClass::from_data_type(&DataType::Text),
1726            Some(TypeCoercionClass::Text)
1727        );
1728
1729        // Numeric types
1730        assert_eq!(
1731            TypeCoercionClass::from_data_type(&DataType::Int {
1732                length: None,
1733                integer_spelling: false
1734            }),
1735            Some(TypeCoercionClass::Numeric)
1736        );
1737        assert_eq!(
1738            TypeCoercionClass::from_data_type(&DataType::Double {
1739                precision: None,
1740                scale: None
1741            }),
1742            Some(TypeCoercionClass::Numeric)
1743        );
1744
1745        // Timelike types
1746        assert_eq!(
1747            TypeCoercionClass::from_data_type(&DataType::Date),
1748            Some(TypeCoercionClass::Timelike)
1749        );
1750        assert_eq!(
1751            TypeCoercionClass::from_data_type(&DataType::Timestamp {
1752                precision: None,
1753                timezone: false
1754            }),
1755            Some(TypeCoercionClass::Timelike)
1756        );
1757
1758        // Unknown types
1759        assert_eq!(TypeCoercionClass::from_data_type(&DataType::Json), None);
1760    }
1761
1762    #[test]
1763    fn test_wider_numeric_type() {
1764        let annotator = TypeAnnotator::new(None, None);
1765
1766        // Int vs BigInt -> BigInt
1767        let result = annotator.wider_numeric_type(
1768            &DataType::Int {
1769                length: None,
1770                integer_spelling: false,
1771            },
1772            &DataType::BigInt { length: None },
1773        );
1774        assert_eq!(result, DataType::BigInt { length: None });
1775
1776        // Float vs Double -> Double
1777        let result = annotator.wider_numeric_type(
1778            &DataType::Float {
1779                precision: None,
1780                scale: None,
1781                real_spelling: false,
1782            },
1783            &DataType::Double {
1784                precision: None,
1785                scale: None,
1786            },
1787        );
1788        assert_eq!(
1789            result,
1790            DataType::Double {
1791                precision: None,
1792                scale: None
1793            }
1794        );
1795
1796        // Int vs Double -> Double
1797        let result = annotator.wider_numeric_type(
1798            &DataType::Int {
1799                length: None,
1800                integer_spelling: false,
1801            },
1802            &DataType::Double {
1803                precision: None,
1804                scale: None,
1805            },
1806        );
1807        assert_eq!(
1808            result,
1809            DataType::Double {
1810                precision: None,
1811                scale: None
1812            }
1813        );
1814    }
1815
1816    #[test]
1817    fn test_aggregate_return_types() {
1818        let mut annotator = TypeAnnotator::new(None, None);
1819
1820        // SUM(int) returns BigInt
1821        let sum_type = annotator.get_aggregate_return_type("SUM", &[make_int_literal(1)]);
1822        assert_eq!(sum_type, Some(DataType::BigInt { length: None }));
1823
1824        // AVG always returns Double
1825        let avg_type = annotator.get_aggregate_return_type("AVG", &[make_int_literal(1)]);
1826        assert_eq!(
1827            avg_type,
1828            Some(DataType::Double {
1829                precision: None,
1830                scale: None
1831            })
1832        );
1833
1834        // MIN/MAX preserve input type
1835        let min_type = annotator.get_aggregate_return_type("MIN", &[make_string_literal("a")]);
1836        assert_eq!(
1837            min_type,
1838            Some(DataType::VarChar {
1839                length: None,
1840                parenthesized_length: false
1841            })
1842        );
1843    }
1844
1845    #[test]
1846    fn test_date_literal_types() {
1847        let mut annotator = TypeAnnotator::new(None, None);
1848
1849        // DATE literal
1850        let date_expr = Expression::Literal(Box::new(Literal::Date("2024-01-15".to_string())));
1851        assert_eq!(annotator.annotate(&date_expr), Some(DataType::Date));
1852
1853        // TIME literal
1854        let time_expr = Expression::Literal(Box::new(Literal::Time("10:30:00".to_string())));
1855        assert_eq!(
1856            annotator.annotate(&time_expr),
1857            Some(DataType::Time {
1858                precision: None,
1859                timezone: false
1860            })
1861        );
1862
1863        // TIMESTAMP literal
1864        let ts_expr = Expression::Literal(Box::new(Literal::Timestamp(
1865            "2024-01-15 10:30:00".to_string(),
1866        )));
1867        assert_eq!(
1868            annotator.annotate(&ts_expr),
1869            Some(DataType::Timestamp {
1870                precision: None,
1871                timezone: false
1872            })
1873        );
1874    }
1875
1876    #[test]
1877    fn test_logical_operations() {
1878        let mut annotator = TypeAnnotator::new(None, None);
1879
1880        // AND returns boolean
1881        let and_expr = Expression::And(Box::new(BinaryOp::new(
1882            make_bool_literal(true),
1883            make_bool_literal(false),
1884        )));
1885        assert_eq!(annotator.annotate(&and_expr), Some(DataType::Boolean));
1886
1887        // OR returns boolean
1888        let or_expr = Expression::Or(Box::new(BinaryOp::new(
1889            make_bool_literal(true),
1890            make_bool_literal(false),
1891        )));
1892        assert_eq!(annotator.annotate(&or_expr), Some(DataType::Boolean));
1893
1894        // NOT returns boolean
1895        let not_expr = Expression::Not(Box::new(crate::expressions::UnaryOp::new(
1896            make_bool_literal(true),
1897        )));
1898        assert_eq!(annotator.annotate(&not_expr), Some(DataType::Boolean));
1899    }
1900
1901    // ========================================
1902    // Tests for newly implemented features
1903    // ========================================
1904
1905    #[test]
1906    fn test_subscript_array_type() {
1907        let mut annotator = TypeAnnotator::new(None, None);
1908
1909        // Array[index] returns element type
1910        let arr = Expression::Array(Box::new(crate::expressions::Array {
1911            expressions: vec![make_int_literal(1), make_int_literal(2)],
1912        }));
1913        let subscript = Expression::Subscript(Box::new(crate::expressions::Subscript {
1914            this: arr,
1915            index: make_int_literal(0),
1916        }));
1917        assert_eq!(
1918            annotator.annotate(&subscript),
1919            Some(DataType::Int {
1920                length: None,
1921                integer_spelling: false
1922            })
1923        );
1924    }
1925
1926    #[test]
1927    fn test_subscript_map_type() {
1928        let mut annotator = TypeAnnotator::new(None, None);
1929
1930        // Map[key] returns value type
1931        let map = Expression::Map(Box::new(crate::expressions::Map {
1932            keys: vec![make_string_literal("a")],
1933            values: vec![make_int_literal(1)],
1934        }));
1935        let subscript = Expression::Subscript(Box::new(crate::expressions::Subscript {
1936            this: map,
1937            index: make_string_literal("a"),
1938        }));
1939        assert_eq!(
1940            annotator.annotate(&subscript),
1941            Some(DataType::Int {
1942                length: None,
1943                integer_spelling: false
1944            })
1945        );
1946    }
1947
1948    #[test]
1949    fn test_struct_type() {
1950        let mut annotator = TypeAnnotator::new(None, None);
1951
1952        // STRUCT literal
1953        let struct_expr = Expression::Struct(Box::new(crate::expressions::Struct {
1954            fields: vec![
1955                (Some("name".to_string()), make_string_literal("Alice")),
1956                (Some("age".to_string()), make_int_literal(30)),
1957            ],
1958        }));
1959        let result = annotator.annotate(&struct_expr);
1960        assert!(matches!(result, Some(DataType::Struct { fields, .. }) if fields.len() == 2));
1961    }
1962
1963    #[test]
1964    fn test_map_type() {
1965        let mut annotator = TypeAnnotator::new(None, None);
1966
1967        // MAP literal
1968        let map_expr = Expression::Map(Box::new(crate::expressions::Map {
1969            keys: vec![make_string_literal("a"), make_string_literal("b")],
1970            values: vec![make_int_literal(1), make_int_literal(2)],
1971        }));
1972        let result = annotator.annotate(&map_expr);
1973        assert!(matches!(
1974            result,
1975            Some(DataType::Map { key_type, value_type })
1976            if matches!(*key_type, DataType::VarChar { .. })
1977               && matches!(*value_type, DataType::Int { .. })
1978        ));
1979    }
1980
1981    #[test]
1982    fn test_explode_array_type() {
1983        let mut annotator = TypeAnnotator::new(None, None);
1984
1985        // EXPLODE(array) returns element type
1986        let arr = Expression::Array(Box::new(crate::expressions::Array {
1987            expressions: vec![make_int_literal(1), make_int_literal(2)],
1988        }));
1989        let explode = Expression::Explode(Box::new(crate::expressions::UnaryFunc {
1990            this: arr,
1991            original_name: None,
1992            inferred_type: None,
1993        }));
1994        assert_eq!(
1995            annotator.annotate(&explode),
1996            Some(DataType::Int {
1997                length: None,
1998                integer_spelling: false
1999            })
2000        );
2001    }
2002
2003    #[test]
2004    fn test_unnest_array_type() {
2005        let mut annotator = TypeAnnotator::new(None, None);
2006
2007        // UNNEST(array) returns element type
2008        let arr = Expression::Array(Box::new(crate::expressions::Array {
2009            expressions: vec![make_string_literal("a"), make_string_literal("b")],
2010        }));
2011        let unnest = Expression::Unnest(Box::new(crate::expressions::UnnestFunc {
2012            this: arr,
2013            expressions: Vec::new(),
2014            with_ordinality: false,
2015            alias: None,
2016            offset_alias: None,
2017        }));
2018        assert_eq!(
2019            annotator.annotate(&unnest),
2020            Some(DataType::VarChar {
2021                length: None,
2022                parenthesized_length: false
2023            })
2024        );
2025    }
2026
2027    #[test]
2028    fn test_set_operation_type() {
2029        let mut annotator = TypeAnnotator::new(None, None);
2030
2031        // UNION/INTERSECT/EXCEPT return None (they produce relations, not scalars)
2032        let select = Expression::Select(Box::new(crate::expressions::Select::default()));
2033        let union = Expression::Union(Box::new(crate::expressions::Union {
2034            left: select.clone(),
2035            right: select.clone(),
2036            all: false,
2037            distinct: false,
2038            with: None,
2039            order_by: None,
2040            limit: None,
2041            offset: None,
2042            by_name: false,
2043            side: None,
2044            kind: None,
2045            corresponding: false,
2046            strict: false,
2047            on_columns: Vec::new(),
2048            distribute_by: None,
2049            sort_by: None,
2050            cluster_by: None,
2051        }));
2052        assert_eq!(annotator.annotate(&union), None);
2053    }
2054
2055    #[test]
2056    fn test_floor_ceil_input_dependent_types() {
2057        use crate::expressions::{CeilFunc, FloorFunc};
2058
2059        let mut annotator = TypeAnnotator::new(None, None);
2060
2061        // FLOOR/CEIL with integer literal → Double (integers get promoted)
2062        let floor_int = Expression::Floor(Box::new(FloorFunc {
2063            this: make_int_literal(42),
2064            scale: None,
2065            to: None,
2066        }));
2067        assert_eq!(
2068            annotator.annotate(&floor_int),
2069            Some(DataType::Double {
2070                precision: None,
2071                scale: None,
2072            })
2073        );
2074
2075        let ceil_int = Expression::Ceil(Box::new(CeilFunc {
2076            this: make_int_literal(42),
2077            decimals: None,
2078            to: None,
2079        }));
2080        assert_eq!(
2081            annotator.annotate(&ceil_int),
2082            Some(DataType::Double {
2083                precision: None,
2084                scale: None,
2085            })
2086        );
2087
2088        // FLOOR with float literal → Double (literals are always Double)
2089        let floor_float = Expression::Floor(Box::new(FloorFunc {
2090            this: make_float_literal(3.14),
2091            scale: None,
2092            to: None,
2093        }));
2094        assert_eq!(
2095            annotator.annotate(&floor_float),
2096            Some(DataType::Double {
2097                precision: None,
2098                scale: None,
2099            })
2100        );
2101
2102        // FLOOR via Function("FLOOR") path → falls through to arg-based inference
2103        let floor_fn =
2104            Expression::Function(Box::new(Function::new("FLOOR", vec![make_int_literal(1)])));
2105        assert_eq!(
2106            annotator.annotate(&floor_fn),
2107            Some(DataType::Int {
2108                length: None,
2109                integer_spelling: false,
2110            })
2111        );
2112    }
2113
2114    #[test]
2115    fn test_sign_preserves_input_type() {
2116        use crate::expressions::UnaryFunc;
2117
2118        let mut annotator = TypeAnnotator::new(None, None);
2119
2120        // SIGN with integer literal → Int (preserves input type)
2121        let sign_int = Expression::Sign(Box::new(UnaryFunc {
2122            this: make_int_literal(42),
2123            original_name: None,
2124            inferred_type: None,
2125        }));
2126        assert_eq!(
2127            annotator.annotate(&sign_int),
2128            Some(DataType::Int {
2129                length: None,
2130                integer_spelling: false,
2131            })
2132        );
2133
2134        // SIGN with float literal → Double (preserves input type)
2135        let sign_float = Expression::Sign(Box::new(UnaryFunc {
2136            this: make_float_literal(3.14),
2137            original_name: None,
2138            inferred_type: None,
2139        }));
2140        assert_eq!(
2141            annotator.annotate(&sign_float),
2142            Some(DataType::Double {
2143                precision: None,
2144                scale: None,
2145            })
2146        );
2147
2148        // SIGN with a CAST to INT → Int (preserves input type)
2149        let sign_cast = Expression::Sign(Box::new(UnaryFunc {
2150            this: Expression::Cast(Box::new(Cast {
2151                this: make_int_literal(42),
2152                to: DataType::Int {
2153                    length: None,
2154                    integer_spelling: false,
2155                },
2156                format: None,
2157                trailing_comments: Vec::new(),
2158                double_colon_syntax: false,
2159                default: None,
2160                inferred_type: None,
2161            })),
2162            original_name: None,
2163            inferred_type: None,
2164        }));
2165        assert_eq!(
2166            annotator.annotate(&sign_cast),
2167            Some(DataType::Int {
2168                length: None,
2169                integer_spelling: false,
2170            })
2171        );
2172    }
2173
2174    #[test]
2175    fn test_date_format_types() {
2176        use crate::expressions::{DateFormatFunc, TimeToStr};
2177
2178        let mut annotator = TypeAnnotator::new(None, None);
2179
2180        // DateFormat → VarChar
2181        let date_fmt = Expression::DateFormat(Box::new(DateFormatFunc {
2182            this: make_string_literal("2024-01-01"),
2183            format: make_string_literal("%Y-%m-%d"),
2184        }));
2185        assert_eq!(
2186            annotator.annotate(&date_fmt),
2187            Some(DataType::VarChar {
2188                length: None,
2189                parenthesized_length: false,
2190            })
2191        );
2192
2193        // FormatDate → VarChar
2194        let format_date = Expression::FormatDate(Box::new(DateFormatFunc {
2195            this: make_string_literal("2024-01-01"),
2196            format: make_string_literal("%Y-%m-%d"),
2197        }));
2198        assert_eq!(
2199            annotator.annotate(&format_date),
2200            Some(DataType::VarChar {
2201                length: None,
2202                parenthesized_length: false,
2203            })
2204        );
2205
2206        // TimeToStr → VarChar
2207        let time_to_str = Expression::TimeToStr(Box::new(TimeToStr {
2208            this: Box::new(make_string_literal("2024-01-01")),
2209            format: "%Y-%m-%d".to_string(),
2210            culture: None,
2211            zone: None,
2212        }));
2213        assert_eq!(
2214            annotator.annotate(&time_to_str),
2215            Some(DataType::VarChar {
2216                length: None,
2217                parenthesized_length: false,
2218            })
2219        );
2220
2221        // DATE_FORMAT via Function path → VarChar (uses function_return_types)
2222        let date_fmt_fn = Expression::Function(Box::new(Function::new(
2223            "DATE_FORMAT",
2224            vec![
2225                make_string_literal("2024-01-01"),
2226                make_string_literal("%Y-%m-%d"),
2227            ],
2228        )));
2229        assert_eq!(
2230            annotator.annotate(&date_fmt_fn),
2231            Some(DataType::VarChar {
2232                length: None,
2233                parenthesized_length: false,
2234            })
2235        );
2236    }
2237
2238    // ===== In-place annotation tests (Step 9) =====
2239
2240    #[test]
2241    fn test_annotate_in_place_sets_type_on_root() {
2242        // Literals don't have inferred_type field, so test with a BinaryOp
2243        let mut expr = Expression::Add(Box::new(BinaryOp::new(
2244            make_int_literal(1),
2245            make_int_literal(2),
2246        )));
2247        annotate_types(&mut expr, None, None);
2248        assert_eq!(
2249            expr.inferred_type(),
2250            Some(&DataType::Int {
2251                length: None,
2252                integer_spelling: false,
2253            })
2254        );
2255    }
2256
2257    #[test]
2258    fn test_annotate_in_place_sets_types_on_children() {
2259        // (a + b) + (c - d) where all are ints
2260        // This tests that inner BinaryOp children also get annotated
2261        let inner_add = Expression::Add(Box::new(BinaryOp::new(
2262            make_int_literal(1),
2263            make_float_literal(2.5),
2264        )));
2265        let inner_sub = Expression::Sub(Box::new(BinaryOp::new(
2266            make_int_literal(3),
2267            make_int_literal(4),
2268        )));
2269        let mut expr = Expression::Add(Box::new(BinaryOp::new(inner_add, inner_sub)));
2270        annotate_types(&mut expr, None, None);
2271
2272        // Root (Add) should be Double (wider of Double and Int)
2273        assert_eq!(
2274            expr.inferred_type(),
2275            Some(&DataType::Double {
2276                precision: None,
2277                scale: None,
2278            })
2279        );
2280
2281        // Children should also have types
2282        if let Expression::Add(op) = &expr {
2283            // Left child (1 + 2.5) should be Double
2284            assert_eq!(
2285                op.left.inferred_type(),
2286                Some(&DataType::Double {
2287                    precision: None,
2288                    scale: None,
2289                })
2290            );
2291            // Right child (3 - 4) should be Int
2292            assert_eq!(
2293                op.right.inferred_type(),
2294                Some(&DataType::Int {
2295                    length: None,
2296                    integer_spelling: false,
2297                })
2298            );
2299        } else {
2300            panic!("Expected Add expression");
2301        }
2302    }
2303
2304    #[test]
2305    fn test_annotate_in_place_comparison() {
2306        let mut expr = Expression::Eq(Box::new(BinaryOp::new(
2307            make_int_literal(1),
2308            make_int_literal(2),
2309        )));
2310        annotate_types(&mut expr, None, None);
2311        assert_eq!(expr.inferred_type(), Some(&DataType::Boolean));
2312    }
2313
2314    #[test]
2315    fn test_annotate_in_place_cast() {
2316        let mut expr = Expression::Cast(Box::new(Cast {
2317            this: make_int_literal(42),
2318            to: DataType::VarChar {
2319                length: None,
2320                parenthesized_length: false,
2321            },
2322            trailing_comments: vec![],
2323            double_colon_syntax: false,
2324            format: None,
2325            default: None,
2326            inferred_type: None,
2327        }));
2328        annotate_types(&mut expr, None, None);
2329        assert_eq!(
2330            expr.inferred_type(),
2331            Some(&DataType::VarChar {
2332                length: None,
2333                parenthesized_length: false,
2334            })
2335        );
2336    }
2337
2338    #[test]
2339    fn test_annotate_in_place_nested_expression() {
2340        // (1 + 2) > 0  -> should be Boolean at root, Int for the Add
2341        let add = Expression::Add(Box::new(BinaryOp::new(
2342            make_int_literal(1),
2343            make_int_literal(2),
2344        )));
2345        let mut expr = Expression::Gt(Box::new(BinaryOp::new(add, make_int_literal(0))));
2346        annotate_types(&mut expr, None, None);
2347
2348        assert_eq!(expr.inferred_type(), Some(&DataType::Boolean));
2349
2350        // The left child (Add) should be Int
2351        if let Expression::Gt(op) = &expr {
2352            assert_eq!(
2353                op.left.inferred_type(),
2354                Some(&DataType::Int {
2355                    length: None,
2356                    integer_spelling: false,
2357                })
2358            );
2359        }
2360    }
2361
2362    #[test]
2363    fn test_annotate_in_place_parsed_sql() {
2364        use crate::parser::Parser;
2365        let mut expr =
2366            Parser::parse_sql("SELECT 1 + 2.0, 'hello', TRUE").expect("parse failed")[0].clone();
2367        annotate_types(&mut expr, None, None);
2368
2369        // The expression tree should have types annotated throughout
2370        // We can't easily inspect deep inside a parsed Select, but at minimum
2371        // the root Select itself won't have a type (it's not value-producing)
2372        assert!(expr.inferred_type().is_none());
2373    }
2374
2375    #[test]
2376    fn test_inferred_type_json_roundtrip() {
2377        let mut expr = Expression::Add(Box::new(BinaryOp::new(
2378            make_int_literal(1),
2379            make_int_literal(2),
2380        )));
2381        annotate_types(&mut expr, None, None);
2382
2383        // Serialize to JSON
2384        let json = serde_json::to_string(&expr).expect("serialize failed");
2385        // The JSON should contain the inferred_type
2386        assert!(json.contains("inferred_type"));
2387
2388        // Deserialize back
2389        let deserialized: Expression = serde_json::from_str(&json).expect("deserialize failed");
2390        assert_eq!(
2391            deserialized.inferred_type(),
2392            Some(&DataType::Int {
2393                length: None,
2394                integer_spelling: false,
2395            })
2396        );
2397    }
2398
2399    #[test]
2400    fn test_inferred_type_none_not_serialized() {
2401        // When inferred_type is None, it should not appear in JSON
2402        let expr = Expression::Add(Box::new(BinaryOp::new(
2403            make_int_literal(1),
2404            make_int_literal(2),
2405        )));
2406        let json = serde_json::to_string(&expr).expect("serialize failed");
2407        assert!(!json.contains("inferred_type"));
2408    }
2409
2410    #[test]
2411    fn test_annotate_if_func_bigquery_node_and_alias_type() {
2412        let mut schema = MappingSchema::with_dialect(DialectType::BigQuery);
2413        schema
2414            .add_table(
2415                "t",
2416                &[("col1".to_string(), DataType::String { length: None })],
2417                None,
2418            )
2419            .unwrap();
2420
2421        let mut expr = parse_one(
2422            "SELECT IF(col1 IS NOT NULL, 1, 0) AS x FROM t",
2423            DialectType::BigQuery,
2424        )
2425        .unwrap();
2426        annotate_types(&mut expr, Some(&schema), Some(DialectType::BigQuery));
2427
2428        let Expression::Select(select) = &expr else {
2429            panic!("expected select");
2430        };
2431        let Expression::Alias(alias) = &select.expressions[0] else {
2432            panic!("expected alias");
2433        };
2434
2435        assert_eq!(
2436            alias.this.inferred_type(),
2437            Some(&DataType::Int {
2438                length: None,
2439                integer_spelling: false,
2440            })
2441        );
2442        assert_eq!(
2443            select.expressions[0].inferred_type(),
2444            Some(&DataType::Int {
2445                length: None,
2446                integer_spelling: false,
2447            })
2448        );
2449    }
2450
2451    #[test]
2452    fn test_annotate_nvl2_node_type() {
2453        let mut expr = parse_one("SELECT NVL2(a, 1, 0) AS x", DialectType::Generic).unwrap();
2454        annotate_types(&mut expr, None, None);
2455
2456        let Expression::Select(select) = &expr else {
2457            panic!("expected select");
2458        };
2459        let Expression::Alias(alias) = &select.expressions[0] else {
2460            panic!("expected alias");
2461        };
2462
2463        assert_eq!(
2464            alias.this.inferred_type(),
2465            Some(&DataType::Int {
2466                length: None,
2467                integer_spelling: false,
2468            })
2469        );
2470    }
2471
2472    #[test]
2473    fn test_annotate_count_node_type() {
2474        let mut expr = parse_one("SELECT COUNT(1) AS x", DialectType::Generic).unwrap();
2475        annotate_types(&mut expr, None, None);
2476
2477        let Expression::Select(select) = &expr else {
2478            panic!("expected select");
2479        };
2480        let Expression::Alias(alias) = &select.expressions[0] else {
2481            panic!("expected alias");
2482        };
2483
2484        assert_eq!(
2485            alias.this.inferred_type(),
2486            Some(&DataType::BigInt { length: None })
2487        );
2488    }
2489
2490    #[test]
2491    fn test_annotate_group_concat_node_type() {
2492        let mut expr = parse_one("SELECT GROUP_CONCAT(a) AS x", DialectType::Generic).unwrap();
2493        annotate_types(&mut expr, None, None);
2494
2495        let Expression::Select(select) = &expr else {
2496            panic!("expected select");
2497        };
2498        let Expression::Alias(alias) = &select.expressions[0] else {
2499            panic!("expected alias");
2500        };
2501
2502        assert_eq!(
2503            alias.this.inferred_type(),
2504            Some(&DataType::VarChar {
2505                length: None,
2506                parenthesized_length: false,
2507            })
2508        );
2509    }
2510
2511    #[test]
2512    fn test_annotate_sum_if_generic_aggregate_type() {
2513        let mut expr =
2514            parse_one("SELECT SUM_IF(1, a > 0) AS x FROM t", DialectType::Generic).unwrap();
2515        annotate_types(&mut expr, None, None);
2516
2517        let Expression::Select(select) = &expr else {
2518            panic!("expected select");
2519        };
2520        let Expression::Alias(alias) = &select.expressions[0] else {
2521            panic!("expected alias");
2522        };
2523
2524        assert_eq!(
2525            select.expressions[0].inferred_type(),
2526            Some(&DataType::BigInt { length: None })
2527        );
2528        assert_eq!(
2529            alias.this.inferred_type(),
2530            Some(&DataType::BigInt { length: None })
2531        );
2532    }
2533}