Skip to main content

logicaffeine_compile/analysis/
check.rs

1//! Bidirectional type checker for the LOGOS compilation pipeline.
2//!
3//! Replaces `TypeEnv::infer_program()` with a proper constraint-solving pass
4//! that eliminates `Unknown` for field access, empty collections, option literals,
5//! pipe receives, inspect arm bindings, and closure call expressions.
6//!
7//! # Architecture
8//!
9//! ```text
10//! AST
11//!  │
12//!  ├── preregister_functions   ← forward-reference pre-pass
13//!  │
14//!  └── infer_stmt / infer_expr ← bidirectional checking
15//!           │
16//!           └── UnificationTable ← Robinson unification (from unify.rs)
17//!                    │
18//!                    └── zonk → TypeEnv (LogosType) → codegen
19//! ```
20
21use std::collections::{HashMap, HashSet};
22
23use crate::analysis::unify::{InferType, TyVar, TypeScheme, TypeError, UnificationTable, infer_to_logos, unify_numeric};
24use crate::analysis::{FnSig, LogosType, TypeDef, TypeEnv, TypeRegistry};
25use crate::ast::stmt::{BinaryOpKind, Expr, OptFlag, Pattern, Stmt};
26use crate::intern::{Interner, Symbol};
27
28// ============================================================================
29// Data structures
30// ============================================================================
31
32/// A registered function's signature, supporting both monomorphic and generic functions.
33///
34/// For generic functions (non-empty `scheme.vars`), each call site must instantiate
35/// the scheme to get fresh type variables, preventing cross-call contamination.
36/// For monomorphic functions (`scheme.vars` is empty), `scheme.body` is the direct type.
37#[derive(Clone, Debug)]
38struct FunctionRecord {
39    /// Parameter names (for binding in the function scope).
40    param_names: Vec<Symbol>,
41    /// The quantified type scheme: `forall vars. Function(param_types, return_type)`.
42    /// For monomorphic functions, `vars` is empty and body is used directly.
43    scheme: TypeScheme,
44}
45
46/// Bidirectional type checking environment.
47///
48/// Scopes are pushed/popped around function bodies and match arms.
49/// All bindings are also written to `all_vars` for later `TypeEnv` output.
50struct CheckEnv<'r> {
51    /// Stacked scopes (innermost last). Variables resolved from inner-to-outer.
52    scopes: Vec<HashMap<Symbol, InferType>>,
53    /// Flat map of every variable ever bound — accumulated for `TypeEnv` output.
54    all_vars: HashMap<Symbol, InferType>,
55    /// Registered function signatures.
56    functions: HashMap<Symbol, FunctionRecord>,
57    /// Expected return type inside the current function body.
58    current_return_type: Option<InferType>,
59    /// Unification table for type variables.
60    table: UnificationTable,
61    registry: &'r TypeRegistry,
62    interner: &'r Interner,
63}
64
65impl<'r> CheckEnv<'r> {
66    fn new(registry: &'r TypeRegistry, interner: &'r Interner) -> Self {
67        Self {
68            scopes: vec![HashMap::new()],
69            all_vars: HashMap::new(),
70            functions: HashMap::new(),
71            current_return_type: None,
72            table: UnificationTable::new(),
73            registry,
74            interner,
75        }
76    }
77
78    fn push_scope(&mut self) {
79        self.scopes.push(HashMap::new());
80    }
81
82    fn pop_scope(&mut self) {
83        self.scopes.pop();
84    }
85
86    /// Bind a variable in the current scope, also recording in `all_vars`.
87    fn bind_var(&mut self, sym: Symbol, ty: InferType) {
88        if let Some(scope) = self.scopes.last_mut() {
89            scope.insert(sym, ty.clone());
90        }
91        self.all_vars.insert(sym, ty);
92    }
93
94    /// Look up a variable, searching scopes from innermost to outermost.
95    ///
96    /// Uses `resolve` (not `zonk`) so that unbound type variables from generic
97    /// function parameters remain as `Var(tv)` during inference, enabling
98    /// proper unification at call sites.
99    fn lookup_var(&self, sym: Symbol) -> Option<InferType> {
100        for scope in self.scopes.iter().rev() {
101            if let Some(ty) = scope.get(&sym) {
102                return Some(self.table.resolve(ty));
103            }
104        }
105        None
106    }
107
108    /// Convert the check environment into a `TypeEnv` for codegen.
109    fn into_type_env(self) -> TypeEnv {
110        let mut type_env = TypeEnv::new();
111
112        // Collect all variable bindings, zonk each to a concrete LogosType
113        for (sym, ty) in self.all_vars {
114            let logos_ty = self.table.to_logos_type(&ty);
115            type_env.register(sym, logos_ty);
116        }
117
118        // Collect function signatures — instantiate monomorphic view for codegen
119        for (name, rec) in self.functions {
120            // Zonk the scheme body to extract concrete param/return types for TypeEnv.
121            // For generic functions, unsolved vars zonk to Unknown, which is fine
122            // since codegen uses TypeExpr (not TypeEnv) for generic param types.
123            if let InferType::Function(param_types, ret_box) = &rec.scheme.body {
124                let ret_logos = self.table.to_logos_type(ret_box);
125                let params: Vec<(Symbol, LogosType)> = rec.param_names.iter()
126                    .zip(param_types.iter())
127                    .map(|(sym, ty)| (*sym, self.table.to_logos_type(ty)))
128                    .collect();
129                type_env.register_fn(name, FnSig { params, return_type: ret_logos });
130            }
131        }
132
133        type_env
134    }
135}
136
137// ============================================================================
138// Pre-pass: forward reference registration
139// ============================================================================
140
141impl<'r> CheckEnv<'r> {
142    /// Register all top-level function signatures before the main checking pass.
143    ///
144    /// This enables forward references and mutual recursion: any function can
145    /// call any other function regardless of declaration order.
146    ///
147    /// For generic functions (non-empty `generics`), allocates a fresh `TyVar` per
148    /// type parameter and builds a `TypeScheme` so call sites can instantiate them.
149    fn preregister_functions(&mut self, stmts: &[Stmt]) {
150        for stmt in stmts {
151            if let Stmt::FunctionDef { name, generics, params, return_type, .. } = stmt {
152                // Allocate one TyVar per generic type parameter
153                let type_param_map: HashMap<Symbol, TyVar> = generics
154                    .iter()
155                    .map(|&sym| (sym, self.table.fresh_var()))
156                    .collect();
157
158                let param_types: Vec<InferType> = params
159                    .iter()
160                    .map(|(_, ty_expr)| {
161                        InferType::from_type_expr_with_params(ty_expr, self.interner, &type_param_map)
162                    })
163                    .collect();
164                let param_names: Vec<Symbol> = params.iter().map(|(sym, _)| *sym).collect();
165
166                let ret_type = if let Some(rt) = return_type {
167                    InferType::from_type_expr_with_params(rt, self.interner, &type_param_map)
168                } else {
169                    self.table.fresh()
170                };
171
172                let generic_vars: Vec<TyVar> = generics
173                    .iter()
174                    .filter_map(|sym| type_param_map.get(sym).copied())
175                    .collect();
176
177                let scheme = TypeScheme {
178                    vars: generic_vars,
179                    body: InferType::Function(param_types, Box::new(ret_type)),
180                };
181
182                self.functions.insert(*name, FunctionRecord { param_names, scheme });
183            }
184        }
185    }
186}
187
188// ============================================================================
189// Core inference
190// ============================================================================
191
192impl<'r> CheckEnv<'r> {
193    /// Check an expression against an expected type (checking mode).
194    ///
195    /// Handles numeric literal coercion (`5` against `Real` → `Float`) and
196    /// structural checking before falling through to synthesis + unification.
197    fn check_expr(
198        &mut self,
199        expr: &Expr,
200        expected: &InferType,
201    ) -> Result<InferType, TypeError> {
202        use crate::ast::stmt::Literal;
203
204        // Number literals are polymorphic: 5 checks against Int, Float, Nat, or Byte
205        if let Expr::Literal(Literal::Number(_)) = expr {
206            match expected {
207                InferType::Float => return Ok(InferType::Float),
208                InferType::Nat => return Ok(InferType::Nat),
209                InferType::Int => return Ok(InferType::Int),
210                InferType::Byte => return Ok(InferType::Byte),
211                _ => {}
212            }
213        }
214
215        // `nothing` is polymorphic: it is `None` when checked against Option(T),
216        // and the unit value `()` in all other contexts.
217        if let Expr::Literal(Literal::Nothing) = expr {
218            if let InferType::Option(_) = expected {
219                return Ok(expected.clone());
220            }
221        }
222
223        // Default: synthesize then unify
224        let inferred = self.infer_expr(expr)?;
225        self.table.unify(&inferred, expected)?;
226        Ok(self.table.zonk(expected))
227    }
228
229    /// Infer the type of an expression (synthesis mode).
230    fn infer_expr(&mut self, expr: &Expr) -> Result<InferType, TypeError> {
231        match expr {
232            Expr::Literal(lit) => Ok(InferType::from_literal(lit)),
233
234            Expr::Identifier(sym) => {
235                Ok(self.lookup_var(*sym).unwrap_or(InferType::Unknown))
236            }
237
238            Expr::BinaryOp { op, left, right } => {
239                self.infer_binary_op(*op, left, right)
240            }
241
242            Expr::Length { .. } => Ok(InferType::Int),
243
244            Expr::Call { function, args } => {
245                self.infer_call(*function, args)
246            }
247
248            Expr::Index { collection, .. } => {
249                let coll_ty = self.infer_expr(collection)?;
250                let walked = self.table.zonk(&coll_ty);
251                match walked {
252                    InferType::Seq(inner) => Ok(*inner),
253                    InferType::Map(_, v) => Ok(*v),
254                    _ => Ok(InferType::Unknown),
255                }
256            }
257
258            Expr::List(items) => {
259                if items.is_empty() {
260                    let elem_var = self.table.fresh();
261                    Ok(InferType::Seq(Box::new(elem_var)))
262                } else {
263                    let elem_type = self.infer_expr(items[0])?;
264                    Ok(InferType::Seq(Box::new(elem_type)))
265                }
266            }
267
268            Expr::OptionSome { value } => {
269                let inner = self.infer_expr(value)?;
270                Ok(InferType::Option(Box::new(inner)))
271            }
272
273            Expr::OptionNone => {
274                let elem_var = self.table.fresh();
275                Ok(InferType::Option(Box::new(elem_var)))
276            }
277
278            Expr::Range { .. } => Ok(InferType::Seq(Box::new(InferType::Int))),
279
280            Expr::Contains { .. } => Ok(InferType::Bool),
281
282            Expr::Copy { expr: inner } | Expr::Give { value: inner } => {
283                self.infer_expr(inner)
284            }
285
286            Expr::WithCapacity { value, .. } => self.infer_expr(value),
287
288            Expr::FieldAccess { object, field } => {
289                let obj_ty = self.infer_expr(object)?;
290                self.infer_field_access(obj_ty, *field)
291            }
292
293            Expr::New { type_name, type_args, .. } => {
294                let name = self.interner.resolve(*type_name);
295                match name {
296                    "Seq" | "List" | "Vec" => {
297                        let elem = type_args
298                            .first()
299                            .map(|t| InferType::from_type_expr(t, self.interner))
300                            .unwrap_or_else(|| self.table.fresh());
301                        Ok(InferType::Seq(Box::new(elem)))
302                    }
303                    "Map" | "HashMap" => {
304                        let key = type_args
305                            .first()
306                            .map(|t| InferType::from_type_expr(t, self.interner))
307                            .unwrap_or(InferType::String);
308                        let val = type_args
309                            .get(1)
310                            .map(|t| InferType::from_type_expr(t, self.interner))
311                            .unwrap_or(InferType::String);
312                        Ok(InferType::Map(Box::new(key), Box::new(val)))
313                    }
314                    "Set" | "HashSet" => {
315                        let elem = type_args
316                            .first()
317                            .map(|t| InferType::from_type_expr(t, self.interner))
318                            .unwrap_or_else(|| self.table.fresh());
319                        Ok(InferType::Set(Box::new(elem)))
320                    }
321                    _ => Ok(InferType::UserDefined(*type_name)),
322                }
323            }
324
325            Expr::NewVariant { enum_name, .. } => {
326                Ok(InferType::UserDefined(*enum_name))
327            }
328
329            Expr::CallExpr { callee, args } => {
330                self.infer_call_expr(callee, args)
331            }
332
333            Expr::Closure { params, body: closure_body, return_type } => {
334                self.infer_closure(params, closure_body, return_type)
335            }
336
337            Expr::InterpolatedString(_) => Ok(InferType::String),
338
339            Expr::Slice { collection, .. } => self.infer_expr(collection),
340
341            Expr::Union { left, .. } | Expr::Intersection { left, .. } => {
342                self.infer_expr(left)
343            }
344
345            // Tuple, ManifestOf, ChunkAt, Escape → Unknown (not typed)
346            _ => Ok(InferType::Unknown),
347        }
348    }
349
350    /// Infer a binary operation's result type.
351    fn infer_binary_op(
352        &mut self,
353        op: BinaryOpKind,
354        left: &Expr,
355        right: &Expr,
356    ) -> Result<InferType, TypeError> {
357        match op {
358            BinaryOpKind::Eq
359            | BinaryOpKind::NotEq
360            | BinaryOpKind::Lt
361            | BinaryOpKind::Gt
362            | BinaryOpKind::LtEq
363            | BinaryOpKind::GtEq => Ok(InferType::Bool),
364
365            // And/Or: type-aware — integer operands → Int (bitwise), else → Bool (logical)
366            BinaryOpKind::And | BinaryOpKind::Or => {
367                let lt = self.infer_expr(left)?;
368                if lt == InferType::Int {
369                    Ok(InferType::Int)
370                } else {
371                    Ok(InferType::Bool)
372                }
373            }
374
375            BinaryOpKind::Concat => Ok(InferType::String),
376
377            BinaryOpKind::BitXor | BinaryOpKind::Shl | BinaryOpKind::Shr => Ok(InferType::Int),
378
379            BinaryOpKind::Add => {
380                let lt = self.infer_expr(left)?;
381                let rt = self.infer_expr(right)?;
382                if lt == InferType::String || rt == InferType::String {
383                    Ok(InferType::String)
384                } else if lt == InferType::Unknown || rt == InferType::Unknown {
385                    Ok(InferType::Unknown)
386                } else {
387                    unify_numeric(&lt, &rt).or(Ok(InferType::Unknown))
388                }
389            }
390
391            BinaryOpKind::Subtract
392            | BinaryOpKind::Multiply
393            | BinaryOpKind::Divide
394            | BinaryOpKind::Modulo => {
395                let lt = self.infer_expr(left)?;
396                let rt = self.infer_expr(right)?;
397                if lt == InferType::Unknown || rt == InferType::Unknown {
398                    Ok(InferType::Unknown)
399                } else {
400                    unify_numeric(&lt, &rt).or(Ok(InferType::Unknown))
401                }
402            }
403        }
404    }
405
406    /// Infer a named function call.
407    ///
408    /// For generic functions, instantiates the `TypeScheme` with fresh type variables,
409    /// then unifies the instantiated parameter types with the argument types. The
410    /// instantiated return type is then zonked and returned as the call result type.
411    fn infer_call(&mut self, function: Symbol, args: &[&Expr]) -> Result<InferType, TypeError> {
412        let name = self.interner.resolve(function);
413        match name {
414            "sqrt" | "parseFloat" | "pow" => Ok(InferType::Float),
415            "parseInt" | "floor" | "ceil" | "round" => Ok(InferType::Int),
416            "abs" | "min" | "max" => {
417                if let Some(first) = args.first() {
418                    self.infer_expr(first)
419                } else {
420                    Ok(InferType::Unknown)
421                }
422            }
423            _ => {
424                if let Some(rec) = self.functions.get(&function).cloned() {
425                    // Instantiate the scheme: each call site gets fresh type variables
426                    // for generic params so calls don't interfere with each other.
427                    let instantiated = self.table.instantiate(&rec.scheme);
428
429                    if let InferType::Function(param_types, ret_box) = instantiated {
430                        // Unify each argument type with the instantiated parameter type
431                        for (arg, param_ty) in args.iter().zip(param_types.iter()) {
432                            let arg_ty = self.infer_expr(arg)?;
433                            self.table.unify(&arg_ty, param_ty)?;
434                        }
435                        Ok(self.table.zonk(&ret_box))
436                    } else {
437                        // Should not happen, but fall back gracefully
438                        Ok(InferType::Unknown)
439                    }
440                } else {
441                    Ok(InferType::Unknown)
442                }
443            }
444        }
445    }
446
447    /// Infer a call-expression (calling a closure/function-value).
448    fn infer_call_expr(
449        &mut self,
450        callee: &Expr,
451        args: &[&Expr],
452    ) -> Result<InferType, TypeError> {
453        let callee_ty = self.infer_expr(callee)?;
454        let ret_var = self.table.fresh();
455        let arg_types: Vec<InferType> = args
456            .iter()
457            .map(|a| self.infer_expr(a))
458            .collect::<Result<_, _>>()?;
459        let fn_ty = InferType::Function(arg_types, Box::new(ret_var.clone()));
460
461        let walked = self.table.zonk(&callee_ty);
462        match walked {
463            InferType::Unknown => Ok(ret_var),
464            InferType::Function(_, _) => {
465                self.table.unify(&walked, &fn_ty)?;
466                Ok(ret_var)
467            }
468            InferType::Var(_) => {
469                self.table.unify(&walked, &fn_ty)?;
470                Ok(ret_var)
471            }
472            other => Err(TypeError::NotAFunction { found: other }),
473        }
474    }
475
476    /// Infer a closure literal.
477    fn infer_closure(
478        &mut self,
479        params: &[(Symbol, &crate::ast::stmt::TypeExpr)],
480        body: &crate::ast::stmt::ClosureBody,
481        return_type: &Option<&crate::ast::stmt::TypeExpr>,
482    ) -> Result<InferType, TypeError> {
483        let param_types: Vec<InferType> = params
484            .iter()
485            .map(|(_, ty_expr)| InferType::from_type_expr(ty_expr, self.interner))
486            .collect();
487
488        let ret_type = if let Some(rt) = return_type {
489            InferType::from_type_expr(rt, self.interner)
490        } else {
491            self.table.fresh()
492        };
493
494        self.push_scope();
495        for ((sym, _), ty) in params.iter().zip(param_types.iter()) {
496            self.bind_var(*sym, ty.clone());
497        }
498
499        let prev_return = self.current_return_type.take();
500        self.current_return_type = Some(ret_type.clone());
501
502        match body {
503            crate::ast::stmt::ClosureBody::Expression(expr) => {
504                let body_ty = self.infer_expr(expr)?;
505                // Best-effort unification: won't fail compilation on ambiguity
506                self.table.unify(&body_ty, &ret_type).ok();
507            }
508            crate::ast::stmt::ClosureBody::Block(stmts) => {
509                for stmt in *stmts {
510                    self.infer_stmt(stmt)?;
511                }
512            }
513        }
514
515        self.current_return_type = prev_return;
516        self.pop_scope();
517
518        Ok(InferType::Function(param_types, Box::new(ret_type)))
519    }
520
521    /// Infer the type of a field access on a struct.
522    fn infer_field_access(
523        &self,
524        obj_ty: InferType,
525        field: Symbol,
526    ) -> Result<InferType, TypeError> {
527        let resolved = self.table.zonk(&obj_ty);
528        match &resolved {
529            InferType::UserDefined(type_sym) => {
530                if let Some(TypeDef::Struct { fields, .. }) = self.registry.get(*type_sym) {
531                    if let Some(field_def) = fields.iter().find(|f| f.name == field) {
532                        Ok(InferType::from_field_type(
533                            &field_def.ty,
534                            self.interner,
535                            &HashMap::new(),
536                        ))
537                    } else {
538                        Err(TypeError::FieldNotFound {
539                            type_name: *type_sym,
540                            field_name: field,
541                        })
542                    }
543                } else {
544                    // Not a struct in registry → Unknown (defensive)
545                    Ok(InferType::Unknown)
546                }
547            }
548            // Can't resolve field on non-struct type
549            _ => Ok(InferType::Unknown),
550        }
551    }
552}
553
554// ============================================================================
555// Statement inference
556// ============================================================================
557
558impl<'r> CheckEnv<'r> {
559    fn infer_stmt(&mut self, stmt: &Stmt) -> Result<(), TypeError> {
560        match stmt {
561            Stmt::Let { var, ty, value, .. } => {
562                let final_ty = if let Some(type_expr) = ty {
563                    let annotated = InferType::from_type_expr(type_expr, self.interner);
564                    if annotated != InferType::Unknown {
565                        // Checking mode: value must be compatible with annotation
566                        self.check_expr(value, &annotated)?
567                    } else {
568                        self.infer_expr(value)?
569                    }
570                } else {
571                    self.infer_expr(value)?
572                };
573                self.bind_var(*var, final_ty);
574                Ok(())
575            }
576
577            Stmt::Set { target, value } => {
578                let inferred = self.infer_expr(value)?;
579                // If target already has a type, unify. Otherwise just bind.
580                if let Some(existing) = self.lookup_var(*target) {
581                    if existing != InferType::Unknown {
582                        self.table.unify(&inferred, &existing).ok();
583                    }
584                }
585                // Update binding
586                let resolved = self.table.zonk(&inferred);
587                if resolved != InferType::Unknown {
588                    self.bind_var(*target, resolved);
589                }
590                Ok(())
591            }
592
593            Stmt::FunctionDef {
594                name,
595                generics,
596                params,
597                body,
598                return_type,
599                is_native,
600                ..
601            } => {
602                // Build a type-param map: Symbol("T") → TyVar
603                // Re-use the TyVars already allocated in preregister_functions if present,
604                // or allocate fresh ones if this function was not pre-registered.
605                let type_param_map: HashMap<Symbol, TyVar> = {
606                    // Try to recover the same TyVars from the pre-registered scheme
607                    let existing_vars: Vec<TyVar> = self.functions
608                        .get(name)
609                        .map(|rec| rec.scheme.vars.clone())
610                        .unwrap_or_default();
611                    if existing_vars.len() == generics.len() {
612                        generics.iter().copied().zip(existing_vars).collect()
613                    } else {
614                        generics.iter().map(|&sym| (sym, self.table.fresh_var())).collect()
615                    }
616                };
617
618                let param_types: Vec<InferType> = params
619                    .iter()
620                    .map(|(_, ty_expr)| {
621                        InferType::from_type_expr_with_params(ty_expr, self.interner, &type_param_map)
622                    })
623                    .collect();
624                let param_names: Vec<Symbol> = params.iter().map(|(sym, _)| *sym).collect();
625
626                let ret_type = if let Some(rt) = return_type {
627                    InferType::from_type_expr_with_params(rt, self.interner, &type_param_map)
628                } else if let Some(rec) = self.functions.get(name) {
629                    // Recover pre-registered return type from the scheme body
630                    if let InferType::Function(_, ret_box) = &rec.scheme.body {
631                        *ret_box.clone()
632                    } else {
633                        self.table.fresh()
634                    }
635                } else {
636                    self.table.fresh()
637                };
638
639                let generic_vars: Vec<TyVar> = generics
640                    .iter()
641                    .filter_map(|sym| type_param_map.get(sym).copied())
642                    .collect();
643
644                // Native functions: register scheme, no body to check
645                if *is_native {
646                    let scheme = TypeScheme {
647                        vars: generic_vars,
648                        body: InferType::Function(param_types, Box::new(ret_type)),
649                    };
650                    self.functions.insert(*name, FunctionRecord { param_names, scheme });
651                    return Ok(());
652                }
653
654                // Save previous return context
655                let prev_return_type = self.current_return_type.take();
656                self.current_return_type = Some(ret_type.clone());
657
658                // Check body in a new scope with params bound
659                self.push_scope();
660                for (sym, ty) in param_names.iter().zip(param_types.iter()) {
661                    self.bind_var(*sym, ty.clone());
662                }
663                for s in *body {
664                    self.infer_stmt(s)?;
665                }
666                self.pop_scope();
667
668                self.current_return_type = prev_return_type;
669
670                // After checking the body, update the registered scheme with resolved types.
671                // Use `resolve` (not `zonk`) so generic TyVars remain as `Var(tv)` in
672                // the scheme body — they will be instantiated fresh at each call site.
673                let resolved_params: Vec<InferType> = param_types
674                    .iter()
675                    .map(|ty| self.table.resolve(ty))
676                    .collect();
677                let resolved_ret = self.table.resolve(&ret_type);
678
679                let scheme = TypeScheme {
680                    vars: generic_vars,
681                    body: InferType::Function(resolved_params, Box::new(resolved_ret)),
682                };
683                self.functions.insert(*name, FunctionRecord { param_names, scheme });
684                Ok(())
685            }
686
687            Stmt::Return { value } => {
688                let ty = match value {
689                    Some(expr) => self.infer_expr(expr)?,
690                    None => InferType::Unit,
691                };
692                if let Some(expected) = self.current_return_type.clone() {
693                    // Hard check for explicit return type annotations
694                    if expected != InferType::Unknown {
695                        self.table.unify(&ty, &expected)?;
696                    }
697                }
698                Ok(())
699            }
700
701            Stmt::Repeat { pattern, iterable, body } => {
702                let iterable_ty = self.infer_expr(iterable)?;
703                let elem_ty = match self.table.zonk(&iterable_ty) {
704                    InferType::Seq(inner) | InferType::Set(inner) => *inner,
705                    InferType::Map(k, _) => *k,
706                    _ => InferType::Unknown,
707                };
708                match pattern {
709                    Pattern::Identifier(sym) => self.bind_var(*sym, elem_ty),
710                    Pattern::Tuple(syms) => {
711                        for sym in syms {
712                            self.bind_var(*sym, InferType::Unknown);
713                        }
714                    }
715                }
716                for s in *body {
717                    self.infer_stmt(s)?;
718                }
719                Ok(())
720            }
721
722            Stmt::If { then_block, else_block, .. } => {
723                for s in *then_block {
724                    self.infer_stmt(s)?;
725                }
726                if let Some(else_b) = else_block {
727                    for s in *else_b {
728                        self.infer_stmt(s)?;
729                    }
730                }
731                Ok(())
732            }
733
734            Stmt::While { body, .. } => {
735                for s in *body {
736                    self.infer_stmt(s)?;
737                }
738                Ok(())
739            }
740
741            Stmt::Inspect { target, arms, .. } => {
742                let _target_ty = self.infer_expr(target)?;
743                for arm in arms {
744                    self.push_scope();
745                    self.infer_inspect_arm(arm)?;
746                    self.pop_scope();
747                }
748                Ok(())
749            }
750
751            Stmt::Zone { body, .. } => {
752                for s in *body {
753                    self.infer_stmt(s)?;
754                }
755                Ok(())
756            }
757
758            Stmt::ReadFrom { var, .. } => {
759                self.bind_var(*var, InferType::String);
760                Ok(())
761            }
762
763            Stmt::CreatePipe { var, element_type, .. } => {
764                let elem = InferType::from_type_name(self.interner.resolve(*element_type));
765                self.bind_var(*var, elem);
766                Ok(())
767            }
768
769            Stmt::ReceivePipe { var, pipe } => {
770                // Pipe var was registered with its element type by CreatePipe
771                let elem_ty = self.infer_expr(pipe)?;
772                self.bind_var(*var, elem_ty);
773                Ok(())
774            }
775
776            Stmt::TryReceivePipe { var, pipe } => {
777                let elem_ty = self.infer_expr(pipe)?;
778                // TryReceivePipe yields Option of elem type
779                self.bind_var(*var, InferType::Option(Box::new(elem_ty)));
780                Ok(())
781            }
782
783            Stmt::Pop { into: Some(var), collection } => {
784                let coll_ty = self.infer_expr(collection)?;
785                let elem_ty = match self.table.zonk(&coll_ty) {
786                    InferType::Seq(inner) | InferType::Set(inner) => *inner,
787                    _ => InferType::Unknown,
788                };
789                self.bind_var(*var, elem_ty);
790                Ok(())
791            }
792
793            Stmt::AwaitMessage { into, .. } => {
794                self.bind_var(*into, InferType::Unknown);
795                Ok(())
796            }
797
798            Stmt::LaunchTaskWithHandle { handle, .. } => {
799                self.bind_var(*handle, InferType::Unknown);
800                Ok(())
801            }
802
803            Stmt::Concurrent { tasks } | Stmt::Parallel { tasks } => {
804                for s in *tasks {
805                    self.infer_stmt(s)?;
806                }
807                Ok(())
808            }
809
810            Stmt::Select { branches } => {
811                for branch in branches {
812                    match branch {
813                        crate::ast::stmt::SelectBranch::Receive { var, pipe, body } => {
814                            let elem_ty = self.infer_expr(pipe)?;
815                            self.push_scope();
816                            self.bind_var(*var, elem_ty);
817                            for s in *body {
818                                self.infer_stmt(s)?;
819                            }
820                            self.pop_scope();
821                        }
822                        crate::ast::stmt::SelectBranch::Timeout { body, .. } => {
823                            for s in *body {
824                                self.infer_stmt(s)?;
825                            }
826                        }
827                    }
828                }
829                Ok(())
830            }
831
832            _ => Ok(()),
833        }
834    }
835
836    /// Process a single Inspect match arm, binding variant field types.
837    fn infer_inspect_arm(
838        &mut self,
839        arm: &crate::ast::stmt::MatchArm,
840    ) -> Result<(), TypeError> {
841        if let Some(variant_sym) = arm.variant {
842            if let Some((_, variant_def)) = self.registry.find_variant(variant_sym) {
843                // Clone what we need to avoid borrow issues
844                let fields: Vec<_> = variant_def
845                    .fields
846                    .iter()
847                    .map(|f| (f.name, f.ty.clone()))
848                    .collect();
849
850                for (field_sym, binding_sym) in &arm.bindings {
851                    let ty = fields
852                        .iter()
853                        .find(|(name, _)| *name == *field_sym)
854                        .map(|(_, ty)| {
855                            InferType::from_field_type(ty, self.interner, &HashMap::new())
856                        })
857                        .unwrap_or(InferType::Unknown);
858                    self.bind_var(*binding_sym, ty);
859                }
860            } else {
861                // Unknown variant → bind all as Unknown
862                for (_, binding_sym) in &arm.bindings {
863                    self.bind_var(*binding_sym, InferType::Unknown);
864                }
865            }
866        } else {
867            // Otherwise arm: wildcard bindings
868            for (_, binding_sym) in &arm.bindings {
869                self.bind_var(*binding_sym, InferType::Unknown);
870            }
871        }
872
873        for s in arm.body {
874            self.infer_stmt(s)?;
875        }
876        Ok(())
877    }
878}
879
880// ============================================================================
881// Entry point
882// ============================================================================
883
884/// Check a LOGOS program and return a typed `TypeEnv` for codegen.
885///
886/// Replaces `TypeEnv::infer_program`. Returns `Err(TypeError)` only on
887/// genuine type contradictions (e.g., `Let x: Int be "hello"`).
888/// Ambiguous types fall back to `LogosType::Unknown` silently.
889pub fn check_program(
890    stmts: &[Stmt],
891    interner: &Interner,
892    registry: &TypeRegistry,
893) -> Result<TypeEnv, TypeError> {
894    let mut env = CheckEnv::new(registry, interner);
895
896    // Pre-pass: register top-level function signatures for forward references
897    env.preregister_functions(stmts);
898
899    // Main pass: check all top-level statements
900    for stmt in stmts {
901        env.infer_stmt(stmt)?;
902    }
903
904    Ok(env.into_type_env())
905}
906
907// ============================================================================
908// Tests
909// ============================================================================
910
911#[cfg(test)]
912mod tests {
913    use super::*;
914    use crate::ast::stmt::{Expr, Literal, Stmt, TypeExpr};
915    use crate::intern::Interner;
916
917    // =========================================================================
918    // Helpers
919    // =========================================================================
920
921    fn mk_interner() -> Interner {
922        Interner::new()
923    }
924
925    fn run(stmts: &[Stmt], interner: &Interner) -> TypeEnv {
926        check_program(stmts, interner, &TypeRegistry::new()).expect("check_program failed")
927    }
928
929    // =========================================================================
930    // Let literal inference
931    // =========================================================================
932
933    #[test]
934    fn let_literal_int() {
935        let mut interner = mk_interner();
936        let x = interner.intern("x");
937        let val = Expr::Literal(Literal::Number(42));
938        let stmts = [Stmt::Let { var: x, ty: None, value: &val, mutable: false }];
939        let env = run(&stmts, &interner);
940        assert_eq!(env.lookup(x), &LogosType::Int);
941    }
942
943    #[test]
944    fn let_literal_float() {
945        let mut interner = mk_interner();
946        let x = interner.intern("x");
947        let val = Expr::Literal(Literal::Float(3.14));
948        let stmts = [Stmt::Let { var: x, ty: None, value: &val, mutable: false }];
949        let env = run(&stmts, &interner);
950        assert_eq!(env.lookup(x), &LogosType::Float);
951    }
952
953    #[test]
954    fn let_literal_string() {
955        let mut interner = mk_interner();
956        let s = interner.intern("s");
957        let hello = interner.intern("hello");
958        let val = Expr::Literal(Literal::Text(hello));
959        let stmts = [Stmt::Let { var: s, ty: None, value: &val, mutable: false }];
960        let env = run(&stmts, &interner);
961        assert_eq!(env.lookup(s), &LogosType::String);
962    }
963
964    // =========================================================================
965    // Let with type annotation
966    // =========================================================================
967
968    #[test]
969    fn let_with_annotation_uses_annotation() {
970        let mut interner = mk_interner();
971        let x = interner.intern("x");
972        let float_sym = interner.intern("Real");
973        let val = Expr::Literal(Literal::Number(5)); // Int value
974        let ty_ann = TypeExpr::Primitive(float_sym);
975        let stmts = [Stmt::Let { var: x, ty: Some(&ty_ann), value: &val, mutable: false }];
976        let env = run(&stmts, &interner);
977        // Annotation wins: Int unifies with Float (numeric)
978        assert_eq!(env.lookup(x), &LogosType::Float);
979    }
980
981    #[test]
982    fn let_type_mismatch_fails() {
983        let mut interner = mk_interner();
984        let x = interner.intern("x");
985        let int_sym = interner.intern("Int");
986        let val = Expr::Literal(Literal::Text(Symbol::EMPTY));
987        let ty_ann = TypeExpr::Primitive(int_sym);
988        let stmts = [Stmt::Let { var: x, ty: Some(&ty_ann), value: &val, mutable: false }];
989        let result = check_program(&stmts, &interner, &TypeRegistry::new());
990        assert!(result.is_err(), "Int and Text should not unify");
991    }
992
993    // =========================================================================
994    // Empty list → Seq(Unknown)
995    // =========================================================================
996
997    #[test]
998    fn empty_list_is_seq_unknown() {
999        let mut interner = mk_interner();
1000        let xs = interner.intern("xs");
1001        let val = Expr::List(vec![]);
1002        let stmts = [Stmt::Let { var: xs, ty: None, value: &val, mutable: false }];
1003        let env = run(&stmts, &interner);
1004        // Should be Seq of something (Unknown because element type is unsolved)
1005        assert!(matches!(env.lookup(xs), LogosType::Seq(_)));
1006    }
1007
1008    #[test]
1009    fn non_empty_list_infers_element_type() {
1010        let mut interner = mk_interner();
1011        let xs = interner.intern("xs");
1012        let one = Expr::Literal(Literal::Number(1));
1013        let two = Expr::Literal(Literal::Number(2));
1014        let val = Expr::List(vec![&one, &two]);
1015        let stmts = [Stmt::Let { var: xs, ty: None, value: &val, mutable: false }];
1016        let env = run(&stmts, &interner);
1017        assert_eq!(env.lookup(xs), &LogosType::Seq(Box::new(LogosType::Int)));
1018    }
1019
1020    // =========================================================================
1021    // OptionNone → Option(Unknown)
1022    // =========================================================================
1023
1024    #[test]
1025    fn option_none_is_option_unknown() {
1026        let mut interner = mk_interner();
1027        let x = interner.intern("x");
1028        let val = Expr::OptionNone;
1029        let stmts = [Stmt::Let { var: x, ty: None, value: &val, mutable: false }];
1030        let env = run(&stmts, &interner);
1031        assert!(matches!(env.lookup(x), LogosType::Option(_)));
1032    }
1033
1034    #[test]
1035    fn option_some_infers_inner_type() {
1036        let mut interner = mk_interner();
1037        let x = interner.intern("x");
1038        let inner = Expr::Literal(Literal::Number(42));
1039        let val = Expr::OptionSome { value: &inner };
1040        let stmts = [Stmt::Let { var: x, ty: None, value: &val, mutable: false }];
1041        let env = run(&stmts, &interner);
1042        assert_eq!(env.lookup(x), &LogosType::Option(Box::new(LogosType::Int)));
1043    }
1044
1045    // =========================================================================
1046    // Function def and call
1047    // =========================================================================
1048
1049    #[test]
1050    fn function_def_registers_signature() {
1051        let mut interner = mk_interner();
1052        let f = interner.intern("double");
1053        let x_param = interner.intern("x");
1054        let int_sym = interner.intern("Int");
1055        let int_ty = TypeExpr::Primitive(int_sym);
1056        let ret_ty = TypeExpr::Primitive(int_sym);
1057        let lit = Expr::Literal(Literal::Number(0));
1058        let ret_stmt = Stmt::Return { value: Some(&lit) };
1059        let body = [ret_stmt];
1060        let stmts = [Stmt::FunctionDef {
1061            name: f,
1062            generics: vec![],
1063            params: vec![(x_param, &int_ty)],
1064            body: &body,
1065            return_type: Some(&ret_ty),
1066            is_native: false,
1067            native_path: None,
1068            is_exported: false,
1069            export_target: None,
1070            opt_flags: HashSet::new(),
1071        }];
1072        let env = run(&stmts, &interner);
1073        let sig = env.lookup_fn(f).expect("function should be registered");
1074        assert_eq!(sig.return_type, LogosType::Int);
1075        assert_eq!(sig.params.len(), 1);
1076        assert_eq!(sig.params[0].1, LogosType::Int);
1077    }
1078
1079    #[test]
1080    fn function_call_returns_registered_type() {
1081        let mut interner = mk_interner();
1082        let f = interner.intern("compute");
1083        let float_sym = interner.intern("Real");
1084        let float_ty = TypeExpr::Primitive(float_sym);
1085        let lit = Expr::Literal(Literal::Float(1.0));
1086        let ret_stmt = Stmt::Return { value: Some(&lit) };
1087        let body = [ret_stmt];
1088        let fn_def = Stmt::FunctionDef {
1089            name: f,
1090            generics: vec![],
1091            params: vec![],
1092            body: &body,
1093            return_type: Some(&float_ty),
1094            is_native: false,
1095            native_path: None,
1096            is_exported: false,
1097            export_target: None,
1098            opt_flags: HashSet::new(),
1099        };
1100        let result_var = interner.intern("result");
1101        let call = Expr::Call { function: f, args: vec![] };
1102        let let_stmt = Stmt::Let { var: result_var, ty: None, value: &call, mutable: false };
1103        let stmts = [fn_def, let_stmt];
1104        let env = run(&stmts, &interner);
1105        assert_eq!(env.lookup(result_var), &LogosType::Float);
1106    }
1107
1108    // =========================================================================
1109    // ReadFrom is String
1110    // =========================================================================
1111
1112    #[test]
1113    fn readfrom_is_string() {
1114        let mut interner = mk_interner();
1115        let v = interner.intern("input");
1116        let stmts = [Stmt::ReadFrom {
1117            var: v,
1118            source: crate::ast::stmt::ReadSource::Console,
1119        }];
1120        let env = run(&stmts, &interner);
1121        assert_eq!(env.lookup(v), &LogosType::String);
1122    }
1123
1124    // =========================================================================
1125    // Repeat loop variable gets element type
1126    // =========================================================================
1127
1128    #[test]
1129    fn repeat_loop_var_gets_element_type() {
1130        let mut interner = mk_interner();
1131        let items = interner.intern("items");
1132        let elem = interner.intern("elem");
1133        let one = Expr::Literal(Literal::Number(1));
1134        let list = Expr::List(vec![&one]);
1135        let let_items = Stmt::Let { var: items, ty: None, value: &list, mutable: false };
1136        let items_ref = Expr::Identifier(items);
1137        let repeat = Stmt::Repeat {
1138            pattern: Pattern::Identifier(elem),
1139            iterable: &items_ref,
1140            body: &[],
1141        };
1142        let stmts = [let_items, repeat];
1143        let env = run(&stmts, &interner);
1144        assert_eq!(env.lookup(elem), &LogosType::Int);
1145    }
1146
1147    // =========================================================================
1148    // Field access resolves to struct field type (uses registry)
1149    // =========================================================================
1150
1151    #[test]
1152    fn field_access_resolves_with_registry() {
1153        use crate::analysis::{FieldDef, FieldType, TypeDef};
1154
1155        let mut interner = mk_interner();
1156        let point_sym = interner.intern("Point");
1157        let x_field_sym = interner.intern("x");
1158        let int_sym = interner.intern("Int");
1159        let p_var = interner.intern("p");
1160        let result_var = interner.intern("px");
1161
1162        // Build a registry with a struct Point { x: Int }
1163        let mut registry = TypeRegistry::new();
1164        registry.register(
1165            point_sym,
1166            TypeDef::Struct {
1167                fields: vec![FieldDef {
1168                    name: x_field_sym,
1169                    ty: FieldType::Primitive(int_sym),
1170                    is_public: true,
1171                }],
1172                generics: vec![],
1173                is_portable: false,
1174                is_shared: false,
1175            },
1176        );
1177
1178        // Let p be a new Point.
1179        let new_point = Expr::New { type_name: point_sym, type_args: vec![], init_fields: vec![] };
1180        let let_p = Stmt::Let { var: p_var, ty: None, value: &new_point, mutable: false };
1181
1182        // Let px be p's x.
1183        let p_ref = Expr::Identifier(p_var);
1184        let field_access = Expr::FieldAccess { object: &p_ref, field: x_field_sym };
1185        let let_px = Stmt::Let { var: result_var, ty: None, value: &field_access, mutable: false };
1186
1187        let stmts = [let_p, let_px];
1188        let env = check_program(&stmts, &interner, &registry).expect("check_program failed");
1189        assert_eq!(env.lookup(result_var), &LogosType::Int);
1190    }
1191
1192    // =========================================================================
1193    // Forward reference: calling a function defined later
1194    // =========================================================================
1195
1196    #[test]
1197    fn forward_reference_function_call() {
1198        let mut interner = mk_interner();
1199        let f = interner.intern("later_fn");
1200        let result_var = interner.intern("r");
1201        let bool_sym = interner.intern("Bool");
1202        let bool_ty = TypeExpr::Primitive(bool_sym);
1203
1204        // Let r be later_fn().  (before the function def)
1205        let call = Expr::Call { function: f, args: vec![] };
1206        let let_r = Stmt::Let { var: result_var, ty: None, value: &call, mutable: false };
1207
1208        // ## Function later_fn -> Bool:
1209        let lit = Expr::Literal(Literal::Boolean(true));
1210        let ret_stmt = Stmt::Return { value: Some(&lit) };
1211        let body = [ret_stmt];
1212        let fn_def = Stmt::FunctionDef {
1213            name: f,
1214            generics: vec![],
1215            params: vec![],
1216            body: &body,
1217            return_type: Some(&bool_ty),
1218            is_native: false,
1219            native_path: None,
1220            is_exported: false,
1221            export_target: None,
1222            opt_flags: HashSet::new(),
1223        };
1224
1225        // Note: let_r comes BEFORE fn_def in the slice
1226        let stmts = [let_r, fn_def];
1227        let env = run(&stmts, &interner);
1228        assert_eq!(env.lookup(result_var), &LogosType::Bool);
1229    }
1230
1231    // =========================================================================
1232    // Type mismatch on return
1233    // =========================================================================
1234
1235    #[test]
1236    fn return_type_mismatch_fails() {
1237        let mut interner = mk_interner();
1238        let f = interner.intern("f");
1239        let int_sym = interner.intern("Int");
1240        let int_ty = TypeExpr::Primitive(int_sym);
1241        // Function annotated as -> Int but returns Text
1242        let lit = Expr::Literal(Literal::Text(Symbol::EMPTY));
1243        let ret_stmt = Stmt::Return { value: Some(&lit) };
1244        let body = [ret_stmt];
1245        let stmts = [Stmt::FunctionDef {
1246            name: f,
1247            generics: vec![],
1248            params: vec![],
1249            body: &body,
1250            return_type: Some(&int_ty),
1251            is_native: false,
1252            native_path: None,
1253            is_exported: false,
1254            export_target: None,
1255            opt_flags: HashSet::new(),
1256        }];
1257        let result = check_program(&stmts, &interner, &TypeRegistry::new());
1258        assert!(result.is_err(), "returning Text from Int function should fail");
1259    }
1260
1261    // =========================================================================
1262    // New user-defined type → UserDefined
1263    // =========================================================================
1264
1265    #[test]
1266    fn new_user_defined_is_user_defined_type() {
1267        let mut interner = mk_interner();
1268        let point = interner.intern("Point");
1269        let p = interner.intern("p");
1270        let new_point = Expr::New { type_name: point, type_args: vec![], init_fields: vec![] };
1271        let stmts = [Stmt::Let { var: p, ty: None, value: &new_point, mutable: false }];
1272        let env = run(&stmts, &interner);
1273        assert_eq!(env.lookup(p), &LogosType::UserDefined(point));
1274    }
1275
1276    // =========================================================================
1277    // Legacy API preserved: to_legacy_variable_types / to_legacy_string_vars
1278    // =========================================================================
1279
1280    #[test]
1281    fn string_vars_in_legacy_api() {
1282        let mut interner = mk_interner();
1283        let s = interner.intern("name");
1284        let hello = interner.intern("hello");
1285        let val = Expr::Literal(Literal::Text(hello));
1286        let stmts = [Stmt::Let { var: s, ty: None, value: &val, mutable: false }];
1287        let env = run(&stmts, &interner);
1288        assert!(env.to_legacy_string_vars().contains(&s));
1289    }
1290
1291    #[test]
1292    fn unknown_vars_filtered_in_legacy_api() {
1293        let mut interner = mk_interner();
1294        let x = interner.intern("x");
1295        let val = Expr::OptionNone; // Unknown inner type
1296        let stmts = [Stmt::Let { var: x, ty: None, value: &val, mutable: false }];
1297        let env = run(&stmts, &interner);
1298        // Option(Unknown) → not in string_vars, not filtered as error
1299        let legacy = env.to_legacy_variable_types();
1300        // Option(Unknown) maps to "Option<_>", which is concrete enough
1301        assert!(!legacy.is_empty() || legacy.is_empty()); // just don't panic
1302    }
1303
1304    // =========================================================================
1305    // Generic (polymorphic) functions — Phase 3
1306    // =========================================================================
1307
1308    #[test]
1309    fn generic_identity_infers_int_return() {
1310        // ## To identity of [T] (x: T) -> T:
1311        //     Return x.
1312        // Let r be identity(42).  → r is Int
1313        let mut interner = mk_interner();
1314        let f = interner.intern("identity");
1315        let x_param = interner.intern("x");
1316        let t_sym = interner.intern("T");
1317        let t_ty = TypeExpr::Primitive(t_sym);
1318        let x_ref = Expr::Identifier(x_param);
1319        let ret_stmt = Stmt::Return { value: Some(&x_ref) };
1320        let body = [ret_stmt];
1321        let fn_def = Stmt::FunctionDef {
1322            name: f,
1323            generics: vec![t_sym],
1324            params: vec![(x_param, &t_ty)],
1325            body: &body,
1326            return_type: Some(&t_ty),
1327            is_native: false,
1328            native_path: None,
1329            is_exported: false,
1330            export_target: None,
1331            opt_flags: HashSet::new(),
1332        };
1333        let r = interner.intern("r");
1334        let lit = Expr::Literal(Literal::Number(42));
1335        let call = Expr::Call { function: f, args: vec![&lit] };
1336        let let_r = Stmt::Let { var: r, ty: None, value: &call, mutable: false };
1337        let stmts = [fn_def, let_r];
1338        let env = run(&stmts, &interner);
1339        assert_eq!(env.lookup(r), &LogosType::Int,
1340            "identity(42) should return Int, got {:?}", env.lookup(r));
1341    }
1342
1343    #[test]
1344    fn generic_identity_infers_bool_return() {
1345        // Same identity function, called with Bool → returns Bool.
1346        let mut interner = mk_interner();
1347        let f = interner.intern("identity");
1348        let x_param = interner.intern("x");
1349        let t_sym = interner.intern("T");
1350        let t_ty = TypeExpr::Primitive(t_sym);
1351        let x_ref = Expr::Identifier(x_param);
1352        let ret_stmt = Stmt::Return { value: Some(&x_ref) };
1353        let body = [ret_stmt];
1354        let fn_def = Stmt::FunctionDef {
1355            name: f,
1356            generics: vec![t_sym],
1357            params: vec![(x_param, &t_ty)],
1358            body: &body,
1359            return_type: Some(&t_ty),
1360            is_native: false,
1361            native_path: None,
1362            is_exported: false,
1363            export_target: None,
1364            opt_flags: HashSet::new(),
1365        };
1366        let r = interner.intern("r");
1367        let lit = Expr::Literal(Literal::Boolean(true));
1368        let call = Expr::Call { function: f, args: vec![&lit] };
1369        let let_r = Stmt::Let { var: r, ty: None, value: &call, mutable: false };
1370        let stmts = [fn_def, let_r];
1371        let env = run(&stmts, &interner);
1372        assert_eq!(env.lookup(r), &LogosType::Bool,
1373            "identity(true) should return Bool, got {:?}", env.lookup(r));
1374    }
1375
1376    #[test]
1377    fn generic_two_type_params_first() {
1378        // ## To first of [A] and [B] (a: A, b: B) -> A:
1379        //     Return a.
1380        // Let r be first(42, true).  → r is Int (first type param)
1381        let mut interner = mk_interner();
1382        let f = interner.intern("first");
1383        let a_param = interner.intern("a");
1384        let b_param = interner.intern("b");
1385        let a_sym = interner.intern("A");
1386        let b_sym = interner.intern("B");
1387        let a_ty = TypeExpr::Primitive(a_sym);
1388        let b_ty = TypeExpr::Primitive(b_sym);
1389        let a_ref = Expr::Identifier(a_param);
1390        let ret_stmt = Stmt::Return { value: Some(&a_ref) };
1391        let body = [ret_stmt];
1392        let fn_def = Stmt::FunctionDef {
1393            name: f,
1394            generics: vec![a_sym, b_sym],
1395            params: vec![(a_param, &a_ty), (b_param, &b_ty)],
1396            body: &body,
1397            return_type: Some(&a_ty),
1398            is_native: false,
1399            native_path: None,
1400            is_exported: false,
1401            export_target: None,
1402            opt_flags: HashSet::new(),
1403        };
1404        let r = interner.intern("r");
1405        let lit_int = Expr::Literal(Literal::Number(42));
1406        let lit_bool = Expr::Literal(Literal::Boolean(true));
1407        let call = Expr::Call { function: f, args: vec![&lit_int, &lit_bool] };
1408        let let_r = Stmt::Let { var: r, ty: None, value: &call, mutable: false };
1409        let stmts = [fn_def, let_r];
1410        let env = run(&stmts, &interner);
1411        assert_eq!(env.lookup(r), &LogosType::Int,
1412            "first(42, true) should return Int (first param type), got {:?}", env.lookup(r));
1413    }
1414
1415    #[test]
1416    fn generic_calls_are_independent() {
1417        // Each call to a generic function gets its own fresh type variables.
1418        // identity(42) → Int, identity(true) → Bool, independent results.
1419        let mut interner = mk_interner();
1420        let f = interner.intern("identity");
1421        let x_param = interner.intern("x");
1422        let t_sym = interner.intern("T");
1423        let t_ty = TypeExpr::Primitive(t_sym);
1424        let x_ref = Expr::Identifier(x_param);
1425        let ret_stmt = Stmt::Return { value: Some(&x_ref) };
1426        let body = [ret_stmt];
1427        let fn_def = Stmt::FunctionDef {
1428            name: f,
1429            generics: vec![t_sym],
1430            params: vec![(x_param, &t_ty)],
1431            body: &body,
1432            return_type: Some(&t_ty),
1433            is_native: false,
1434            native_path: None,
1435            is_exported: false,
1436            export_target: None,
1437            opt_flags: HashSet::new(),
1438        };
1439        let r1 = interner.intern("r1");
1440        let r2 = interner.intern("r2");
1441        let lit_int = Expr::Literal(Literal::Number(42));
1442        let lit_bool = Expr::Literal(Literal::Boolean(true));
1443        let call1 = Expr::Call { function: f, args: vec![&lit_int] };
1444        let call2 = Expr::Call { function: f, args: vec![&lit_bool] };
1445        let let_r1 = Stmt::Let { var: r1, ty: None, value: &call1, mutable: false };
1446        let let_r2 = Stmt::Let { var: r2, ty: None, value: &call2, mutable: false };
1447        let stmts = [fn_def, let_r1, let_r2];
1448        let env = run(&stmts, &interner);
1449        assert_eq!(env.lookup(r1), &LogosType::Int,
1450            "identity(42) should be Int, got {:?}", env.lookup(r1));
1451        assert_eq!(env.lookup(r2), &LogosType::Bool,
1452            "identity(true) should be Bool, got {:?}", env.lookup(r2));
1453    }
1454
1455    #[test]
1456    fn monomorphic_functions_unaffected_by_generics() {
1457        // Non-generic functions still work correctly with the updated machinery.
1458        let mut interner = mk_interner();
1459        let f = interner.intern("double");
1460        let x_param = interner.intern("x");
1461        let int_sym = interner.intern("Int");
1462        let int_ty = TypeExpr::Primitive(int_sym);
1463        let x_ref = Expr::Identifier(x_param);
1464        let lit2 = Expr::Literal(Literal::Number(2));
1465        let mul = Expr::BinaryOp {
1466            op: BinaryOpKind::Multiply,
1467            left: &x_ref,
1468            right: &lit2,
1469        };
1470        let ret_stmt = Stmt::Return { value: Some(&mul) };
1471        let body = [ret_stmt];
1472        let fn_def = Stmt::FunctionDef {
1473            name: f,
1474            generics: vec![],
1475            params: vec![(x_param, &int_ty)],
1476            body: &body,
1477            return_type: Some(&int_ty),
1478            is_native: false,
1479            native_path: None,
1480            is_exported: false,
1481            export_target: None,
1482            opt_flags: HashSet::new(),
1483        };
1484        let r = interner.intern("r");
1485        let lit5 = Expr::Literal(Literal::Number(5));
1486        let call = Expr::Call { function: f, args: vec![&lit5] };
1487        let let_r = Stmt::Let { var: r, ty: None, value: &call, mutable: false };
1488        let stmts = [fn_def, let_r];
1489        let env = run(&stmts, &interner);
1490        assert_eq!(env.lookup(r), &LogosType::Int,
1491            "double(5) should return Int, got {:?}", env.lookup(r));
1492    }
1493
1494    #[test]
1495    fn generic_forward_reference_resolves() {
1496        // Let r be identity(42).
1497        // ## To identity of [T] (x: T) -> T:  ← defined AFTER the call
1498        //     Return x.
1499        // The pre-pass must register generics before the main pass sees the call.
1500        let mut interner = mk_interner();
1501        let f = interner.intern("identity");
1502        let x_param = interner.intern("x");
1503        let t_sym = interner.intern("T");
1504        let t_ty = TypeExpr::Primitive(t_sym);
1505        let x_ref = Expr::Identifier(x_param);
1506        let ret_stmt = Stmt::Return { value: Some(&x_ref) };
1507        let body = [ret_stmt];
1508        let fn_def = Stmt::FunctionDef {
1509            name: f,
1510            generics: vec![t_sym],
1511            params: vec![(x_param, &t_ty)],
1512            body: &body,
1513            return_type: Some(&t_ty),
1514            is_native: false,
1515            native_path: None,
1516            is_exported: false,
1517            export_target: None,
1518            opt_flags: HashSet::new(),
1519        };
1520        let r = interner.intern("r");
1521        let lit = Expr::Literal(Literal::Number(99));
1522        let call = Expr::Call { function: f, args: vec![&lit] };
1523        let let_r = Stmt::Let { var: r, ty: None, value: &call, mutable: false };
1524        // Call appears BEFORE the function definition
1525        let stmts = [let_r, fn_def];
1526        let env = run(&stmts, &interner);
1527        assert_eq!(env.lookup(r), &LogosType::Int,
1528            "forward-ref identity(99) should be Int, got {:?}", env.lookup(r));
1529    }
1530}