fusabi_frontend/
inference.rs

1//! Hindley-Milner Type Inference (Algorithm W)
2//!
3//! This module implements Layer 2 of the Fusabi type system: constraint-based type inference
4//! using the Hindley-Milner algorithm. It builds on Layer 1 (types.rs) to provide complete
5//! type checking for F# expressions.
6//!
7//! # Architecture
8//!
9//! The type inference algorithm follows the classic Hindley-Milner approach:
10//! 1. **Constraint Generation**: Traverse the AST and generate type constraints
11//! 2. **Unification**: Solve constraints using Robinson's unification algorithm
12//! 3. **Generalization**: Generalize types in let-bindings for polymorphism
13//! 4. **Instantiation**: Instantiate polymorphic type schemes with fresh variables
14//!
15//! # Example
16//!
17//! ```rust
18//! use fusabi_frontend::inference::TypeInference;
19//! use fusabi_frontend::types::TypeEnv;
20//! use fusabi_frontend::ast::{Expr, Literal};
21//!
22//! let mut inference = TypeInference::new();
23//! let env = TypeEnv::new();
24//! let expr = Expr::Lit(Literal::Int(42));
25//!
26//! let ty = inference.infer_and_solve(&expr, &env).unwrap();
27//! // ty is Type::Int
28//! ```
29//!
30//! # Key Features
31//!
32//! - **Let-polymorphism**: Automatic generalization in let-bindings
33//! - **Occurs check**: Prevents infinite types
34//! - **Pattern matching**: Full support for match expressions
35//! - **Records and variants**: Type checking for structural types
36//! - **Helpful errors**: Detailed error messages with suggestions
37//! - **Auto-recursive detection**: Automatically detects recursive lambdas (issue #126)
38
39use crate::ast::{BinOp, Expr, Literal, MatchArm, Pattern};
40use crate::error::{TypeError, TypeErrorKind};
41use crate::types::{Substitution, Type, TypeEnv, TypeScheme, TypeVar};
42use std::collections::HashMap;
43
44/// Constraint representing equality between two types.
45///
46/// During inference, we generate constraints like `t1 = t2` which are later
47/// solved through unification.
48#[derive(Debug, Clone, PartialEq)]
49pub enum Constraint {
50    /// Two types must be equal
51    Equal(Type, Type),
52}
53
54/// Type inference engine implementing Algorithm W.
55///
56/// Maintains state for fresh type variable generation and constraint accumulation.
57pub struct TypeInference {
58    /// Counter for generating fresh type variables
59    next_var_id: usize,
60    /// Accumulated type constraints
61    constraints: Vec<Constraint>,
62}
63
64#[allow(clippy::result_large_err)]
65impl TypeInference {
66    /// Create a new type inference instance.
67    pub fn new() -> Self {
68        TypeInference {
69            next_var_id: 0,
70            constraints: Vec::new(),
71        }
72    }
73
74    /// Generate a fresh type variable.
75    ///
76    /// Each call produces a unique type variable that hasn't been used before.
77    pub fn fresh_var(&mut self) -> TypeVar {
78        let id = self.next_var_id;
79        self.next_var_id += 1;
80        TypeVar::fresh(id)
81    }
82
83    /// Add a constraint to the constraint set.
84    fn add_constraint(&mut self, constraint: Constraint) {
85        self.constraints.push(constraint);
86    }
87
88    /// Check if an expression references a variable (for auto-recursion detection).
89    ///
90    /// This performs a simple free variable analysis to detect if `name` appears
91    /// anywhere in the expression. Used to automatically treat `let x = fun ... x ...`
92    /// as recursive without requiring explicit `let rec`.
93    fn expr_references_var(expr: &Expr, name: &str) -> bool {
94        match expr {
95            Expr::Var(var_name) => var_name == name,
96            Expr::Lambda { param, body } => {
97                // If the lambda parameter shadows the name, don't look inside
98                if param == name {
99                    false
100                } else {
101                    Self::expr_references_var(body, name)
102                }
103            }
104            Expr::App { func, arg } => {
105                Self::expr_references_var(func, name) || Self::expr_references_var(arg, name)
106            }
107            Expr::Let {
108                name: let_name,
109                value,
110                body,
111            } => {
112                // Check value, but if let shadows the name, don't check body
113                Self::expr_references_var(value, name)
114                    || (let_name != name && Self::expr_references_var(body, name))
115            }
116            Expr::LetRec {
117                name: rec_name,
118                value,
119                body,
120            } => {
121                // Similar to Let
122                Self::expr_references_var(value, name)
123                    || (rec_name != name && Self::expr_references_var(body, name))
124            }
125            Expr::LetRecMutual { bindings, body } => {
126                // Check all binding values
127                bindings
128                    .iter()
129                    .any(|(_, expr)| Self::expr_references_var(expr, name))
130                    // Check body unless one of the bindings shadows the name
131                    || (!bindings.iter().any(|(n, _)| n == name)
132                        && Self::expr_references_var(body, name))
133            }
134            Expr::If {
135                cond,
136                then_branch,
137                else_branch,
138            } => {
139                Self::expr_references_var(cond, name)
140                    || Self::expr_references_var(then_branch, name)
141                    || Self::expr_references_var(else_branch, name)
142            }
143            Expr::BinOp { left, right, .. } => {
144                Self::expr_references_var(left, name) || Self::expr_references_var(right, name)
145            }
146            Expr::Tuple(elements) | Expr::List(elements) | Expr::Array(elements) => {
147                elements.iter().any(|e| Self::expr_references_var(e, name))
148            }
149            Expr::Cons { head, tail } => {
150                Self::expr_references_var(head, name) || Self::expr_references_var(tail, name)
151            }
152            Expr::ArrayIndex { array, index } => {
153                Self::expr_references_var(array, name) || Self::expr_references_var(index, name)
154            }
155            Expr::ArrayUpdate {
156                array,
157                index,
158                value,
159            } => {
160                Self::expr_references_var(array, name)
161                    || Self::expr_references_var(index, name)
162                    || Self::expr_references_var(value, name)
163            }
164            Expr::ArrayLength(array) => Self::expr_references_var(array, name),
165            Expr::RecordLiteral { fields, .. } => fields
166                .iter()
167                .any(|(_, expr)| Self::expr_references_var(expr, name)),
168            Expr::RecordAccess { record, .. } => Self::expr_references_var(record, name),
169            Expr::RecordUpdate { record, fields } => {
170                Self::expr_references_var(record, name)
171                    || fields
172                        .iter()
173                        .any(|(_, expr)| Self::expr_references_var(expr, name))
174            }
175            Expr::VariantConstruct { fields, .. } => fields
176                .iter()
177                .any(|expr| Self::expr_references_var(expr, name)),
178            Expr::Match { scrutinee, arms } => {
179                Self::expr_references_var(scrutinee, name)
180                    || arms.iter().any(|arm| {
181                        // Check if pattern binds the name (shadows it)
182                        let pattern_binds = Self::pattern_binds(&arm.pattern, name);
183                        // Only check body if pattern doesn't shadow the name
184                        !pattern_binds && Self::expr_references_var(&arm.body, name)
185                    })
186            }
187            Expr::MethodCall { receiver, args, .. } => {
188                Self::expr_references_var(receiver, name)
189                    || args.iter().any(|e| Self::expr_references_var(e, name))
190            }
191            Expr::While { cond, body } => {
192                Self::expr_references_var(cond, name) || Self::expr_references_var(body, name)
193            }
194            Expr::ComputationExpr { body, .. } => {
195                // Check if any statement in the CE body references the variable
196                body.iter().any(|stmt| {
197                    use crate::ast::CEStatement;
198                    match stmt {
199                        CEStatement::Let { value, .. }
200                        | CEStatement::LetBang { value, .. }
201                        | CEStatement::DoBang { value }
202                        | CEStatement::Return { value }
203                        | CEStatement::ReturnBang { value }
204                        | CEStatement::Yield { value }
205                        | CEStatement::YieldBang { value }
206                        | CEStatement::Expr { value } => Self::expr_references_var(value, name),
207                    }
208                })
209            }
210            // Literals and control flow don't reference variables
211            Expr::Lit(_) | Expr::Break | Expr::Continue => false,
212        }
213    }
214
215    /// Check if a pattern binds a variable name.
216    fn pattern_binds(pattern: &Pattern, name: &str) -> bool {
217        match pattern {
218            Pattern::Var(var_name) => var_name == name,
219            Pattern::Tuple(patterns) | Pattern::Variant { patterns, .. } => {
220                patterns.iter().any(|p| Self::pattern_binds(p, name))
221            }
222            Pattern::Wildcard | Pattern::Literal(_) => false,
223        }
224    }
225
226    /// Infer the type of an expression in the given environment.
227    ///
228    /// This is the main entry point for type inference. It generates constraints
229    /// and returns a type (possibly containing type variables).
230    ///
231    /// # Arguments
232    ///
233    /// * `expr` - The expression to type check
234    /// * `env` - The type environment containing variable bindings
235    ///
236    /// # Returns
237    ///
238    /// The inferred type, or a type error if inference fails.
239    pub fn infer(&mut self, expr: &Expr, env: &TypeEnv) -> Result<Type, TypeError> {
240        match expr {
241            // Literals have concrete types
242            Expr::Lit(lit) => Ok(self.infer_literal(lit)),
243
244            // Variables: lookup in environment and instantiate
245            Expr::Var(name) => self.infer_var(name, env),
246
247            // Lambda: fun x -> body
248            Expr::Lambda { param, body } => self.infer_lambda(param, body, env),
249
250            // Function application: f arg
251            Expr::App { func, arg } => self.infer_app(func, arg, env),
252
253            // Let-binding: let x = value in body
254            Expr::Let { name, value, body } => self.infer_let(name, value, body, env, false),
255
256            // Recursive let-binding: let rec f = value in body
257            Expr::LetRec { name, value, body } => self.infer_let(name, value, body, env, true),
258
259            // Mutually recursive bindings: let rec f = ... and g = ... in body
260            Expr::LetRecMutual { bindings, body } => self.infer_let_rec_mutual(bindings, body, env),
261
262            // Conditional: if cond then t else e
263            Expr::If {
264                cond,
265                then_branch,
266                else_branch,
267            } => self.infer_if(cond, then_branch, else_branch, env),
268
269            // Binary operations: e1 op e2
270            Expr::BinOp { op, left, right } => self.infer_binop(*op, left, right, env),
271
272            // Tuple: (e1, e2, ...)
273            Expr::Tuple(elements) => self.infer_tuple(elements, env),
274
275            // List: [e1; e2; ...]
276            Expr::List(elements) => self.infer_list(elements, env),
277
278            // Cons: e1 :: e2
279            Expr::Cons { head, tail } => self.infer_cons(head, tail, env),
280
281            // Array: [|e1; e2; ...|]
282            Expr::Array(elements) => self.infer_array(elements, env),
283
284            // Array indexing: arr.[idx]
285            Expr::ArrayIndex { array, index } => self.infer_array_index(array, index, env),
286
287            // Array update: arr.[idx] <- value
288            Expr::ArrayUpdate {
289                array,
290                index,
291                value,
292            } => self.infer_array_update(array, index, value, env),
293
294            // Array length: Array.length arr
295            Expr::ArrayLength(array) => self.infer_array_length(array, env),
296
297            // Record literal: { field1 = e1; field2 = e2 }
298            Expr::RecordLiteral { type_name, fields } => {
299                self.infer_record_literal(type_name, fields, env)
300            }
301
302            // Record access: record.field
303            Expr::RecordAccess { record, field } => self.infer_record_access(record, field, env),
304
305            // Record update: { record with field = value }
306            Expr::RecordUpdate { record, fields } => self.infer_record_update(record, fields, env),
307
308            // Variant constructor: Some(42), None, etc.
309            Expr::VariantConstruct {
310                type_name,
311                variant,
312                fields,
313            } => self.infer_variant_construct(type_name, variant, fields, env),
314
315            // Pattern matching: match scrutinee with | pat1 -> e1 | pat2 -> e2
316            Expr::Match { scrutinee, arms } => self.infer_match(scrutinee, arms, env),
317
318            // Method call: obj.method(args...)
319            Expr::MethodCall {
320                receiver,
321                method_name: _,
322                args: _,
323            } => {
324                // For now, we infer method calls conservatively
325                // Type check the receiver
326                self.infer(receiver, env)?;
327                // Return a fresh type variable since we don't know the method's return type
328                Ok(Type::Var(self.fresh_var()))
329            }
330
331            // While loop: while cond do body
332            Expr::While { cond, body } => {
333                // Condition must be bool
334                let cond_ty = self.infer(cond, env)?;
335                self.unify(&cond_ty, &Type::Bool)?;
336                // Type check the body
337                self.infer(body, env)?;
338                // While loops return unit
339                Ok(Type::Unit)
340            }
341
342            // Break statement
343            Expr::Break => {
344                // Break has unit type but can only appear in loops
345                // We'll let the compiler handle loop context validation
346                Ok(Type::Unit)
347            }
348
349            // Continue statement
350            Expr::Continue => {
351                // Continue has unit type but can only appear in loops
352                // We'll let the compiler handle loop context validation
353                Ok(Type::Unit)
354            }
355
356            // Computation expression (stub implementation)
357            Expr::ComputationExpr {
358                builder: _,
359                body: _,
360            } => {
361                // TODO: Implement proper type inference for computation expressions
362                // For now, return a fresh type variable
363                Ok(Type::Var(self.fresh_var()))
364            }
365        }
366    }
367
368    /// Infer the type of a literal value.
369    fn infer_literal(&self, lit: &Literal) -> Type {
370        match lit {
371            Literal::Int(_) => Type::Int,
372            Literal::Float(_) => Type::Float,
373            Literal::Bool(_) => Type::Bool,
374            Literal::Str(_) => Type::String,
375            Literal::Unit => Type::Unit,
376        }
377    }
378
379    /// Infer the type of a variable by looking it up in the environment.
380    fn infer_var(&mut self, name: &str, env: &TypeEnv) -> Result<Type, TypeError> {
381        match env.lookup(name) {
382            Some(scheme) => {
383                // Instantiate the type scheme with fresh type variables
384                Ok(env.instantiate(scheme, &mut || self.fresh_var()))
385            }
386            None => Err(TypeError::new(TypeErrorKind::UnboundVariable {
387                name: name.to_string(),
388            })),
389        }
390    }
391
392    /// Infer the type of a lambda function.
393    ///
394    /// For `fun x -> body`, we:
395    /// 1. Create a fresh type variable α for the parameter
396    /// 2. Extend the environment with x: α
397    /// 3. Infer the type β of the body
398    /// 4. Return α -> β
399    fn infer_lambda(&mut self, param: &str, body: &Expr, env: &TypeEnv) -> Result<Type, TypeError> {
400        let param_type = Type::Var(self.fresh_var());
401        let param_scheme = TypeScheme::mono(param_type.clone());
402        let extended_env = env.extend(param.to_string(), param_scheme);
403
404        let body_type = self.infer(body, &extended_env)?;
405
406        Ok(Type::Function(Box::new(param_type), Box::new(body_type)))
407    }
408
409    /// Infer the type of a function application.
410    ///
411    /// For `f arg`, we:
412    /// 1. Infer the type tf of f
413    /// 2. Infer the type targ of arg
414    /// 3. Create a fresh type variable α for the result
415    /// 4. Add constraint: tf = targ -> α
416    /// 5. Return α
417    fn infer_app(&mut self, func: &Expr, arg: &Expr, env: &TypeEnv) -> Result<Type, TypeError> {
418        let func_type = self.infer(func, env)?;
419        let arg_type = self.infer(arg, env)?;
420        let result_type = Type::Var(self.fresh_var());
421
422        // Constraint: func_type = arg_type -> result_type
423        let expected_func_type = Type::Function(Box::new(arg_type), Box::new(result_type.clone()));
424        self.add_constraint(Constraint::Equal(func_type, expected_func_type));
425
426        Ok(result_type)
427    }
428
429    /// Infer the type of a let-binding with automatic recursion detection.
430    ///
431    /// For `let x = value in body`:
432    /// 1. Check if value references x (auto-detect recursion)
433    /// 2. If recursive or is_recursive is true, infer with x in scope
434    /// 3. Generalize the type (let-polymorphism)
435    /// 4. Infer and return the type of body
436    ///
437    /// This implements issue #126: automatic recursive function detection
438    /// for lambda expressions like `let factorial = fun n -> ... factorial ...`
439    fn infer_let(
440        &mut self,
441        name: &str,
442        value: &Expr,
443        body: &Expr,
444        env: &TypeEnv,
445        is_recursive: bool,
446    ) -> Result<Type, TypeError> {
447        // Auto-detect recursion: check if value references name
448        let auto_recursive = !is_recursive && Self::expr_references_var(value, name);
449        let treat_as_recursive = is_recursive || auto_recursive;
450
451        let value_type = if treat_as_recursive {
452            // For recursive bindings, assume a fresh type variable for the name
453            let rec_var = Type::Var(self.fresh_var());
454            let rec_scheme = TypeScheme::mono(rec_var.clone());
455            let rec_env = env.extend(name.to_string(), rec_scheme);
456
457            // Infer the value type in the extended environment
458            let inferred = self.infer(value, &rec_env)?;
459
460            // Add constraint: rec_var = inferred
461            self.add_constraint(Constraint::Equal(rec_var, inferred.clone()));
462            inferred
463        } else {
464            // Non-recursive: infer value in current environment
465            self.infer(value, env)?
466        };
467
468        // Generalize the type (let-polymorphism)
469        let value_scheme = env.generalize(&value_type);
470
471        // Extend environment and infer body
472        let extended_env = env.extend(name.to_string(), value_scheme);
473        self.infer(body, &extended_env)
474    }
475
476    /// Infer the type of mutually recursive let-bindings.
477    fn infer_let_rec_mutual(
478        &mut self,
479        bindings: &[(String, Expr)],
480        body: &Expr,
481        env: &TypeEnv,
482    ) -> Result<Type, TypeError> {
483        // Create fresh type variables for all bindings
484        let mut rec_env = env.clone();
485        let mut binding_vars = Vec::new();
486
487        for (name, _) in bindings {
488            let var = Type::Var(self.fresh_var());
489            rec_env.insert(name.clone(), TypeScheme::mono(var.clone()));
490            binding_vars.push((name.clone(), var));
491        }
492
493        // Infer types for all bindings in the extended environment
494        for ((_, expr), (_name, var)) in bindings.iter().zip(binding_vars.iter()) {
495            let inferred = self.infer(expr, &rec_env)?;
496            self.add_constraint(Constraint::Equal(var.clone(), inferred));
497        }
498
499        // Infer body type
500        self.infer(body, &rec_env)
501    }
502
503    /// Infer the type of a conditional expression.
504    ///
505    /// For `if cond then t else e`:
506    /// 1. Infer type of cond and constrain it to bool
507    /// 2. Infer types of both branches
508    /// 3. Constrain both branches to have the same type
509    /// 4. Return the branch type
510    fn infer_if(
511        &mut self,
512        cond: &Expr,
513        then_branch: &Expr,
514        else_branch: &Expr,
515        env: &TypeEnv,
516    ) -> Result<Type, TypeError> {
517        let cond_type = self.infer(cond, env)?;
518        self.add_constraint(Constraint::Equal(cond_type, Type::Bool));
519
520        let then_type = self.infer(then_branch, env)?;
521        let else_type = self.infer(else_branch, env)?;
522
523        // Both branches must have the same type
524        self.add_constraint(Constraint::Equal(then_type.clone(), else_type));
525
526        Ok(then_type)
527    }
528
529    /// Infer the type of a binary operation.
530    fn infer_binop(
531        &mut self,
532        op: BinOp,
533        left: &Expr,
534        right: &Expr,
535        env: &TypeEnv,
536    ) -> Result<Type, TypeError> {
537        let left_type = self.infer(left, env)?;
538        let right_type = self.infer(right, env)?;
539
540        if op.is_arithmetic() {
541            // Arithmetic: both operands must be int or float, result is same type
542            // For simplicity, we constrain to int (full implementation would support float)
543            self.add_constraint(Constraint::Equal(left_type.clone(), Type::Int));
544            self.add_constraint(Constraint::Equal(right_type, Type::Int));
545            Ok(Type::Int)
546        } else if op.is_comparison() {
547            // Comparison: operands must have the same type, result is bool
548            self.add_constraint(Constraint::Equal(left_type, right_type));
549            Ok(Type::Bool)
550        } else if op.is_logical() {
551            // Logical: both operands must be bool, result is bool
552            self.add_constraint(Constraint::Equal(left_type, Type::Bool));
553            self.add_constraint(Constraint::Equal(right_type, Type::Bool));
554            Ok(Type::Bool)
555        } else {
556            unreachable!("Unknown binary operator")
557        }
558    }
559
560    /// Infer the type of a tuple.
561    fn infer_tuple(&mut self, elements: &[Expr], env: &TypeEnv) -> Result<Type, TypeError> {
562        let mut element_types = Vec::new();
563        for element in elements {
564            element_types.push(self.infer(element, env)?);
565        }
566        Ok(Type::Tuple(element_types))
567    }
568
569    /// Infer the type of a list.
570    ///
571    /// All elements must have the same type.
572    fn infer_list(&mut self, elements: &[Expr], env: &TypeEnv) -> Result<Type, TypeError> {
573        if elements.is_empty() {
574            // Empty list has polymorphic type 'a list
575            Ok(Type::List(Box::new(Type::Var(self.fresh_var()))))
576        } else {
577            let first_type = self.infer(&elements[0], env)?;
578            for element in &elements[1..] {
579                let element_type = self.infer(element, env)?;
580                self.add_constraint(Constraint::Equal(first_type.clone(), element_type));
581            }
582            Ok(Type::List(Box::new(first_type)))
583        }
584    }
585
586    /// Infer the type of cons operator (::).
587    fn infer_cons(&mut self, head: &Expr, tail: &Expr, env: &TypeEnv) -> Result<Type, TypeError> {
588        let head_type = self.infer(head, env)?;
589        let tail_type = self.infer(tail, env)?;
590
591        let expected_tail = Type::List(Box::new(head_type.clone()));
592        self.add_constraint(Constraint::Equal(tail_type, expected_tail));
593
594        Ok(Type::List(Box::new(head_type)))
595    }
596
597    /// Infer the type of an array.
598    fn infer_array(&mut self, elements: &[Expr], env: &TypeEnv) -> Result<Type, TypeError> {
599        if elements.is_empty() {
600            Ok(Type::Array(Box::new(Type::Var(self.fresh_var()))))
601        } else {
602            let first_type = self.infer(&elements[0], env)?;
603            for element in &elements[1..] {
604                let element_type = self.infer(element, env)?;
605                self.add_constraint(Constraint::Equal(first_type.clone(), element_type));
606            }
607            Ok(Type::Array(Box::new(first_type)))
608        }
609    }
610
611    /// Infer the type of array indexing.
612    fn infer_array_index(
613        &mut self,
614        array: &Expr,
615        index: &Expr,
616        env: &TypeEnv,
617    ) -> Result<Type, TypeError> {
618        let array_type = self.infer(array, env)?;
619        let index_type = self.infer(index, env)?;
620
621        // Index must be int
622        self.add_constraint(Constraint::Equal(index_type, Type::Int));
623
624        // Array must be array type, extract element type
625        let element_type = Type::Var(self.fresh_var());
626        let expected_array_type = Type::Array(Box::new(element_type.clone()));
627        self.add_constraint(Constraint::Equal(array_type, expected_array_type));
628
629        Ok(element_type)
630    }
631
632    /// Infer the type of array update.
633    fn infer_array_update(
634        &mut self,
635        array: &Expr,
636        index: &Expr,
637        value: &Expr,
638        env: &TypeEnv,
639    ) -> Result<Type, TypeError> {
640        let array_type = self.infer(array, env)?;
641        let index_type = self.infer(index, env)?;
642        let value_type = self.infer(value, env)?;
643
644        // Index must be int
645        self.add_constraint(Constraint::Equal(index_type, Type::Int));
646
647        // Value type must match array element type
648        let expected_array_type = Type::Array(Box::new(value_type));
649        self.add_constraint(Constraint::Equal(array_type.clone(), expected_array_type));
650
651        Ok(array_type)
652    }
653
654    /// Infer the type of array length.
655    fn infer_array_length(&mut self, array: &Expr, env: &TypeEnv) -> Result<Type, TypeError> {
656        let array_type = self.infer(array, env)?;
657
658        // Must be an array
659        let element_type = Type::Var(self.fresh_var());
660        let expected_array_type = Type::Array(Box::new(element_type));
661        self.add_constraint(Constraint::Equal(array_type, expected_array_type));
662
663        Ok(Type::Int)
664    }
665
666    /// Infer the type of a record literal.
667    fn infer_record_literal(
668        &mut self,
669        _type_name: &str,
670        fields: &[(String, Box<Expr>)],
671        env: &TypeEnv,
672    ) -> Result<Type, TypeError> {
673        let mut field_types = HashMap::new();
674
675        for (field_name, field_expr) in fields {
676            let field_type = self.infer(field_expr, env)?;
677            field_types.insert(field_name.clone(), field_type);
678        }
679
680        Ok(Type::Record(field_types))
681    }
682
683    /// Infer the type of record field access.
684    fn infer_record_access(
685        &mut self,
686        record: &Expr,
687        field: &str,
688        env: &TypeEnv,
689    ) -> Result<Type, TypeError> {
690        let record_type = self.infer(record, env)?;
691
692        // Create a fresh type variable for the field
693        let field_type = Type::Var(self.fresh_var());
694
695        // Create a record type with at least this field
696        let mut expected_fields = HashMap::new();
697        expected_fields.insert(field.to_string(), field_type.clone());
698        let expected_record = Type::Record(expected_fields);
699
700        // Note: This is a simplified version. A full implementation would use
701        // row polymorphism or structural typing for better record handling.
702        self.add_constraint(Constraint::Equal(record_type, expected_record));
703
704        Ok(field_type)
705    }
706
707    /// Infer the type of record update.
708    fn infer_record_update(
709        &mut self,
710        record: &Expr,
711        fields: &[(String, Box<Expr>)],
712        env: &TypeEnv,
713    ) -> Result<Type, TypeError> {
714        let record_type = self.infer(record, env)?;
715
716        // Infer types of updated fields
717        let mut update_field_types = HashMap::new();
718        for (field_name, field_expr) in fields {
719            let field_type = self.infer(field_expr, env)?;
720            update_field_types.insert(field_name.clone(), field_type);
721        }
722
723        // The result has the same type as the input record
724        // (with potentially different field types for updated fields)
725        Ok(record_type)
726    }
727
728    /// Infer the type of a variant constructor.
729    fn infer_variant_construct(
730        &mut self,
731        _type_name: &str,
732        variant: &str,
733        fields: &[Box<Expr>],
734        env: &TypeEnv,
735    ) -> Result<Type, TypeError> {
736        // Infer types of all fields
737        let mut field_types = Vec::new();
738        for field in fields {
739            field_types.push(self.infer(field, env)?);
740        }
741
742        // Create variant type
743        Ok(Type::Variant(variant.to_string(), field_types))
744    }
745
746    /// Infer the type of a match expression.
747    ///
748    /// For `match scrutinee with | pat1 -> e1 | pat2 -> e2`:
749    /// 1. Infer type of scrutinee
750    /// 2. For each arm, check pattern matches scrutinee type
751    /// 3. Infer type of each arm body
752    /// 4. Constrain all arm bodies to have the same type
753    /// 5. Return the arm type
754    fn infer_match(
755        &mut self,
756        scrutinee: &Expr,
757        arms: &[MatchArm],
758        env: &TypeEnv,
759    ) -> Result<Type, TypeError> {
760        if arms.is_empty() {
761            return Err(TypeError::new(TypeErrorKind::Custom {
762                message: "Match expression must have at least one arm".to_string(),
763            }));
764        }
765
766        let scrutinee_type = self.infer(scrutinee, env)?;
767
768        // Infer the type of the first arm as the result type
769        let (_first_pattern_env, first_result_type) =
770            self.infer_match_arm(&arms[0], &scrutinee_type, env)?;
771
772        // Check remaining arms
773        for arm in &arms[1..] {
774            let (_, arm_type) = self.infer_match_arm(arm, &scrutinee_type, env)?;
775            self.add_constraint(Constraint::Equal(first_result_type.clone(), arm_type));
776        }
777
778        Ok(first_result_type)
779    }
780
781    /// Infer the type of a single match arm.
782    ///
783    /// Returns the extended environment from pattern bindings and the body type.
784    fn infer_match_arm(
785        &mut self,
786        arm: &MatchArm,
787        scrutinee_type: &Type,
788        env: &TypeEnv,
789    ) -> Result<(TypeEnv, Type), TypeError> {
790        // Check pattern against scrutinee type and get bindings
791        let pattern_env = self.infer_pattern(&arm.pattern, scrutinee_type, env)?;
792
793        // Infer body type in extended environment
794        let body_type = self.infer(&arm.body, &pattern_env)?;
795
796        Ok((pattern_env, body_type))
797    }
798
799    /// Infer pattern bindings and check pattern type matches scrutinee.
800    ///
801    /// Returns an extended environment with pattern variable bindings.
802    pub fn infer_pattern(
803        &mut self,
804        pattern: &Pattern,
805        scrutinee_ty: &Type,
806        env: &TypeEnv,
807    ) -> Result<TypeEnv, TypeError> {
808        match pattern {
809            // Wildcard matches anything, no bindings
810            Pattern::Wildcard => Ok(env.clone()),
811
812            // Variable binds the scrutinee value
813            Pattern::Var(name) => {
814                let scheme = TypeScheme::mono(scrutinee_ty.clone());
815                Ok(env.extend(name.clone(), scheme))
816            }
817
818            // Literal must match scrutinee type exactly
819            Pattern::Literal(lit) => {
820                let lit_type = self.infer_literal(lit);
821                self.add_constraint(Constraint::Equal(scrutinee_ty.clone(), lit_type));
822                Ok(env.clone())
823            }
824
825            // Tuple pattern
826            Pattern::Tuple(patterns) => {
827                // Scrutinee must be a tuple with matching arity
828                let mut pattern_types = Vec::new();
829                for _ in patterns {
830                    pattern_types.push(Type::Var(self.fresh_var()));
831                }
832
833                let expected_tuple = Type::Tuple(pattern_types.clone());
834                self.add_constraint(Constraint::Equal(scrutinee_ty.clone(), expected_tuple));
835
836                // Process each sub-pattern
837                let mut extended_env = env.clone();
838                for (pattern, pattern_type) in patterns.iter().zip(pattern_types.iter()) {
839                    extended_env = self.infer_pattern(pattern, pattern_type, &extended_env)?;
840                }
841
842                Ok(extended_env)
843            }
844
845            // Variant pattern
846            Pattern::Variant { variant, patterns } => {
847                // Create types for variant fields
848                let mut field_types = Vec::new();
849                for _ in patterns {
850                    field_types.push(Type::Var(self.fresh_var()));
851                }
852
853                let expected_variant = Type::Variant(variant.clone(), field_types.clone());
854                self.add_constraint(Constraint::Equal(scrutinee_ty.clone(), expected_variant));
855
856                // Process field patterns
857                let mut extended_env = env.clone();
858                for (pattern, field_type) in patterns.iter().zip(field_types.iter()) {
859                    extended_env = self.infer_pattern(pattern, field_type, &extended_env)?;
860                }
861
862                Ok(extended_env)
863            }
864        }
865    }
866
867    /// Solve all accumulated constraints using unification.
868    ///
869    /// Returns a substitution that satisfies all constraints.
870    pub fn solve_constraints(&mut self) -> Result<Substitution, TypeError> {
871        let mut subst = Substitution::empty();
872
873        for constraint in &self.constraints {
874            match constraint {
875                Constraint::Equal(t1, t2) => {
876                    // Apply current substitution to both sides
877                    let t1_subst = t1.apply(&subst);
878                    let t2_subst = t2.apply(&subst);
879
880                    // Unify and compose substitutions
881                    let new_subst = self.unify(&t1_subst, &t2_subst)?;
882                    subst = Substitution::compose(&new_subst, &subst);
883                }
884            }
885        }
886
887        Ok(subst)
888    }
889
890    /// Unify two types using Robinson's unification algorithm.
891    ///
892    /// Returns a substitution that makes the types equal, or an error if unification fails.
893    #[allow(clippy::only_used_in_recursion)]
894    pub fn unify(&self, t1: &Type, t2: &Type) -> Result<Substitution, TypeError> {
895        match (t1, t2) {
896            // Identical types unify trivially
897            (Type::Int, Type::Int)
898            | (Type::Bool, Type::Bool)
899            | (Type::String, Type::String)
900            | (Type::Unit, Type::Unit)
901            | (Type::Float, Type::Float) => Ok(Substitution::empty()),
902
903            // Same type variable
904            (Type::Var(v1), Type::Var(v2)) if v1 == v2 => Ok(Substitution::empty()),
905
906            // Type variable unifies with any type (with occurs check)
907            (Type::Var(v), t) | (t, Type::Var(v)) => {
908                if t.occurs_check(v) {
909                    Err(TypeError::new(TypeErrorKind::OccursCheck {
910                        var: v.clone(),
911                        in_type: t.clone(),
912                    }))
913                } else {
914                    Ok(Substitution::singleton(v.clone(), t.clone()))
915                }
916            }
917
918            // Function types unify if domain and codomain unify
919            (Type::Function(a1, r1), Type::Function(a2, r2)) => {
920                let subst1 = self.unify(a1, a2)?;
921                let r1_subst = r1.apply(&subst1);
922                let r2_subst = r2.apply(&subst1);
923                let subst2 = self.unify(&r1_subst, &r2_subst)?;
924                Ok(Substitution::compose(&subst2, &subst1))
925            }
926
927            // Tuple types unify if they have the same arity and elements unify
928            (Type::Tuple(ts1), Type::Tuple(ts2)) => {
929                if ts1.len() != ts2.len() {
930                    return Err(TypeError::new(TypeErrorKind::Mismatch {
931                        expected: t1.clone(),
932                        got: t2.clone(),
933                    }));
934                }
935
936                let mut subst = Substitution::empty();
937                for (ty1, ty2) in ts1.iter().zip(ts2.iter()) {
938                    let ty1_subst = ty1.apply(&subst);
939                    let ty2_subst = ty2.apply(&subst);
940                    let new_subst = self.unify(&ty1_subst, &ty2_subst)?;
941                    subst = Substitution::compose(&new_subst, &subst);
942                }
943                Ok(subst)
944            }
945
946            // List types unify if element types unify
947            (Type::List(t1), Type::List(t2)) => self.unify(t1, t2),
948
949            // Array types unify if element types unify
950            (Type::Array(t1), Type::Array(t2)) => self.unify(t1, t2),
951
952            // Record types unify if they have the same fields with unifying types
953            (Type::Record(fields1), Type::Record(fields2)) => {
954                if fields1.len() != fields2.len() {
955                    return Err(TypeError::new(TypeErrorKind::Mismatch {
956                        expected: t1.clone(),
957                        got: t2.clone(),
958                    }));
959                }
960
961                let mut subst = Substitution::empty();
962                for (name, ty1) in fields1 {
963                    match fields2.get(name) {
964                        Some(ty2) => {
965                            let ty1_subst = ty1.apply(&subst);
966                            let ty2_subst = ty2.apply(&subst);
967                            let new_subst = self.unify(&ty1_subst, &ty2_subst)?;
968                            subst = Substitution::compose(&new_subst, &subst);
969                        }
970                        None => {
971                            return Err(TypeError::new(TypeErrorKind::FieldNotFound {
972                                record_type: t1.clone(),
973                                field: name.clone(),
974                            }));
975                        }
976                    }
977                }
978                Ok(subst)
979            }
980
981            // Variant types unify if same variant name and field types unify
982            (Type::Variant(name1, fields1), Type::Variant(name2, fields2)) => {
983                if name1 != name2 {
984                    return Err(TypeError::new(TypeErrorKind::Mismatch {
985                        expected: t1.clone(),
986                        got: t2.clone(),
987                    }));
988                }
989
990                if fields1.len() != fields2.len() {
991                    return Err(TypeError::new(TypeErrorKind::Mismatch {
992                        expected: t1.clone(),
993                        got: t2.clone(),
994                    }));
995                }
996
997                let mut subst = Substitution::empty();
998                for (ty1, ty2) in fields1.iter().zip(fields2.iter()) {
999                    let ty1_subst = ty1.apply(&subst);
1000                    let ty2_subst = ty2.apply(&subst);
1001                    let new_subst = self.unify(&ty1_subst, &ty2_subst)?;
1002                    subst = Substitution::compose(&new_subst, &subst);
1003                }
1004                Ok(subst)
1005            }
1006
1007            // All other cases are type mismatches
1008            _ => Err(TypeError::new(TypeErrorKind::Mismatch {
1009                expected: t1.clone(),
1010                got: t2.clone(),
1011            })),
1012        }
1013    }
1014
1015    /// Convenience method: infer type and solve constraints in one step.
1016    ///
1017    /// This is the main entry point for type checking.
1018    ///
1019    /// # Example
1020    ///
1021    /// ```rust
1022    /// use fusabi_frontend::inference::TypeInference;
1023    /// use fusabi_frontend::types::TypeEnv;
1024    /// use fusabi_frontend::ast::{Expr, Literal};
1025    ///
1026    /// let mut inference = TypeInference::new();
1027    /// let env = TypeEnv::new();
1028    /// let expr = Expr::Lit(Literal::Int(42));
1029    ///
1030    /// let ty = inference.infer_and_solve(&expr, &env).unwrap();
1031    /// ```
1032    pub fn infer_and_solve(&mut self, expr: &Expr, env: &TypeEnv) -> Result<Type, TypeError> {
1033        // Clear any previous constraints
1034        self.constraints.clear();
1035
1036        // Infer the type (generating constraints)
1037        let ty = self.infer(expr, env)?;
1038
1039        // Solve the constraints
1040        let subst = self.solve_constraints()?;
1041
1042        // Apply the substitution to the result type
1043        Ok(ty.apply(&subst))
1044    }
1045}
1046
1047impl Default for TypeInference {
1048    fn default() -> Self {
1049        Self::new()
1050    }
1051}
1052
1053#[cfg(test)]
1054mod tests {
1055    use super::*;
1056
1057    // Helper to create simple test expressions
1058    fn lit_int(n: i64) -> Expr {
1059        Expr::Lit(Literal::Int(n))
1060    }
1061
1062    fn var(name: &str) -> Expr {
1063        Expr::Var(name.to_string())
1064    }
1065
1066    fn lambda(param: &str, body: Expr) -> Expr {
1067        Expr::Lambda {
1068            param: param.to_string(),
1069            body: Box::new(body),
1070        }
1071    }
1072
1073    fn app(func: Expr, arg: Expr) -> Expr {
1074        Expr::App {
1075            func: Box::new(func),
1076            arg: Box::new(arg),
1077        }
1078    }
1079
1080    fn let_expr(name: &str, value: Expr, body: Expr) -> Expr {
1081        Expr::Let {
1082            name: name.to_string(),
1083            value: Box::new(value),
1084            body: Box::new(body),
1085        }
1086    }
1087
1088    // ========================================================================
1089    // Basic Inference Tests
1090    // ========================================================================
1091
1092    #[test]
1093    fn test_infer_literal_int() {
1094        let mut inf = TypeInference::new();
1095        let env = TypeEnv::new();
1096        let expr = lit_int(42);
1097
1098        let ty = inf.infer_and_solve(&expr, &env).unwrap();
1099        assert_eq!(ty, Type::Int);
1100    }
1101
1102    #[test]
1103    fn test_infer_literal_bool() {
1104        let mut inf = TypeInference::new();
1105        let env = TypeEnv::new();
1106        let expr = Expr::Lit(Literal::Bool(true));
1107
1108        let ty = inf.infer_and_solve(&expr, &env).unwrap();
1109        assert_eq!(ty, Type::Bool);
1110    }
1111
1112    #[test]
1113    fn test_infer_identity_function() {
1114        let mut inf = TypeInference::new();
1115        let env = TypeEnv::new();
1116        // fun x -> x
1117        let expr = lambda("x", var("x"));
1118
1119        let ty = inf.infer_and_solve(&expr, &env).unwrap();
1120        // Should be 'a -> 'a (with some type variable)
1121        match ty {
1122            Type::Function(arg, ret) => match (*arg, *ret) {
1123                (Type::Var(v1), Type::Var(v2)) => assert_eq!(v1, v2),
1124                _ => panic!("Expected function with type variables"),
1125            },
1126            _ => panic!("Expected function type"),
1127        }
1128    }
1129
1130    #[test]
1131    fn test_infer_const_function() {
1132        let mut inf = TypeInference::new();
1133        let env = TypeEnv::new();
1134        // fun x -> 42
1135        let expr = lambda("x", lit_int(42));
1136
1137        let ty = inf.infer_and_solve(&expr, &env).unwrap();
1138        // Should be 'a -> int
1139        match ty {
1140            Type::Function(_, ret) => assert_eq!(*ret, Type::Int),
1141            _ => panic!("Expected function type"),
1142        }
1143    }
1144
1145    #[test]
1146    fn test_infer_application() {
1147        let mut inf = TypeInference::new();
1148        let env = TypeEnv::new();
1149        // (fun x -> x) 42
1150        let expr = app(lambda("x", var("x")), lit_int(42));
1151
1152        let ty = inf.infer_and_solve(&expr, &env).unwrap();
1153        assert_eq!(ty, Type::Int);
1154    }
1155
1156    #[test]
1157    fn test_infer_unbound_variable() {
1158        let mut inf = TypeInference::new();
1159        let env = TypeEnv::new();
1160        let expr = var("x");
1161
1162        let result = inf.infer_and_solve(&expr, &env);
1163        assert!(result.is_err());
1164        match result.unwrap_err().kind {
1165            TypeErrorKind::UnboundVariable { name } => assert_eq!(name, "x"),
1166            _ => panic!("Expected UnboundVariable error"),
1167        }
1168    }
1169
1170    // ========================================================================
1171    // Auto-Recursive Detection Tests (Issue #126)
1172    // ========================================================================
1173
1174    #[test]
1175    fn test_auto_recursive_lambda_factorial() {
1176        let mut inf = TypeInference::new();
1177        let env = TypeEnv::new();
1178
1179        // let factorial = fun n ->
1180        //     if n <= 1 then 1
1181        //     else n * factorial (n - 1)
1182        // in factorial 5
1183        let cond = Expr::BinOp {
1184            op: BinOp::Lte,
1185            left: Box::new(var("n")),
1186            right: Box::new(lit_int(1)),
1187        };
1188        let then_branch = lit_int(1);
1189        let else_branch = Expr::BinOp {
1190            op: BinOp::Mul,
1191            left: Box::new(var("n")),
1192            right: Box::new(app(
1193                var("factorial"),
1194                Expr::BinOp {
1195                    op: BinOp::Sub,
1196                    left: Box::new(var("n")),
1197                    right: Box::new(lit_int(1)),
1198                },
1199            )),
1200        };
1201        let factorial_body = Expr::If {
1202            cond: Box::new(cond),
1203            then_branch: Box::new(then_branch),
1204            else_branch: Box::new(else_branch),
1205        };
1206        let factorial_lambda = lambda("n", factorial_body);
1207        let expr = let_expr(
1208            "factorial",
1209            factorial_lambda,
1210            app(var("factorial"), lit_int(5)),
1211        );
1212
1213        // Should successfully infer type without needing explicit 'rec'
1214        let ty = inf.infer_and_solve(&expr, &env).unwrap();
1215        assert_eq!(ty, Type::Int);
1216    }
1217
1218    #[test]
1219    fn test_auto_recursive_simple() {
1220        let mut inf = TypeInference::new();
1221        let env = TypeEnv::new();
1222
1223        // let f = fun x -> f x in f 42
1224        let f_body = app(var("f"), var("x"));
1225        let f_lambda = lambda("x", f_body);
1226        let expr = let_expr("f", f_lambda, app(var("f"), lit_int(42)));
1227
1228        // Should infer type (may be polymorphic)
1229        let result = inf.infer_and_solve(&expr, &env);
1230        assert!(result.is_ok());
1231    }
1232
1233    #[test]
1234    fn test_non_recursive_lambda_still_works() {
1235        let mut inf = TypeInference::new();
1236        let env = TypeEnv::new();
1237
1238        // let double = fun x -> x * 2 in double 21
1239        let double_body = Expr::BinOp {
1240            op: BinOp::Mul,
1241            left: Box::new(var("x")),
1242            right: Box::new(lit_int(2)),
1243        };
1244        let double_lambda = lambda("x", double_body);
1245        let expr = let_expr("double", double_lambda, app(var("double"), lit_int(21)));
1246
1247        let ty = inf.infer_and_solve(&expr, &env).unwrap();
1248        assert_eq!(ty, Type::Int);
1249    }
1250
1251    #[test]
1252    fn test_shadowing_prevents_auto_recursion() {
1253        let mut inf = TypeInference::new();
1254        let env = TypeEnv::new();
1255
1256        // let f = fun f -> f in f 42
1257        // The parameter 'f' shadows the binding name, so this is not recursive
1258        // f is the identity function: 'a -> 'a
1259        // When applied to 42, it returns 42
1260        let f_lambda = lambda("f", var("f"));
1261        let expr = let_expr("f", f_lambda, app(var("f"), lit_int(42)));
1262
1263        let ty = inf.infer_and_solve(&expr, &env).unwrap();
1264        // Result type should be int (identity function applied to int gives int)
1265        assert_eq!(ty, Type::Int);
1266    }
1267}