Skip to main content

elo_rust/codegen/
type_inference.rs

1//! Type inference system for ELO expressions
2//!
3//! Infers types for expressions to enable better error checking and code generation.
4//! Uses a simple bidirectional type inference approach.
5
6use crate::ast::{BinaryOperator, Expr, Literal, TemporalKeyword, UnaryOperator, Visitor};
7use std::fmt;
8
9/// Inferred type of an ELO expression
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum InferredType {
12    /// Integer type
13    Integer,
14
15    /// Float type
16    Float,
17
18    /// String type
19    String,
20
21    /// Boolean type
22    Boolean,
23
24    /// Null/None type
25    Null,
26
27    /// Array with element type
28    Array(Box<InferredType>),
29
30    /// Object with field types (simplified - just track it's an object)
31    Object,
32
33    /// Date type
34    Date,
35
36    /// DateTime type
37    DateTime,
38
39    /// Duration type
40    Duration,
41
42    /// Unknown type (when inference fails or type can't be determined)
43    Unknown,
44
45    /// Any numeric type (integer or float)
46    Numeric,
47
48    /// Type error - incompatible types
49    Error(String),
50}
51
52impl fmt::Display for InferredType {
53    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54        match self {
55            Self::Integer => write!(f, "integer"),
56            Self::Float => write!(f, "float"),
57            Self::String => write!(f, "string"),
58            Self::Boolean => write!(f, "boolean"),
59            Self::Null => write!(f, "null"),
60            Self::Array(elem_type) => write!(f, "[{}]", elem_type),
61            Self::Object => write!(f, "object"),
62            Self::Date => write!(f, "date"),
63            Self::DateTime => write!(f, "datetime"),
64            Self::Duration => write!(f, "duration"),
65            Self::Unknown => write!(f, "unknown"),
66            Self::Numeric => write!(f, "number"),
67            Self::Error(msg) => write!(f, "error({})", msg),
68        }
69    }
70}
71
72impl InferredType {
73    /// Check if this is a numeric type
74    pub fn is_numeric(&self) -> bool {
75        matches!(self, Self::Integer | Self::Float | Self::Numeric)
76    }
77
78    /// Check if this is a scalar type
79    pub fn is_scalar(&self) -> bool {
80        matches!(
81            self,
82            Self::Integer | Self::Float | Self::String | Self::Boolean
83        )
84    }
85
86    /// Check if this is an error type
87    pub fn is_error(&self) -> bool {
88        matches!(self, Self::Error(_))
89    }
90
91    /// Get a common type that's compatible with both types
92    pub fn common_type(a: &InferredType, b: &InferredType) -> InferredType {
93        match (a, b) {
94            // Same types unify to themselves
95            (t1, t2) if t1 == t2 => t1.clone(),
96
97            // Numeric type unification
98            (InferredType::Integer, InferredType::Float)
99            | (InferredType::Float, InferredType::Integer) => InferredType::Float,
100            (InferredType::Numeric, other) | (other, InferredType::Numeric) => {
101                if other.is_numeric() {
102                    other.clone()
103                } else {
104                    InferredType::Error(format!(
105                        "Type mismatch: cannot unify numeric and {}",
106                        other
107                    ))
108                }
109            }
110
111            // Unknown can unify with anything
112            (InferredType::Unknown, t) | (t, InferredType::Unknown) => t.clone(),
113
114            // Array types unify on element type
115            (InferredType::Array(a), InferredType::Array(b)) => {
116                let elem_type = Self::common_type(a, b);
117                InferredType::Array(Box::new(elem_type))
118            }
119
120            // Otherwise, type mismatch
121            (a, b) => InferredType::Error(format!("Type mismatch: cannot unify {} and {}", a, b)),
122        }
123    }
124}
125
126/// Type inference visitor
127///
128/// Analyzes expressions and infers their types.
129/// Returns the inferred type for each expression.
130#[derive(Debug)]
131pub struct TypeInferenceVisitor;
132
133impl TypeInferenceVisitor {
134    /// Create a new type inference visitor
135    pub fn new() -> Self {
136        TypeInferenceVisitor
137    }
138
139    /// Infer the type of an expression
140    pub fn infer(&self, expr: &Expr) -> InferredType {
141        Self::infer_expr(expr)
142    }
143
144    /// Helper function to infer expression type without mut self
145    fn infer_expr(expr: &Expr) -> InferredType {
146        match expr {
147            Expr::Literal(lit) => match lit {
148                Literal::Integer(_) => InferredType::Integer,
149                Literal::Float(_) => InferredType::Float,
150                Literal::Boolean(_) => InferredType::Boolean,
151            },
152            Expr::Null => InferredType::Null,
153            Expr::Identifier(_) => InferredType::Unknown,
154            Expr::String(_) => InferredType::String,
155            Expr::FieldAccess { .. } => InferredType::Unknown,
156            Expr::BinaryOp { op, left, right } => Self::infer_binary_op(*op, left, right),
157            Expr::UnaryOp { op, operand } => Self::infer_unary_op(*op, operand),
158            Expr::FunctionCall { name, args } => Self::infer_function_call(name, args),
159            Expr::Lambda { .. } => InferredType::Unknown,
160            Expr::Let { body, .. } => Self::infer_expr(body),
161            Expr::If {
162                then_branch,
163                else_branch,
164                ..
165            } => {
166                let then_type = Self::infer_expr(then_branch);
167                let else_type = Self::infer_expr(else_branch);
168                InferredType::common_type(&then_type, &else_type)
169            }
170            Expr::Array(elements) => {
171                if elements.is_empty() {
172                    InferredType::Array(Box::new(InferredType::Unknown))
173                } else {
174                    let first_type = Self::infer_expr(&elements[0]);
175                    let mut common = first_type;
176                    for elem in &elements[1..] {
177                        let elem_type = Self::infer_expr(elem);
178                        common = InferredType::common_type(&common, &elem_type);
179                        if common.is_error() {
180                            break;
181                        }
182                    }
183                    InferredType::Array(Box::new(common))
184                }
185            }
186            Expr::Object(_) => InferredType::Object,
187            Expr::Pipe { functions, .. } => {
188                if functions.is_empty() {
189                    InferredType::Unknown
190                } else {
191                    Self::infer_expr(functions.last().unwrap())
192                }
193            }
194            Expr::Alternative {
195                primary,
196                alternative,
197            } => {
198                let primary_type = Self::infer_expr(primary);
199                let alt_type = Self::infer_expr(alternative);
200                InferredType::common_type(&primary_type, &alt_type)
201            }
202            Expr::Guard { body, .. } => Self::infer_expr(body),
203            Expr::Date(_) => InferredType::Date,
204            Expr::DateTime(_) => InferredType::DateTime,
205            Expr::Duration(_) => InferredType::Duration,
206            Expr::TemporalKeyword(keyword) => match keyword {
207                TemporalKeyword::Now => InferredType::DateTime,
208                TemporalKeyword::Today | TemporalKeyword::Tomorrow | TemporalKeyword::Yesterday => {
209                    InferredType::Date
210                }
211                _ => InferredType::Date, // Boundary operations return dates
212            },
213        }
214    }
215
216    fn infer_binary_op(op: BinaryOperator, left: &Expr, right: &Expr) -> InferredType {
217        let left_type = Self::infer_expr(left);
218        let right_type = Self::infer_expr(right);
219
220        match op {
221            BinaryOperator::Add => match (&left_type, &right_type) {
222                (InferredType::Integer, InferredType::Integer) => InferredType::Integer,
223                (InferredType::Float, InferredType::Float) => InferredType::Float,
224                (InferredType::Integer, InferredType::Float)
225                | (InferredType::Float, InferredType::Integer) => InferredType::Float,
226                (InferredType::String, InferredType::String) => InferredType::String,
227                // Temporal arithmetic: date/datetime + duration
228                (InferredType::Date, InferredType::Duration)
229                | (InferredType::Duration, InferredType::Date) => InferredType::Date,
230                (InferredType::DateTime, InferredType::Duration)
231                | (InferredType::Duration, InferredType::DateTime) => InferredType::DateTime,
232                (InferredType::Duration, InferredType::Duration) => InferredType::Duration,
233                // Handle Unknown by returning the other type
234                (InferredType::Unknown, t) | (t, InferredType::Unknown) => t.clone(),
235                _ => InferredType::Error(format!("Cannot add {} and {}", left_type, right_type)),
236            },
237            BinaryOperator::Sub => match (&left_type, &right_type) {
238                (InferredType::Integer, InferredType::Integer) => InferredType::Integer,
239                (InferredType::Float, InferredType::Float) => InferredType::Float,
240                (InferredType::Integer, InferredType::Float)
241                | (InferredType::Float, InferredType::Integer) => InferredType::Float,
242                // Temporal arithmetic: date/datetime - duration = date/datetime, date - date = duration
243                (InferredType::Date, InferredType::Duration) => InferredType::Date,
244                (InferredType::DateTime, InferredType::Duration) => InferredType::DateTime,
245                (InferredType::Date, InferredType::Date) => InferredType::Duration,
246                (InferredType::DateTime, InferredType::DateTime) => InferredType::Duration,
247                (InferredType::Duration, InferredType::Duration) => InferredType::Duration,
248                // Handle Unknown
249                (InferredType::Unknown, t) | (t, InferredType::Unknown) => {
250                    if t.is_numeric() {
251                        t.clone()
252                    } else {
253                        InferredType::Error(format!(
254                            "Cannot apply arithmetic to {} and {}",
255                            left_type, right_type
256                        ))
257                    }
258                }
259                _ => InferredType::Error(format!(
260                    "Cannot apply arithmetic to {} and {}",
261                    left_type, right_type
262                )),
263            },
264            BinaryOperator::Mul | BinaryOperator::Div => {
265                match (&left_type, &right_type) {
266                    (InferredType::Integer, InferredType::Integer) => InferredType::Integer,
267                    (InferredType::Float, InferredType::Float) => InferredType::Float,
268                    (InferredType::Integer, InferredType::Float)
269                    | (InferredType::Float, InferredType::Integer) => InferredType::Float,
270                    // Handle Unknown
271                    (InferredType::Unknown, t) | (t, InferredType::Unknown) => {
272                        if t.is_numeric() {
273                            t.clone()
274                        } else {
275                            InferredType::Error(format!(
276                                "Cannot apply arithmetic to {} and {}",
277                                left_type, right_type
278                            ))
279                        }
280                    }
281                    _ => InferredType::Error(format!(
282                        "Cannot apply arithmetic to {} and {}",
283                        left_type, right_type
284                    )),
285                }
286            }
287            BinaryOperator::Mod | BinaryOperator::Pow => {
288                if left_type.is_numeric() && right_type.is_numeric() {
289                    InferredType::Integer
290                } else if left_type == InferredType::Unknown || right_type == InferredType::Unknown
291                {
292                    // If either is Unknown but the other is numeric, assume Integer result
293                    InferredType::Integer
294                } else {
295                    InferredType::Error(format!(
296                        "Cannot apply operator to {} and {}",
297                        left_type, right_type
298                    ))
299                }
300            }
301            BinaryOperator::Eq
302            | BinaryOperator::Neq
303            | BinaryOperator::Lt
304            | BinaryOperator::Lte
305            | BinaryOperator::Gt
306            | BinaryOperator::Gte => InferredType::Boolean,
307            BinaryOperator::And | BinaryOperator::Or => InferredType::Boolean,
308        }
309    }
310
311    fn infer_unary_op(op: UnaryOperator, operand: &Expr) -> InferredType {
312        let operand_type = Self::infer_expr(operand);
313        match op {
314            UnaryOperator::Not => InferredType::Boolean,
315            UnaryOperator::Neg | UnaryOperator::Plus => operand_type,
316        }
317    }
318
319    fn infer_function_call(name: &str, args: &[Expr]) -> InferredType {
320        match name {
321            "length" | "uppercase" | "lowercase" | "trim" | "contains" | "starts_with"
322            | "ends_with" => InferredType::String,
323            "map" | "filter" | "sort" => InferredType::Array(Box::new(InferredType::Unknown)),
324            "abs" | "min" | "max" | "round" | "floor" | "ceil" => {
325                if args.is_empty() {
326                    InferredType::Unknown
327                } else {
328                    let arg_type = Self::infer_expr(&args[0]);
329                    if arg_type.is_numeric() {
330                        arg_type
331                    } else {
332                        InferredType::Error(format!("Expected numeric argument, got {}", arg_type))
333                    }
334                }
335            }
336            "all" | "any" => InferredType::Boolean,
337            _ => InferredType::Unknown,
338        }
339    }
340}
341
342impl Default for TypeInferenceVisitor {
343    fn default() -> Self {
344        Self::new()
345    }
346}
347
348impl Visitor<InferredType> for TypeInferenceVisitor {
349    fn visit_expr(&mut self, expr: &Expr) -> InferredType {
350        Self::infer_expr(expr)
351    }
352
353    fn visit_literal(&mut self, lit: &Literal) -> InferredType {
354        match lit {
355            Literal::Integer(_) => InferredType::Integer,
356            Literal::Float(_) => InferredType::Float,
357            Literal::Boolean(_) => InferredType::Boolean,
358        }
359    }
360
361    fn visit_null(&mut self) -> InferredType {
362        InferredType::Null
363    }
364
365    fn visit_identifier(&mut self, _name: &str) -> InferredType {
366        InferredType::Unknown
367    }
368
369    fn visit_field_access(&mut self, _receiver: &Expr, _field: &str) -> InferredType {
370        InferredType::Unknown
371    }
372
373    fn visit_binary_op(&mut self, op: BinaryOperator, left: &Expr, right: &Expr) -> InferredType {
374        Self::infer_binary_op(op, left, right)
375    }
376
377    fn visit_unary_op(&mut self, op: UnaryOperator, operand: &Expr) -> InferredType {
378        Self::infer_unary_op(op, operand)
379    }
380
381    fn visit_function_call(&mut self, name: &str, args: &[Expr]) -> InferredType {
382        Self::infer_function_call(name, args)
383    }
384
385    fn visit_lambda(&mut self, _param: &str, _body: &Expr) -> InferredType {
386        InferredType::Unknown
387    }
388
389    fn visit_let(&mut self, _name: &str, _value: &Expr, body: &Expr) -> InferredType {
390        Self::infer_expr(body)
391    }
392
393    fn visit_if(
394        &mut self,
395        _condition: &Expr,
396        then_branch: &Expr,
397        else_branch: &Expr,
398    ) -> InferredType {
399        let then_type = Self::infer_expr(then_branch);
400        let else_type = Self::infer_expr(else_branch);
401        InferredType::common_type(&then_type, &else_type)
402    }
403
404    fn visit_array(&mut self, elements: &[Expr]) -> InferredType {
405        if elements.is_empty() {
406            InferredType::Array(Box::new(InferredType::Unknown))
407        } else {
408            let first_type = Self::infer_expr(&elements[0]);
409            let mut common = first_type;
410            for elem in &elements[1..] {
411                let elem_type = Self::infer_expr(elem);
412                common = InferredType::common_type(&common, &elem_type);
413                if common.is_error() {
414                    break;
415                }
416            }
417            InferredType::Array(Box::new(common))
418        }
419    }
420
421    fn visit_object(&mut self, _fields: &[(String, Expr)]) -> InferredType {
422        InferredType::Object
423    }
424
425    fn visit_pipe(&mut self, value: &Expr, functions: &[Expr]) -> InferredType {
426        if functions.is_empty() {
427            Self::infer_expr(value)
428        } else {
429            Self::infer_expr(functions.last().unwrap())
430        }
431    }
432
433    fn visit_alternative(&mut self, primary: &Expr, alternative: &Expr) -> InferredType {
434        let primary_type = Self::infer_expr(primary);
435        let alt_type = Self::infer_expr(alternative);
436        InferredType::common_type(&primary_type, &alt_type)
437    }
438
439    fn visit_guard(&mut self, _condition: &Expr, body: &Expr) -> InferredType {
440        Self::infer_expr(body)
441    }
442
443    fn visit_date(&mut self, _date: &str) -> InferredType {
444        InferredType::String
445    }
446
447    fn visit_datetime(&mut self, _datetime: &str) -> InferredType {
448        InferredType::String
449    }
450
451    fn visit_duration(&mut self, _duration: &str) -> InferredType {
452        InferredType::String
453    }
454
455    fn visit_temporal_keyword(&mut self, _keyword: TemporalKeyword) -> InferredType {
456        InferredType::String
457    }
458
459    fn visit_string(&mut self, _value: &str) -> InferredType {
460        InferredType::String
461    }
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467    use crate::parser::Parser;
468
469    #[test]
470    fn test_infer_integer_literal() {
471        let expr = Parser::parse("42").unwrap();
472        let ty = TypeInferenceVisitor::infer_expr(&expr);
473        assert_eq!(ty, InferredType::Integer);
474    }
475
476    #[test]
477    fn test_infer_float_literal() {
478        let expr = Parser::parse("3.14").unwrap();
479        let ty = TypeInferenceVisitor::infer_expr(&expr);
480        assert_eq!(ty, InferredType::Float);
481    }
482
483    #[test]
484    fn test_infer_string_literal() {
485        let expr = Parser::parse("'hello'").unwrap();
486        let ty = TypeInferenceVisitor::infer_expr(&expr);
487        assert_eq!(ty, InferredType::String);
488    }
489
490    #[test]
491    fn test_infer_boolean_literal() {
492        let expr = Parser::parse("true").unwrap();
493        let ty = TypeInferenceVisitor::infer_expr(&expr);
494        assert_eq!(ty, InferredType::Boolean);
495    }
496
497    #[test]
498    fn test_infer_null_literal() {
499        let expr = Parser::parse("null").unwrap();
500        let ty = TypeInferenceVisitor::infer_expr(&expr);
501        assert_eq!(ty, InferredType::Null);
502    }
503
504    #[test]
505    fn test_infer_integer_addition() {
506        let expr = Parser::parse("1 + 2").unwrap();
507        let ty = TypeInferenceVisitor::infer_expr(&expr);
508        assert_eq!(ty, InferredType::Integer);
509    }
510
511    #[test]
512    fn test_infer_float_arithmetic() {
513        let expr = Parser::parse("3.0 + 2.0").unwrap();
514        let ty = TypeInferenceVisitor::infer_expr(&expr);
515        assert_eq!(ty, InferredType::Float);
516    }
517
518    #[test]
519    fn test_infer_mixed_numeric() {
520        let expr = Parser::parse("1 + 2.0").unwrap();
521        let ty = TypeInferenceVisitor::infer_expr(&expr);
522        assert_eq!(ty, InferredType::Float);
523    }
524
525    #[test]
526    fn test_infer_comparison() {
527        let expr = Parser::parse("5 > 3").unwrap();
528        let ty = TypeInferenceVisitor::infer_expr(&expr);
529        assert_eq!(ty, InferredType::Boolean);
530    }
531
532    #[test]
533    fn test_infer_logical_and() {
534        let expr = Parser::parse("true && false").unwrap();
535        let ty = TypeInferenceVisitor::infer_expr(&expr);
536        assert_eq!(ty, InferredType::Boolean);
537    }
538
539    #[test]
540    fn test_infer_array_integers() {
541        let expr = Parser::parse("[1, 2, 3]").unwrap();
542        let ty = TypeInferenceVisitor::infer_expr(&expr);
543        assert_eq!(ty, InferredType::Array(Box::new(InferredType::Integer)));
544    }
545
546    #[test]
547    fn test_infer_array_mixed_numeric() {
548        let expr = Parser::parse("[1, 2.0, 3]").unwrap();
549        let ty = TypeInferenceVisitor::infer_expr(&expr);
550        assert_eq!(ty, InferredType::Array(Box::new(InferredType::Float)));
551    }
552
553    #[test]
554    fn test_infer_empty_array() {
555        let expr = Parser::parse("[]").unwrap();
556        let ty = TypeInferenceVisitor::infer_expr(&expr);
557        assert_eq!(ty, InferredType::Array(Box::new(InferredType::Unknown)));
558    }
559
560    #[test]
561    fn test_infer_if_same_types() {
562        let expr = Parser::parse("if true then 1 else 2").unwrap();
563        let ty = TypeInferenceVisitor::infer_expr(&expr);
564        assert_eq!(ty, InferredType::Integer);
565    }
566
567    #[test]
568    fn test_infer_if_different_numeric_types() {
569        let expr = Parser::parse("if true then 1 else 2.0").unwrap();
570        let ty = TypeInferenceVisitor::infer_expr(&expr);
571        assert_eq!(ty, InferredType::Float);
572    }
573
574    #[test]
575    fn test_infer_let_expression() {
576        let expr = Parser::parse("let x = 5 in x + 3").unwrap();
577        let ty = TypeInferenceVisitor::infer_expr(&expr);
578        assert_eq!(ty, InferredType::Integer);
579    }
580
581    #[test]
582    fn test_infer_unary_not() {
583        let expr = Parser::parse("!true").unwrap();
584        let ty = TypeInferenceVisitor::infer_expr(&expr);
585        assert_eq!(ty, InferredType::Boolean);
586    }
587
588    #[test]
589    fn test_infer_string_concat() {
590        let expr = Parser::parse("'hello' + ' world'").unwrap();
591        let ty = TypeInferenceVisitor::infer_expr(&expr);
592        assert_eq!(ty, InferredType::String);
593    }
594
595    #[test]
596    fn test_type_common_type_same() {
597        let t1 = InferredType::Integer;
598        let t2 = InferredType::Integer;
599        let common = InferredType::common_type(&t1, &t2);
600        assert_eq!(common, InferredType::Integer);
601    }
602
603    #[test]
604    fn test_type_common_type_numeric() {
605        let t1 = InferredType::Integer;
606        let t2 = InferredType::Float;
607        let common = InferredType::common_type(&t1, &t2);
608        assert_eq!(common, InferredType::Float);
609    }
610
611    #[test]
612    fn test_type_common_type_unknown() {
613        let t1 = InferredType::Unknown;
614        let t2 = InferredType::Integer;
615        let common = InferredType::common_type(&t1, &t2);
616        assert_eq!(common, InferredType::Integer);
617    }
618
619    #[test]
620    fn test_type_is_numeric() {
621        assert!(InferredType::Integer.is_numeric());
622        assert!(InferredType::Float.is_numeric());
623        assert!(InferredType::Numeric.is_numeric());
624        assert!(!InferredType::String.is_numeric());
625    }
626
627    #[test]
628    fn test_type_is_scalar() {
629        assert!(InferredType::Integer.is_scalar());
630        assert!(InferredType::String.is_scalar());
631        assert!(!InferredType::Array(Box::new(InferredType::Integer)).is_scalar());
632    }
633}