Skip to main content

mimium_lang/compiler/
typing.rs

1use crate::ast::program::TypeAliasMap;
2use crate::ast::{Expr, Literal, RecordField};
3use crate::compiler::{EvalStage, intrinsics};
4use crate::interner::{ExprKey, ExprNodeId, Symbol, ToSymbol, TypeNodeId};
5use crate::pattern::{Pattern, TypedId, TypedPattern};
6use crate::types::{IntermediateId, PType, RecordTypeField, Type, TypeSchemeId, TypeVar};
7use crate::utils::metadata::Location;
8use crate::utils::{environment::Environment, error::ReportableError};
9use crate::{function, integer, numeric, unit};
10use itertools::Itertools;
11use std::collections::{BTreeMap, HashMap};
12use std::path::PathBuf;
13use std::sync::{Arc, RwLock};
14use thiserror::Error;
15
16mod unification;
17pub(crate) use unification::Relation;
18use unification::{Error as UnificationError, unify_types};
19
20#[derive(Clone, Debug, Error)]
21#[error("Type Inference Error")]
22pub enum Error {
23    TypeMismatch {
24        left: (TypeNodeId, Location),
25        right: (TypeNodeId, Location),
26    },
27    EscapeRequiresCodeType {
28        found: (TypeNodeId, Location),
29    },
30    LengthMismatch {
31        left: (usize, Location),
32        right: (usize, Location),
33    },
34    PatternMismatch((TypeNodeId, Location), (Pattern, Location)),
35    NonFunctionForLetRec(TypeNodeId, Location),
36    NonFunctionForApply(TypeNodeId, Location),
37    NonSupertypeArgument {
38        location: Location,
39        expected: TypeNodeId,
40        found: TypeNodeId,
41    },
42    CircularType(Location, Location),
43    IndexOutOfRange {
44        len: u16,
45        idx: u16,
46        loc: Location,
47    },
48    IndexForNonTuple(Location, TypeNodeId),
49    FieldForNonRecord(Location, TypeNodeId),
50    FieldNotExist {
51        field: Symbol,
52        loc: Location,
53        et: TypeNodeId,
54    },
55    DuplicateKeyInRecord {
56        key: Vec<Symbol>,
57        loc: Location,
58    },
59    DuplicateKeyInParams(Vec<(Symbol, Location)>),
60    /// The error of records, which contains both subtypes and supertypes.
61    IncompatibleKeyInRecord {
62        left: (Vec<(Symbol, TypeNodeId)>, Location),
63        right: (Vec<(Symbol, TypeNodeId)>, Location),
64    },
65    VariableNotFound(Symbol, Location),
66    /// Module not found in the current scope
67    ModuleNotFound {
68        module_path: Vec<Symbol>,
69        location: Location,
70    },
71    /// Member not found in a module
72    MemberNotFound {
73        module_path: Vec<Symbol>,
74        member: Symbol,
75        location: Location,
76    },
77    /// Attempted to access a private module member
78    PrivateMemberAccess {
79        module_path: Vec<Symbol>,
80        member: Symbol,
81        location: Location,
82    },
83    StageMismatch {
84        variable: Symbol,
85        expected_stage: EvalStage,
86        found_stage: EvalStage,
87        location: Location,
88    },
89    NonPrimitiveInFeed(Location),
90    /// Constructor pattern doesn't match any variant of the union type
91    ConstructorNotInUnion {
92        constructor: Symbol,
93        union_type: TypeNodeId,
94        location: Location,
95    },
96    /// Expected a union type for constructor pattern matching
97    ExpectedUnionType {
98        found: TypeNodeId,
99        location: Location,
100    },
101    /// Match expression is not exhaustive (missing patterns)
102    NonExhaustiveMatch {
103        missing_constructors: Vec<Symbol>,
104        location: Location,
105    },
106    /// Recursive type alias detected (infinite expansion)
107    RecursiveTypeAlias {
108        type_name: Symbol,
109        cycle: Vec<Symbol>,
110        location: Location,
111    },
112    /// Private type accessed from outside its module
113    PrivateTypeAccess {
114        module_path: Vec<Symbol>,
115        type_name: Symbol,
116        location: Location,
117    },
118    /// Public function leaking private type in its signature
119    PrivateTypeLeak {
120        function_name: Symbol,
121        private_type: Symbol,
122        location: Location,
123    },
124}
125
126impl ReportableError for Error {
127    fn get_message(&self) -> String {
128        match self {
129            Error::TypeMismatch { .. } => format!("Type mismatch"),
130            Error::EscapeRequiresCodeType { found: (ty, ..) } => {
131                format!(
132                    "Escape requires a code value, but found {}",
133                    ty.to_type().to_string_for_error()
134                )
135            }
136            Error::PatternMismatch(..) => format!("Pattern mismatch"),
137            Error::LengthMismatch { .. } => format!("Length of the elements are different"),
138            Error::NonFunctionForLetRec(_, _) => format!("`letrec` can take only function type."),
139            Error::NonFunctionForApply(_, _) => {
140                format!("This is not applicable because it is not a function type.")
141            }
142            Error::CircularType(_, _) => format!("Circular loop of type definition detected."),
143            Error::IndexOutOfRange { len, idx, .. } => {
144                format!("Length of tuple elements is {len} but index was {idx}")
145            }
146            Error::IndexForNonTuple(_, _) => {
147                format!("Index access for non-tuple variable.")
148            }
149            Error::VariableNotFound(symbol, _) => {
150                format!("Variable \"{symbol}\" not found in this scope")
151            }
152            Error::ModuleNotFound { module_path, .. } => {
153                let path_str = module_path
154                    .iter()
155                    .map(|s| s.to_string())
156                    .collect::<Vec<_>>()
157                    .join("::");
158                format!("Module \"{path_str}\" not found")
159            }
160            Error::MemberNotFound {
161                module_path,
162                member,
163                ..
164            } => {
165                let path_str = module_path
166                    .iter()
167                    .map(|s| s.to_string())
168                    .collect::<Vec<_>>()
169                    .join("::");
170                format!("Member \"{member}\" not found in module \"{path_str}\"")
171            }
172            Error::PrivateMemberAccess {
173                module_path,
174                member,
175                ..
176            } => {
177                let path_str = module_path
178                    .iter()
179                    .map(|s| s.to_string())
180                    .collect::<Vec<_>>()
181                    .join("::");
182                format!("Member \"{member}\" in module \"{path_str}\" is private")
183            }
184            Error::StageMismatch {
185                variable,
186                expected_stage,
187                found_stage,
188                ..
189            } => {
190                format!(
191                    "Variable {variable} is defined in stage {} but accessed from stage {}",
192                    found_stage.format_for_error(),
193                    expected_stage.format_for_error()
194                )
195            }
196            Error::NonPrimitiveInFeed(_) => {
197                format!("Function that uses `self` cannot return function type.")
198            }
199            Error::DuplicateKeyInParams { .. } => {
200                format!("Duplicate keys found in parameter list")
201            }
202            Error::DuplicateKeyInRecord { .. } => {
203                format!("Duplicate keys found in record type")
204            }
205            Error::FieldForNonRecord { .. } => {
206                format!("Field access for non-record variable.")
207            }
208            Error::FieldNotExist { field, .. } => {
209                format!("Field \"{field}\" does not exist in the record type")
210            }
211            Error::IncompatibleKeyInRecord { .. } => {
212                format!("Record type has incompatible keys.",)
213            }
214
215            Error::NonSupertypeArgument { .. } => {
216                format!("Arguments for functions are less than required.")
217            }
218            Error::ConstructorNotInUnion { constructor, .. } => {
219                format!("Constructor \"{constructor}\" is not a variant of the union type")
220            }
221            Error::ExpectedUnionType { found, .. } => {
222                format!(
223                    "Expected a union type but found {}",
224                    found.to_type().to_string_for_error()
225                )
226            }
227            Error::NonExhaustiveMatch {
228                missing_constructors,
229                ..
230            } => {
231                let missing = missing_constructors
232                    .iter()
233                    .map(|s| s.to_string())
234                    .collect::<Vec<_>>()
235                    .join(", ");
236                format!("Match expression is not exhaustive. Missing patterns: {missing}")
237            }
238            Error::RecursiveTypeAlias {
239                type_name, cycle, ..
240            } => {
241                let cycle_str = cycle
242                    .iter()
243                    .map(|s| s.to_string())
244                    .collect::<Vec<_>>()
245                    .join(" -> ");
246                format!(
247                    "Recursive type alias '{type_name}' detected. Cycle: {cycle_str} -> {type_name}. Use 'type rec' to declare recursive types."
248                )
249            }
250            Error::PrivateTypeAccess {
251                module_path,
252                type_name,
253                ..
254            } => {
255                let path_str = module_path
256                    .iter()
257                    .map(|s| s.to_string())
258                    .collect::<Vec<_>>()
259                    .join("::");
260                format!(
261                    "Type '{type_name}' in module '{path_str}' is private and cannot be accessed from outside"
262                )
263            }
264            Error::PrivateTypeLeak {
265                function_name,
266                private_type,
267                ..
268            } => {
269                format!(
270                    "Public function '{function_name}' cannot expose private type '{private_type}' in its signature"
271                )
272            }
273        }
274    }
275    fn get_labels(&self) -> Vec<(Location, String)> {
276        match self {
277            Error::TypeMismatch {
278                left: (lty, locl),
279                right: (rty, locr),
280            } => {
281                let expected = lty.get_root().to_type().to_string_for_error();
282                let found = rty.get_root().to_type().to_string_for_error();
283                let is_dummy = |loc: &Location| {
284                    loc.path.as_os_str().is_empty() || (loc.span.start == 0 && loc.span.end == 0)
285                };
286                let normalize_loc = |primary: &Location, fallback: &Location| {
287                    let mut loc = if is_dummy(primary) {
288                        fallback.clone()
289                    } else {
290                        primary.clone()
291                    };
292
293                    if loc.path.as_os_str().is_empty() {
294                        loc.path = if !primary.path.as_os_str().is_empty() {
295                            primary.path.clone()
296                        } else {
297                            fallback.path.clone()
298                        };
299                    }
300
301                    if loc.span.start == 0 && loc.span.end == 0 {
302                        if !(primary.span.start == 0 && primary.span.end == 0) {
303                            loc.span = primary.span.clone();
304                        } else if !(fallback.span.start == 0 && fallback.span.end == 0) {
305                            loc.span = fallback.span.clone();
306                        } else {
307                            loc.span = 0..1;
308                        }
309                    }
310                    loc
311                };
312
313                let left_loc = normalize_loc(locl, locr);
314                let right_loc = normalize_loc(locr, &left_loc);
315                if left_loc == right_loc {
316                    vec![(
317                        left_loc,
318                        format!("expected type: {expected}, found type: {found}"),
319                    )]
320                } else {
321                    vec![
322                        (left_loc, format!("expected type: {expected}")),
323                        (right_loc, format!("found type: {found}")),
324                    ]
325                }
326            }
327            Error::EscapeRequiresCodeType { found: (ty, loc) } => vec![(
328                loc.clone(),
329                format!(
330                    "escape expects `Code(T)`, but found {}. Escaping nested code containers such as arrays of quoted values is not supported",
331                    ty.to_type().to_string_for_error()
332                ),
333            )],
334            Error::PatternMismatch((ty, loct), (pat, locp)) => vec![
335                (loct.clone(), ty.to_type().to_string_for_error()),
336                (locp.clone(), pat.to_string()),
337            ],
338            Error::LengthMismatch {
339                left: (l, locl),
340                right: (r, locr),
341            } => vec![
342                (locl.clone(), format!("The length is {l}")),
343                (locr.clone(), format!("but the length for here is {r}")),
344            ],
345            Error::NonFunctionForLetRec(ty, loc) => {
346                vec![(loc.clone(), ty.to_type().to_string_for_error())]
347            }
348            Error::NonFunctionForApply(ty, loc) => {
349                vec![(loc.clone(), ty.to_type().to_string_for_error())]
350            }
351            Error::CircularType(loc1, loc2) => vec![
352                (loc1.clone(), format!("Circular type happens here")),
353                (loc2.clone(), format!("and here")),
354            ],
355            Error::IndexOutOfRange { loc, len, .. } => {
356                vec![(loc.clone(), format!("Length for this tuple is {len}"))]
357            }
358            Error::IndexForNonTuple(loc, ty) => {
359                vec![(
360                    loc.clone(),
361                    format!(
362                        "This is not tuple type but {}",
363                        ty.to_type().to_string_for_error()
364                    ),
365                )]
366            }
367            Error::VariableNotFound(symbol, loc) => {
368                vec![(loc.clone(), format!("{symbol} is not defined"))]
369            }
370            Error::ModuleNotFound {
371                module_path,
372                location,
373            } => {
374                let path_str = module_path
375                    .iter()
376                    .map(|s| s.to_string())
377                    .collect::<Vec<_>>()
378                    .join("::");
379                vec![(location.clone(), format!("Module \"{path_str}\" not found"))]
380            }
381            Error::MemberNotFound {
382                module_path,
383                member,
384                location,
385            } => {
386                let path_str = module_path
387                    .iter()
388                    .map(|s| s.to_string())
389                    .collect::<Vec<_>>()
390                    .join("::");
391                vec![(
392                    location.clone(),
393                    format!("\"{member}\" is not a member of \"{path_str}\""),
394                )]
395            }
396            Error::PrivateMemberAccess {
397                module_path,
398                member,
399                location,
400            } => {
401                let path_str = module_path
402                    .iter()
403                    .map(|s| s.to_string())
404                    .collect::<Vec<_>>()
405                    .join("::");
406                vec![(
407                    location.clone(),
408                    format!("\"{member}\" in \"{path_str}\" is private and cannot be accessed"),
409                )]
410            }
411            Error::StageMismatch {
412                variable,
413                expected_stage,
414                found_stage,
415                location,
416            } => {
417                vec![(
418                    location.clone(),
419                    format!(
420                        "Variable \"{variable}\" defined in stage {} cannot be accessed from stage {}",
421                        found_stage.format_for_error(),
422                        expected_stage.format_for_error()
423                    ),
424                )]
425            }
426            Error::NonPrimitiveInFeed(loc) => {
427                vec![(loc.clone(), format!("This cannot be function type."))]
428            }
429            Error::DuplicateKeyInRecord { key, loc } => {
430                vec![(
431                    loc.clone(),
432                    format!(
433                        "Duplicate keys \"{}\" found in record type",
434                        key.iter()
435                            .map(|s| s.to_string())
436                            .collect::<Vec<_>>()
437                            .join(", ")
438                    ),
439                )]
440            }
441            Error::DuplicateKeyInParams(keys) => keys
442                .iter()
443                .map(|(key, loc)| {
444                    (
445                        loc.clone(),
446                        format!("Duplicate key \"{key}\" found in parameter list"),
447                    )
448                })
449                .collect(),
450            Error::FieldForNonRecord(location, ty) => {
451                vec![(
452                    location.clone(),
453                    format!(
454                        "Field access for non-record type {}.",
455                        ty.to_type().to_string_for_error()
456                    ),
457                )]
458            }
459            Error::FieldNotExist { field, loc, et } => vec![(
460                loc.clone(),
461                format!(
462                    "Field \"{}\" does not exist in the type {}",
463                    field,
464                    et.to_type().to_string_for_error()
465                ),
466            )],
467            Error::IncompatibleKeyInRecord {
468                left: (left, lloc),
469                right: (right, rloc),
470            } => {
471                vec![
472                    (
473                        lloc.clone(),
474                        format!(
475                            "the record here contains{}",
476                            left.iter()
477                                .map(|(key, ty)| format!(
478                                    " \"{key}\":{}",
479                                    ty.to_type().to_string_for_error()
480                                ))
481                                .collect::<Vec<_>>()
482                                .join(", ")
483                        ),
484                    ),
485                    (
486                        rloc.clone(),
487                        format!(
488                            "but the record here contains {}",
489                            right
490                                .iter()
491                                .map(|(key, ty)| format!(
492                                    " \"{key}\":{}",
493                                    ty.to_type().to_string_for_error()
494                                ))
495                                .collect::<Vec<_>>()
496                                .join(", ")
497                        ),
498                    ),
499                ]
500            }
501
502            Error::NonSupertypeArgument {
503                location,
504                expected,
505                found,
506            } => {
507                vec![(
508                    location.clone(),
509                    format!(
510                        "Type {} is not a supertype of the expected type {}",
511                        found.to_type().to_string_for_error(),
512                        expected.to_type().to_string_for_error()
513                    ),
514                )]
515            }
516            Error::ConstructorNotInUnion {
517                constructor,
518                union_type,
519                location,
520            } => {
521                vec![(
522                    location.clone(),
523                    format!(
524                        "Constructor \"{constructor}\" is not a variant of {}",
525                        union_type.to_type().to_string_for_error()
526                    ),
527                )]
528            }
529            Error::ExpectedUnionType { found, location } => {
530                vec![(
531                    location.clone(),
532                    format!(
533                        "Expected a union type but found {}",
534                        found.to_type().to_string_for_error()
535                    ),
536                )]
537            }
538            Error::NonExhaustiveMatch {
539                missing_constructors,
540                location,
541            } => {
542                let missing = missing_constructors
543                    .iter()
544                    .map(|s| format!("\"{s}\""))
545                    .collect::<Vec<_>>()
546                    .join(", ");
547                vec![(location.clone(), format!("Missing patterns: {missing}"))]
548            }
549            Error::RecursiveTypeAlias {
550                type_name,
551                cycle,
552                location,
553            } => {
554                let cycle_str = cycle
555                    .iter()
556                    .map(|s| s.to_string())
557                    .collect::<Vec<_>>()
558                    .join(" -> ");
559                vec![(
560                    location.clone(),
561                    format!(
562                        "Type alias '{type_name}' creates a cycle: {cycle_str} -> {type_name}. Consider using 'type rec' instead of 'type alias'."
563                    ),
564                )]
565            }
566            Error::PrivateTypeAccess {
567                module_path,
568                type_name,
569                location,
570            } => {
571                let path_str = module_path
572                    .iter()
573                    .map(|s| s.to_string())
574                    .collect::<Vec<_>>()
575                    .join("::");
576                vec![(
577                    location.clone(),
578                    format!("Type '{type_name}' in module '{path_str}' is private"),
579                )]
580            }
581            Error::PrivateTypeLeak { location, .. } => {
582                vec![(
583                    location.clone(),
584                    "private type leaked in public function signature".to_string(),
585                )]
586            }
587        }
588    }
589}
590
591/// Information about a constructor in a user-defined sum type
592#[derive(Clone, Debug)]
593pub struct ConstructorInfo {
594    /// The type of the sum type this constructor belongs to
595    pub sum_type: TypeNodeId,
596    /// The index (tag) of this constructor in the sum type
597    pub tag_index: usize,
598    /// Optional payload type for this constructor
599    pub payload_type: Option<TypeNodeId>,
600}
601
602/// Map from constructor name to its info
603pub type ConstructorEnv = HashMap<Symbol, ConstructorInfo>;
604
605/// Result of looking up a field in a (possibly wrapped) type.
606enum FieldLookup {
607    /// Field was found with this type.
608    Found(TypeNodeId),
609    /// A record was reached but the field was not present.
610    RecordWithoutField,
611    /// No record type could be reached.
612    NotRecord,
613}
614
615#[derive(Clone, Debug)]
616pub struct InferContext {
617    interm_idx: IntermediateId,
618    typescheme_idx: TypeSchemeId,
619    level: u64,
620    stage: EvalStage,
621    instantiated_map: BTreeMap<TypeSchemeId, TypeNodeId>, //from type scheme to typevar
622    generalize_map: BTreeMap<IntermediateId, TypeSchemeId>,
623    result_memo: BTreeMap<ExprKey, TypeNodeId>,
624    explicit_type_param_scopes: Vec<BTreeMap<Symbol, TypeNodeId>>,
625    file_path: PathBuf,
626    pub env: Environment<(TypeNodeId, EvalStage)>,
627    /// Constructor environment for user-defined sum types
628    pub constructor_env: ConstructorEnv,
629    /// Type alias resolution map
630    pub type_aliases: HashMap<Symbol, TypeNodeId>,
631    /// Module information for visibility checking
632    module_info: Option<crate::ast::program::ModuleInfo>,
633    /// Match expressions to check for exhaustiveness after type resolution
634    match_expressions: Vec<(ExprNodeId, TypeNodeId)>,
635    pub errors: Vec<Error>,
636    /// Debug: unique ID for this infer_root call
637    pub infer_root_id: usize,
638}
639struct TypeCycle(pub Vec<Symbol>);
640
641impl InferContext {
642    pub fn new(
643        builtins: &[(Symbol, TypeNodeId)],
644        file_path: PathBuf,
645        type_declarations: Option<&crate::ast::program::TypeDeclarationMap>,
646        type_aliases: Option<&crate::ast::program::TypeAliasMap>,
647        module_info: Option<crate::ast::program::ModuleInfo>,
648    ) -> Self {
649        let mut res = Self {
650            interm_idx: Default::default(),
651            typescheme_idx: Default::default(),
652            level: Default::default(),
653            stage: EvalStage::Stage(0), // Start at stage 0
654            instantiated_map: Default::default(),
655            generalize_map: Default::default(),
656            result_memo: Default::default(),
657            explicit_type_param_scopes: Default::default(),
658            file_path,
659            env: Environment::<(TypeNodeId, EvalStage)>::default(),
660            constructor_env: Default::default(),
661            type_aliases: Default::default(),
662            module_info,
663            match_expressions: Default::default(),
664            errors: Default::default(),
665            infer_root_id: usize::MAX,
666        };
667        res.env.extend();
668        // Intrinsic types are persistent (available at all stages)
669        let intrinsics = Self::intrinsic_types()
670            .into_iter()
671            .map(|(name, ty)| (name, (ty, EvalStage::Persistent)))
672            .collect::<Vec<_>>();
673        res.env.add_bind(&intrinsics);
674        // Builtins are also persistent
675        let builtins = builtins
676            .iter()
677            .map(|(name, ty)| (*name, (*ty, EvalStage::Persistent)))
678            .collect::<Vec<_>>();
679        res.env.add_bind(&builtins);
680        // Register user-defined type constructors
681        if let Some(type_decls) = type_declarations {
682            res.register_type_declarations(type_decls);
683        }
684        // Register type aliases
685        if let Some(type_aliases) = type_aliases {
686            res.register_type_aliases(type_aliases);
687        }
688        res
689    }
690
691    fn is_explicit_type_param_name(name: Symbol) -> bool {
692        let s = name.as_str();
693        s.len() == 1 && s.as_bytes()[0].is_ascii_lowercase()
694    }
695
696    fn collect_explicit_type_params_in_type(ty: TypeNodeId, out: &mut BTreeMap<Symbol, Location>) {
697        match ty.to_type() {
698            Type::TypeAlias(name) if Self::is_explicit_type_param_name(name) => {
699                out.entry(name).or_insert_with(|| ty.to_loc());
700            }
701            Type::Array(elem) | Type::Ref(elem) | Type::Code(elem) | Type::Boxed(elem) => {
702                Self::collect_explicit_type_params_in_type(elem, out);
703            }
704            Type::Tuple(elems) | Type::Union(elems) => elems
705                .iter()
706                .for_each(|elem| Self::collect_explicit_type_params_in_type(*elem, out)),
707            Type::Record(fields) => fields
708                .iter()
709                .for_each(|field| Self::collect_explicit_type_params_in_type(field.ty, out)),
710            Type::Function { arg, ret } => {
711                Self::collect_explicit_type_params_in_type(arg, out);
712                Self::collect_explicit_type_params_in_type(ret, out);
713            }
714            _ => {}
715        }
716    }
717
718    fn with_explicit_type_param_scope_from_types<T>(
719        &mut self,
720        types: &[TypeNodeId],
721        f: impl FnOnce(&mut Self) -> T,
722    ) -> T {
723        let mut collected = BTreeMap::<Symbol, Location>::new();
724        types
725            .iter()
726            .for_each(|ty| Self::collect_explicit_type_params_in_type(*ty, &mut collected));
727        let map = collected
728            .into_iter()
729            .map(|(name, loc)| {
730                let ty = self
731                    .lookup_explicit_type_param(name)
732                    .unwrap_or_else(|| self.gen_typescheme(loc));
733                (name, ty)
734            })
735            .collect::<BTreeMap<_, _>>();
736        self.explicit_type_param_scopes.push(map);
737        let res = f(self);
738        let _ = self.explicit_type_param_scopes.pop();
739        res
740    }
741
742    fn lookup_explicit_type_param(&self, name: Symbol) -> Option<TypeNodeId> {
743        self.explicit_type_param_scopes
744            .iter()
745            .rev()
746            .find_map(|scope| scope.get(&name).copied())
747    }
748
749    /// Register type declarations from ModuleInfo into the constructor environment
750    /// Register type declarations from ModuleInfo into the constructor environment
751    fn register_type_declarations(
752        &mut self,
753        type_declarations: &crate::ast::program::TypeDeclarationMap,
754    ) {
755        // First pass: Create all UserSum types without recursive wrapping
756        // and register type names so that TypeAlias can be resolved
757        let mut sum_types: std::collections::HashMap<Symbol, TypeNodeId> =
758            std::collections::HashMap::new();
759
760        for (type_name, decl_info) in type_declarations {
761            let variants = &decl_info.variants;
762            let variant_data: Vec<(Symbol, Option<TypeNodeId>)> =
763                variants.iter().map(|v| (v.name, v.payload)).collect();
764
765            let sum_type = Type::UserSum {
766                name: *type_name,
767                variants: variant_data.clone(),
768            }
769            .into_id();
770
771            sum_types.insert(*type_name, sum_type);
772            // Register the type name itself as Persistent so it's accessible from all stages
773            self.env
774                .add_bind(&[(*type_name, (sum_type, EvalStage::Persistent))]);
775        }
776
777        // Second pass: For recursive types, wrap self-references in Boxed
778        for (type_name, decl_info) in type_declarations {
779            if !decl_info.is_recursive {
780                continue;
781            }
782
783            let variants = &decl_info.variants;
784            let sum_type_id = sum_types[type_name];
785
786            // Transform recursive references to Boxed
787            let variant_data: Vec<(Symbol, Option<TypeNodeId>)> = variants
788                .iter()
789                .map(|v| {
790                    let wrapped_payload = v.payload.map(|payload_type| {
791                        Self::wrap_recursive_refs_static(payload_type, *type_name, sum_type_id)
792                    });
793                    (v.name, wrapped_payload)
794                })
795                .collect();
796
797            // Update the UserSum type with wrapped variants
798            let new_sum_type = Type::UserSum {
799                name: *type_name,
800                variants: variant_data.clone(),
801            }
802            .into_id();
803
804            // Update the binding as Persistent
805            self.env
806                .add_bind(&[(*type_name, (new_sum_type, EvalStage::Persistent))]);
807
808            // Register each constructor
809            for (tag_index, (variant_name, payload_type)) in variant_data.iter().enumerate() {
810                self.constructor_env.insert(
811                    *variant_name,
812                    ConstructorInfo {
813                        sum_type: new_sum_type,
814                        tag_index,
815                        payload_type: *payload_type,
816                    },
817                );
818            }
819        }
820
821        // Register constructors for non-recursive types
822        for (type_name, decl_info) in type_declarations {
823            if decl_info.is_recursive {
824                continue;
825            }
826
827            let sum_type = sum_types[type_name];
828            let variants = &decl_info.variants;
829
830            for (tag_index, variant) in variants.iter().enumerate() {
831                self.constructor_env.insert(
832                    variant.name,
833                    ConstructorInfo {
834                        sum_type,
835                        tag_index,
836                        payload_type: variant.payload,
837                    },
838                );
839            }
840        }
841
842        // Check for recursive type declarations (not allowed without 'rec' keyword)
843        self.check_type_declaration_recursion(type_declarations);
844    }
845
846    /// Wrap direct self-references in Boxed type for recursive type declarations
847    /// This is a static function that transforms TypeAlias(self_name) -> Boxed(sum_type_id)
848    /// in Tuple/Record positions. Does NOT recurse into Function/Array which already provide indirection.
849    fn wrap_recursive_refs_static(
850        ty: TypeNodeId,
851        self_name: Symbol,
852        sum_type_id: TypeNodeId,
853    ) -> TypeNodeId {
854        match ty.to_type() {
855            Type::TypeAlias(name) if name == self_name => {
856                // Direct self-reference: wrap the sum type in Boxed
857                Type::Boxed(sum_type_id).into_id()
858            }
859            Type::Tuple(elements) => {
860                // Recursively wrap in tuple elements
861                let wrapped_elements: Vec<TypeNodeId> = elements
862                    .iter()
863                    .map(|&elem| Self::wrap_recursive_refs_static(elem, self_name, sum_type_id))
864                    .collect();
865                Type::Tuple(wrapped_elements).into_id()
866            }
867            Type::Record(fields) => {
868                // Recursively wrap in record fields
869                let wrapped_fields: Vec<RecordTypeField> = fields
870                    .iter()
871                    .map(|field| RecordTypeField {
872                        key: field.key,
873                        ty: Self::wrap_recursive_refs_static(field.ty, self_name, sum_type_id),
874                        has_default: field.has_default,
875                    })
876                    .collect();
877                Type::Record(wrapped_fields).into_id()
878            }
879            Type::Union(elements) => {
880                // Recursively wrap in union elements
881                let wrapped_elements: Vec<TypeNodeId> = elements
882                    .iter()
883                    .map(|&elem| Self::wrap_recursive_refs_static(elem, self_name, sum_type_id))
884                    .collect();
885                Type::Union(wrapped_elements).into_id()
886            }
887            // Do NOT recurse into Function, Array, Code, or Boxed - they already provide indirection
888            _ => ty,
889        }
890    }
891
892    /// Check for recursive references in type declarations
893    /// Recursion is only allowed when the `rec` keyword is used
894    fn check_type_declaration_recursion(
895        &mut self,
896        type_declarations: &crate::ast::program::TypeDeclarationMap,
897    ) {
898        for (type_name, decl_info) in type_declarations {
899            // Skip the recursion check for types declared with `type rec`
900            if decl_info.is_recursive {
901                continue;
902            }
903            if let Some(location) =
904                self.is_type_declaration_recursive(*type_name, &decl_info.variants)
905            {
906                self.errors.push(Error::RecursiveTypeAlias {
907                    type_name: *type_name,
908                    cycle: vec![*type_name],
909                    location,
910                });
911            }
912        }
913    }
914
915    /// Check if a type declaration contains recursive references
916    /// Returns Some(location) if recursion is found, None otherwise
917    fn is_type_declaration_recursive(
918        &self,
919        type_name: Symbol,
920        variants: &[crate::ast::program::VariantDef],
921    ) -> Option<Location> {
922        variants.iter().find_map(|variant| {
923            variant
924                .payload
925                .filter(|&payload_type| self.type_references_name(payload_type, type_name))
926                .map(|payload_type| payload_type.to_loc())
927        })
928    }
929
930    /// Check if a type references a specific type name (for recursion detection)
931    fn type_references_name(&self, type_id: TypeNodeId, target_name: Symbol) -> bool {
932        match type_id.to_type() {
933            Type::TypeAlias(name) if name == target_name => true,
934            Type::TypeAlias(name) => {
935                // Follow type alias to see if it eventually references target
936                if let Some(resolved_type) = self.type_aliases.get(&name) {
937                    self.type_references_name(*resolved_type, target_name)
938                } else {
939                    false
940                }
941            }
942            Type::Function { arg, ret } => {
943                self.type_references_name(arg, target_name)
944                    || self.type_references_name(ret, target_name)
945            }
946            Type::Tuple(elements) | Type::Union(elements) => elements
947                .iter()
948                .any(|t| self.type_references_name(*t, target_name)),
949            Type::Array(elem) | Type::Code(elem) => self.type_references_name(elem, target_name),
950            Type::Boxed(inner) => self.type_references_name(inner, target_name),
951            Type::Record(fields) => fields
952                .iter()
953                .any(|f| self.type_references_name(f.ty, target_name)),
954            Type::UserSum { name, .. } if name == target_name => true,
955            Type::UserSum { variants, .. } => variants
956                .iter()
957                .filter_map(|(_, payload)| *payload)
958                .any(|p| self.type_references_name(p, target_name)),
959            _ => false,
960        }
961    }
962
963    /// Register type aliases from ModuleInfo into the type environment
964    fn register_type_aliases(&mut self, type_aliases: &crate::ast::program::TypeAliasMap) {
965        // Store type aliases for resolution during unification
966        for (alias_name, target_type) in type_aliases {
967            self.type_aliases.insert(*alias_name, *target_type);
968            // Also add to environment for name resolution
969            self.env
970                .add_bind(&[(*alias_name, (*target_type, EvalStage::Persistent))]);
971        }
972
973        // Check for circular type aliases
974        self.check_type_alias_cycles(type_aliases);
975    }
976
977    /// Check for circular references in type aliases
978    fn check_type_alias_cycles(&mut self, type_aliases: &TypeAliasMap) {
979        let errors: Vec<_> = type_aliases
980            .iter()
981            .filter_map(|(alias_name, target_type)| {
982                Self::detect_type_alias_cycle(*alias_name, type_aliases).map(|cycle| {
983                    Error::RecursiveTypeAlias {
984                        type_name: *alias_name,
985                        cycle,
986                        location: target_type.to_loc(),
987                    }
988                })
989            })
990            .collect();
991
992        self.errors.extend(errors);
993    }
994
995    /// Detect a cycle starting from a given type alias name
996    /// Returns Some(cycle) if a cycle is found, None otherwise
997    fn detect_type_alias_cycle(start: Symbol, type_aliases: &TypeAliasMap) -> Option<Vec<Symbol>> {
998        Self::detect_cycle_helper(start, vec![], type_aliases).map(|t| t.0)
999    }
1000
1001    /// Helper function for cycle detection
1002    fn detect_cycle_helper(
1003        current: Symbol,
1004        path: Vec<Symbol>,
1005        type_aliases: &TypeAliasMap,
1006    ) -> Option<TypeCycle> {
1007        // If we've seen this type before in the current path, we have a cycle
1008        if let Some(cycle_start) = path.iter().position(|&s| s == current) {
1009            return Some(TypeCycle(path[cycle_start..].to_vec()));
1010        }
1011
1012        let new_path = [path, vec![current]].concat();
1013
1014        type_aliases.get(&current).and_then(|target_type| {
1015            Self::find_type_aliases_in_type(*target_type)
1016                .into_iter()
1017                .find_map(|ref_alias| {
1018                    Self::detect_cycle_helper(ref_alias, new_path.clone(), type_aliases)
1019                })
1020        })
1021    }
1022
1023    /// Find all type alias names referenced in a type
1024    fn find_type_aliases_in_type(type_id: TypeNodeId) -> Vec<Symbol> {
1025        match type_id.to_type() {
1026            Type::TypeAlias(name) => vec![name],
1027            Type::Function { arg, ret } => {
1028                let mut aliases = Self::find_type_aliases_in_type(arg);
1029                aliases.extend(Self::find_type_aliases_in_type(ret));
1030                aliases
1031            }
1032            Type::Tuple(elements) | Type::Union(elements) => elements
1033                .iter()
1034                .flat_map(|t| Self::find_type_aliases_in_type(*t))
1035                .collect(),
1036            Type::Array(elem) | Type::Code(elem) => Self::find_type_aliases_in_type(elem),
1037            Type::Record(fields) => fields
1038                .iter()
1039                .flat_map(|f| Self::find_type_aliases_in_type(f.ty))
1040                .collect(),
1041            Type::UserSum { variants, .. } => variants
1042                .iter()
1043                .filter_map(|(_, payload)| *payload)
1044                .flat_map(Self::find_type_aliases_in_type)
1045                .collect(),
1046            _ => vec![],
1047        }
1048    }
1049
1050    /// Resolve type aliases recursively
1051    pub fn resolve_type_alias(&self, type_id: TypeNodeId) -> TypeNodeId {
1052        match type_id.to_type() {
1053            Type::TypeAlias(alias_name) => {
1054                let resolved_alias_name = self.resolve_type_alias_symbol_fallback(alias_name);
1055                if let Some(resolved_type) = self.type_aliases.get(&resolved_alias_name) {
1056                    // Recursively resolve in case the alias points to another alias
1057                    self.resolve_type_alias(*resolved_type)
1058                } else {
1059                    type_id // Return original if not found (shouldn't happen)
1060                }
1061            }
1062            _ => type_id.apply_fn(|t| self.resolve_type_alias(t)),
1063        }
1064    }
1065}
1066impl InferContext {
1067    const TUPLE_BINOP_MAX_ARITY: usize = 16;
1068
1069    fn intrinsic_types() -> Vec<(Symbol, TypeNodeId)> {
1070        let binop_ty = function!(vec![numeric!(), numeric!()], numeric!());
1071        let binop_names = [
1072            intrinsics::ADD,
1073            intrinsics::SUB,
1074            intrinsics::MULT,
1075            intrinsics::DIV,
1076            intrinsics::MODULO,
1077            intrinsics::POW,
1078            intrinsics::GT,
1079            intrinsics::LT,
1080            intrinsics::GE,
1081            intrinsics::LE,
1082            intrinsics::EQ,
1083            intrinsics::NE,
1084            intrinsics::AND,
1085            intrinsics::OR,
1086        ];
1087        let uniop_ty = function!(vec![numeric!()], numeric!());
1088        let uniop_names = [
1089            intrinsics::NEG,
1090            intrinsics::MEM,
1091            intrinsics::SIN,
1092            intrinsics::COS,
1093            intrinsics::ABS,
1094            intrinsics::LOG,
1095            intrinsics::SQRT,
1096        ];
1097
1098        let binds = binop_names.map(|n| (n.to_symbol(), binop_ty));
1099        let unibinds = uniop_names.map(|n| (n.to_symbol(), uniop_ty));
1100        [
1101            (
1102                intrinsics::DELAY.to_symbol(),
1103                function!(vec![numeric!(), numeric!(), numeric!()], numeric!()),
1104            ),
1105            (
1106                intrinsics::TOFLOAT.to_symbol(),
1107                function!(vec![integer!()], numeric!()),
1108            ),
1109        ]
1110        .into_iter()
1111        .chain(binds)
1112        .chain(unibinds)
1113        .collect()
1114    }
1115
1116    fn is_tuple_arithmetic_binop_label(label: Symbol) -> bool {
1117        matches!(
1118            label.as_str(),
1119            intrinsics::ADD | intrinsics::SUB | intrinsics::MULT | intrinsics::DIV
1120        )
1121    }
1122
1123    fn try_get_tuple_arithmetic_binop_label(&self, fun: ExprNodeId) -> Option<Symbol> {
1124        match fun.to_expr() {
1125            Expr::Var(name) if Self::is_tuple_arithmetic_binop_label(name) => Some(name),
1126            _ => None,
1127        }
1128    }
1129
1130    fn resolve_for_tuple_binop(&self, ty: TypeNodeId) -> TypeNodeId {
1131        let resolved_alias = self.resolve_type_alias(ty);
1132        Self::substitute_type(resolved_alias)
1133    }
1134
1135    fn type_loc_or_expr_loc(&self, ty: TypeNodeId, expr_loc: &Location) -> Location {
1136        let ty_loc = ty.to_loc();
1137        if ty_loc.path.as_os_str().is_empty() {
1138            expr_loc.clone()
1139        } else {
1140            ty_loc
1141        }
1142    }
1143
1144    fn is_numeric_scalar_for_tuple_binop(&self, ty: TypeNodeId) -> bool {
1145        matches!(
1146            self.resolve_for_tuple_binop(ty).to_type(),
1147            Type::Primitive(PType::Numeric) | Type::Primitive(PType::Int)
1148        )
1149    }
1150
1151    fn make_tuple_binop_arity_error(&self, actual_arity: usize, loc: &Location) -> Error {
1152        Error::TypeMismatch {
1153            left: (
1154                Type::Tuple(vec![numeric!(); Self::TUPLE_BINOP_MAX_ARITY])
1155                    .into_id_with_location(loc.clone()),
1156                loc.clone(),
1157            ),
1158            right: (
1159                Type::Tuple(vec![numeric!(); actual_arity]).into_id_with_location(loc.clone()),
1160                loc.clone(),
1161            ),
1162        }
1163    }
1164
1165    fn infer_tuple_arithmetic_binop_type_rec(
1166        &mut self,
1167        lhs_ty: TypeNodeId,
1168        rhs_ty: TypeNodeId,
1169        loc: &Location,
1170        errs: &mut Vec<Error>,
1171    ) -> Option<TypeNodeId> {
1172        let lhs_resolved = self.resolve_for_tuple_binop(lhs_ty);
1173        let rhs_resolved = self.resolve_for_tuple_binop(rhs_ty);
1174
1175        match (lhs_resolved.to_type(), rhs_resolved.to_type()) {
1176            (Type::Tuple(lhs_elems), Type::Tuple(rhs_elems)) => {
1177                if lhs_elems.len() != rhs_elems.len() {
1178                    errs.push(Error::TypeMismatch {
1179                        left: (lhs_ty, loc.clone()),
1180                        right: (rhs_ty, loc.clone()),
1181                    });
1182                    return None;
1183                }
1184                if lhs_elems.len() > Self::TUPLE_BINOP_MAX_ARITY {
1185                    errs.push(self.make_tuple_binop_arity_error(lhs_elems.len(), loc));
1186                    return None;
1187                }
1188
1189                let result_elems = lhs_elems
1190                    .iter()
1191                    .zip(rhs_elems.iter())
1192                    .filter_map(|(lt, rt)| {
1193                        self.infer_tuple_arithmetic_binop_type_rec(*lt, *rt, loc, errs)
1194                    })
1195                    .collect::<Vec<_>>();
1196
1197                if result_elems.len() != lhs_elems.len() {
1198                    None
1199                } else {
1200                    Some(Type::Tuple(result_elems).into_id_with_location(loc.clone()))
1201                }
1202            }
1203            (Type::Tuple(tuple_elems), _) => {
1204                if tuple_elems.len() > Self::TUPLE_BINOP_MAX_ARITY {
1205                    errs.push(self.make_tuple_binop_arity_error(tuple_elems.len(), loc));
1206                    return None;
1207                }
1208                if !self.is_numeric_scalar_for_tuple_binop(rhs_ty) {
1209                    let rhs_loc = self.type_loc_or_expr_loc(rhs_ty, loc);
1210                    errs.push(Error::TypeMismatch {
1211                        left: (numeric!(), rhs_loc.clone()),
1212                        right: (rhs_ty, rhs_loc),
1213                    });
1214                    return None;
1215                }
1216
1217                let result_elems = tuple_elems
1218                    .iter()
1219                    .filter_map(|elem_ty| {
1220                        self.infer_tuple_arithmetic_binop_type_rec(*elem_ty, rhs_ty, loc, errs)
1221                    })
1222                    .collect::<Vec<_>>();
1223
1224                if result_elems.len() != tuple_elems.len() {
1225                    None
1226                } else {
1227                    Some(Type::Tuple(result_elems).into_id_with_location(loc.clone()))
1228                }
1229            }
1230            (_, Type::Tuple(tuple_elems)) => {
1231                if tuple_elems.len() > Self::TUPLE_BINOP_MAX_ARITY {
1232                    errs.push(self.make_tuple_binop_arity_error(tuple_elems.len(), loc));
1233                    return None;
1234                }
1235                if !self.is_numeric_scalar_for_tuple_binop(lhs_ty) {
1236                    let lhs_loc = self.type_loc_or_expr_loc(lhs_ty, loc);
1237                    errs.push(Error::TypeMismatch {
1238                        left: (numeric!(), lhs_loc.clone()),
1239                        right: (lhs_ty, lhs_loc),
1240                    });
1241                    return None;
1242                }
1243
1244                let result_elems = tuple_elems
1245                    .iter()
1246                    .filter_map(|elem_ty| {
1247                        self.infer_tuple_arithmetic_binop_type_rec(lhs_ty, *elem_ty, loc, errs)
1248                    })
1249                    .collect::<Vec<_>>();
1250
1251                if result_elems.len() != tuple_elems.len() {
1252                    None
1253                } else {
1254                    Some(Type::Tuple(result_elems).into_id_with_location(loc.clone()))
1255                }
1256            }
1257            _ => {
1258                let mut valid = true;
1259                if !self.is_numeric_scalar_for_tuple_binop(lhs_ty) {
1260                    let lhs_loc = self.type_loc_or_expr_loc(lhs_ty, loc);
1261                    errs.push(Error::TypeMismatch {
1262                        left: (numeric!(), lhs_loc.clone()),
1263                        right: (lhs_ty, lhs_loc),
1264                    });
1265                    valid = false;
1266                }
1267                if !self.is_numeric_scalar_for_tuple_binop(rhs_ty) {
1268                    let rhs_loc = self.type_loc_or_expr_loc(rhs_ty, loc);
1269                    errs.push(Error::TypeMismatch {
1270                        left: (numeric!(), rhs_loc.clone()),
1271                        right: (rhs_ty, rhs_loc),
1272                    });
1273                    valid = false;
1274                }
1275                if valid { Some(numeric!()) } else { None }
1276            }
1277        }
1278    }
1279
1280    fn infer_tuple_arithmetic_binop_type(
1281        &mut self,
1282        lhs_ty: TypeNodeId,
1283        rhs_ty: TypeNodeId,
1284        loc: Location,
1285    ) -> Result<TypeNodeId, Vec<Error>> {
1286        let mut errs = vec![];
1287        let result_ty = self.infer_tuple_arithmetic_binop_type_rec(lhs_ty, rhs_ty, &loc, &mut errs);
1288        if !errs.is_empty() {
1289            return Err(errs);
1290        }
1291        result_ty.ok_or_else(|| {
1292            vec![Error::TypeMismatch {
1293                left: (lhs_ty, loc.clone()),
1294                right: (rhs_ty, loc),
1295            }]
1296        })
1297    }
1298
1299    fn is_auto_spread_endpoint_type(&self, ty: TypeNodeId) -> bool {
1300        matches!(
1301            self.resolve_for_tuple_binop(ty).to_type(),
1302            Type::Primitive(PType::Numeric)
1303                | Type::Primitive(PType::Int)
1304                | Type::Intermediate(_)
1305                | Type::TypeScheme(_)
1306                | Type::Unknown
1307                | Type::Failure
1308        )
1309    }
1310
1311    fn auto_spread_param_endpoint_type(&self, param_ty: TypeNodeId) -> Option<TypeNodeId> {
1312        let resolved = self.resolve_for_tuple_binop(param_ty);
1313        match resolved.to_type() {
1314            Type::Record(fields) if fields.len() == 1 => Some(fields[0].ty),
1315            _ => Some(resolved),
1316        }
1317    }
1318
1319    fn is_numeric_to_numeric_function_for_auto_spread(&self, fn_ty: TypeNodeId) -> bool {
1320        let resolved = self.resolve_for_tuple_binop(fn_ty);
1321        matches!(
1322            resolved.to_type(),
1323            Type::Function { arg, ret }
1324                if self
1325                    .auto_spread_param_endpoint_type(arg)
1326                    .is_some_and(|endpoint| self.is_auto_spread_endpoint_type(endpoint))
1327                    && self.is_auto_spread_endpoint_type(ret)
1328        )
1329    }
1330
1331    fn infer_auto_spread_type_rec(
1332        &mut self,
1333        arg_ty: TypeNodeId,
1334        loc: &Location,
1335        errs: &mut Vec<Error>,
1336    ) -> Option<TypeNodeId> {
1337        let resolved = self.resolve_for_tuple_binop(arg_ty);
1338        match resolved.to_type() {
1339            Type::Tuple(elems) => {
1340                if elems.len() > Self::TUPLE_BINOP_MAX_ARITY {
1341                    errs.push(self.make_tuple_binop_arity_error(elems.len(), loc));
1342                    return None;
1343                }
1344                let mapped = elems
1345                    .iter()
1346                    .filter_map(|elem_ty| self.infer_auto_spread_type_rec(*elem_ty, loc, errs))
1347                    .collect::<Vec<_>>();
1348                if mapped.len() != elems.len() {
1349                    None
1350                } else {
1351                    Some(Type::Tuple(mapped).into_id_with_location(loc.clone()))
1352                }
1353            }
1354            _ => {
1355                if self.is_numeric_scalar_for_tuple_binop(arg_ty) {
1356                    Some(numeric!())
1357                } else {
1358                    let arg_loc = self.type_loc_or_expr_loc(arg_ty, loc);
1359                    errs.push(Error::TypeMismatch {
1360                        left: (numeric!(), arg_loc.clone()),
1361                        right: (arg_ty, arg_loc),
1362                    });
1363                    None
1364                }
1365            }
1366        }
1367    }
1368
1369    fn infer_auto_spread_type(
1370        &mut self,
1371        fn_ty: TypeNodeId,
1372        arg_ty: TypeNodeId,
1373        loc: Location,
1374    ) -> Result<TypeNodeId, Vec<Error>> {
1375        let mut errs = vec![];
1376        let result_ty = self.infer_auto_spread_type_rec(arg_ty, &loc, &mut errs);
1377        if !errs.is_empty() {
1378            return Err(errs);
1379        }
1380        result_ty.ok_or_else(|| {
1381            vec![Error::TypeMismatch {
1382                left: (arg_ty, loc.clone()),
1383                right: (arg_ty, loc),
1384            }]
1385        })
1386    }
1387
1388    /// Get the type associated with a constructor name from a union or user-defined sum type
1389    /// For primitive types in unions like `float | string`, the constructor names are "float" and "string"
1390    /// For user-defined sum types, returns Unit for payloadless constructors
1391    fn get_constructor_type_from_union(
1392        &self,
1393        union_ty: TypeNodeId,
1394        constructor_name: Symbol,
1395    ) -> TypeNodeId {
1396        // First, try to look up the constructor directly in constructor_env.
1397        // This handles cases where the union type is still an unresolved intermediate type.
1398        if let Some(constructor_info) = self.constructor_env.get(&constructor_name) {
1399            return constructor_info.payload_type.unwrap_or_else(|| unit!());
1400        }
1401
1402        let resolved = Self::substitute_type(union_ty);
1403        match resolved.to_type() {
1404            Type::Union(variants) => {
1405                // Find a variant that matches the constructor name
1406                for variant_ty in variants.iter() {
1407                    let variant_resolved = Self::substitute_type(*variant_ty);
1408                    let variant_name = Self::type_constructor_name(&variant_resolved.to_type());
1409                    if variant_name == Some(constructor_name) {
1410                        return *variant_ty;
1411                    }
1412                }
1413                // Constructor not found in union - return Unknown as placeholder
1414                Type::Unknown.into_id_with_location(union_ty.to_loc())
1415            }
1416            Type::UserSum { name: _, variants } => {
1417                // Check if constructor_name is one of the variants
1418                if let Some((_, payload_ty)) =
1419                    variants.iter().find(|(name, _)| *name == constructor_name)
1420                {
1421                    // Return the payload type if available, otherwise Unit
1422                    payload_ty.unwrap_or_else(|| unit!())
1423                } else {
1424                    Type::Unknown.into_id_with_location(union_ty.to_loc())
1425                }
1426            }
1427            // If not a union type, check if it matches the constructor directly
1428            other => {
1429                let type_name = Self::type_constructor_name(&other);
1430                if type_name == Some(constructor_name) {
1431                    resolved
1432                } else {
1433                    Type::Unknown.into_id_with_location(union_ty.to_loc())
1434                }
1435            }
1436        }
1437    }
1438
1439    /// Get the constructor name for a type (used for matching in union types)
1440    /// Primitive types use their type name as constructor (e.g., "float", "string")
1441    fn type_constructor_name(ty: &Type) -> Option<Symbol> {
1442        match ty {
1443            Type::Primitive(PType::Numeric) => Some("float".to_symbol()),
1444            Type::Primitive(PType::String) => Some("string".to_symbol()),
1445            Type::Primitive(PType::Int) => Some("int".to_symbol()),
1446            Type::Primitive(PType::Unit) => Some("unit".to_symbol()),
1447            // For other types, we don't have built-in constructor names yet
1448            _ => None,
1449        }
1450    }
1451
1452    /// Add bindings for a match pattern to the current environment
1453    /// Handles variable bindings, tuple patterns, and nested patterns
1454    fn add_pattern_bindings(&mut self, pattern: &crate::ast::MatchPattern, ty: TypeNodeId) {
1455        use crate::ast::MatchPattern;
1456        // Resolve the type to its concrete form (unwrap intermediate types)
1457        let resolved_ty = ty.get_root().to_type();
1458        match pattern {
1459            MatchPattern::Variable(var) => {
1460                self.env.add_bind(&[(*var, (ty, self.stage))]);
1461            }
1462            MatchPattern::Wildcard => {
1463                // No bindings for wildcard
1464            }
1465            MatchPattern::Literal(_) => {
1466                // No bindings for literal patterns
1467            }
1468            MatchPattern::Tuple(patterns) => {
1469                // For tuple patterns, we need to bind each element
1470                // The type should be a tuple type with matching elements
1471                if let Type::Tuple(elem_types) = resolved_ty {
1472                    for (pat, elem_ty) in patterns.iter().zip(elem_types.iter()) {
1473                        self.add_pattern_bindings(pat, *elem_ty);
1474                    }
1475                } else {
1476                    // If we have a single-element tuple pattern, try to unwrap and bind
1477                    // This handles the case of Tuple([inner_pattern]) where we should
1478                    // pass the type directly to the inner pattern
1479                    if patterns.len() == 1 {
1480                        self.add_pattern_bindings(&patterns[0], ty);
1481                    }
1482                }
1483            }
1484            MatchPattern::Constructor(_, inner) => {
1485                // For constructor patterns, recursively handle the inner pattern
1486                if let Some(inner_pat) = inner {
1487                    self.add_pattern_bindings(inner_pat, ty);
1488                }
1489            }
1490        }
1491    }
1492
1493    /// Check a pattern against a type and add variable bindings
1494    /// This is used for tuple patterns in multi-scrutinee matching
1495    fn check_pattern_against_type(
1496        &mut self,
1497        pattern: &crate::ast::MatchPattern,
1498        ty: TypeNodeId,
1499        loc: &Location,
1500    ) {
1501        use crate::ast::MatchPattern;
1502        match pattern {
1503            MatchPattern::Literal(lit) => {
1504                // For literal patterns, unify with expected type
1505                let pat_ty = match lit {
1506                    crate::ast::Literal::Int(_) | crate::ast::Literal::Float(_) => {
1507                        Type::Primitive(PType::Numeric).into_id_with_location(loc.clone())
1508                    }
1509                    _ => Type::Failure.into_id_with_location(loc.clone()),
1510                };
1511                let _ = self.unify_types(ty, pat_ty);
1512            }
1513            MatchPattern::Wildcard => {
1514                // Wildcard matches anything, no binding
1515            }
1516            MatchPattern::Variable(var) => {
1517                // Bind variable to the expected type
1518                self.env.add_bind(&[(*var, (ty, self.stage))]);
1519            }
1520            MatchPattern::Constructor(constructor_name, inner) => {
1521                // Get the payload type for this constructor from the union/enum type
1522                let binding_ty = self.get_constructor_type_from_union(ty, *constructor_name);
1523                if let Some(inner_pat) = inner {
1524                    self.add_pattern_bindings(inner_pat, binding_ty);
1525                }
1526            }
1527            MatchPattern::Tuple(patterns) => {
1528                // Recursively check nested tuple pattern
1529                let resolved_ty = ty.get_root().to_type();
1530                if let Type::Tuple(elem_types) = resolved_ty {
1531                    for (pat, elem_ty) in patterns.iter().zip(elem_types.iter()) {
1532                        self.check_pattern_against_type(pat, *elem_ty, loc);
1533                    }
1534                }
1535            }
1536        }
1537    }
1538
1539    fn unwrap_result(&mut self, res: Result<TypeNodeId, Vec<Error>>) -> TypeNodeId {
1540        match res {
1541            Ok(t) => t,
1542            Err(mut e) => {
1543                let loc = &e[0].get_labels()[0].0; //todo
1544                self.errors.append(&mut e);
1545                Type::Failure.into_id_with_location(loc.clone())
1546            }
1547        }
1548    }
1549    fn get_typescheme(&mut self, tvid: IntermediateId, loc: Location) -> TypeNodeId {
1550        self.generalize_map.get(&tvid).cloned().map_or_else(
1551            || self.gen_typescheme(loc),
1552            |id| Type::TypeScheme(id).into_id(),
1553        )
1554    }
1555    fn gen_typescheme(&mut self, loc: Location) -> TypeNodeId {
1556        let res = Type::TypeScheme(self.typescheme_idx).into_id_with_location(loc);
1557        self.typescheme_idx.0 += 1;
1558        res
1559    }
1560
1561    fn gen_intermediate_type_with_location(&mut self, loc: Location) -> TypeNodeId {
1562        let res = Type::Intermediate(Arc::new(RwLock::new(TypeVar::new(
1563            self.interm_idx,
1564            self.level,
1565        ))))
1566        .into_id_with_location(loc);
1567        self.interm_idx.0 += 1;
1568        res
1569    }
1570
1571    fn resolve_type_alias_symbol_fallback(&self, name: Symbol) -> Symbol {
1572        if name.as_str().contains('$') {
1573            return name;
1574        }
1575
1576        if let Some(ref module_info) = self.module_info
1577            && let Some(mapped) = module_info.use_alias_map.get(&name)
1578        {
1579            return *mapped;
1580        }
1581
1582        if self.type_aliases.contains_key(&name) {
1583            return name;
1584        }
1585
1586        // Also check type_declarations directly (for UserSum types defined without module prefix)
1587        if let Some(ref module_info) = self.module_info
1588            && module_info.type_declarations.contains_key(&name)
1589        {
1590            return name;
1591        }
1592
1593        // Search for mangled names ending with $<name> in both type_aliases and type_declarations
1594        let suffix = format!("${}", name.as_str());
1595        let mut candidates: Vec<Symbol> = self
1596            .type_aliases
1597            .keys()
1598            .copied()
1599            .filter(|symbol| symbol.as_str().ends_with(&suffix))
1600            .collect();
1601
1602        if let Some(ref module_info) = self.module_info {
1603            candidates.extend(
1604                module_info
1605                    .type_declarations
1606                    .keys()
1607                    .copied()
1608                    .filter(|symbol| symbol.as_str().ends_with(&suffix)),
1609            );
1610        }
1611
1612        if candidates.len() == 1 {
1613            candidates[0]
1614        } else {
1615            name
1616        }
1617    }
1618
1619    fn convert_unknown_to_intermediate(&mut self, t: TypeNodeId, loc: Location) -> TypeNodeId {
1620        match t.to_type() {
1621            Type::Unknown => self.gen_intermediate_type_with_location(loc.clone()),
1622            Type::TypeAlias(name) => {
1623                if Self::is_explicit_type_param_name(name) {
1624                    return self
1625                        .lookup_explicit_type_param(name)
1626                        .unwrap_or_else(|| self.gen_typescheme(loc.clone()));
1627                }
1628                let resolved_name = self.resolve_type_alias_symbol_fallback(name);
1629
1630                log::trace!(
1631                    "Resolving TypeAlias: {} -> {}",
1632                    name.as_str(),
1633                    resolved_name.as_str()
1634                );
1635
1636                // Check visibility if module_info is available
1637                if let Some(ref module_info) = self.module_info
1638                    && let Some(&is_public) = module_info.visibility_map.get(&resolved_name)
1639                    && !is_public
1640                {
1641                    // Type is private - report error for accessing it from outside
1642                    let type_path: Vec<&str> = resolved_name.as_str().split('$').collect();
1643                    if type_path.len() > 1 {
1644                        // This is a module member type
1645                        let module_path: Vec<crate::interner::Symbol> = type_path
1646                            [..type_path.len() - 1]
1647                            .iter()
1648                            .map(ToSymbol::to_symbol)
1649                            .collect();
1650                        let type_name = type_path.last().unwrap().to_symbol();
1651
1652                        // Report error for private type access
1653                        self.errors.push(Error::PrivateTypeAccess {
1654                            module_path,
1655                            type_name,
1656                            location: loc.clone(),
1657                        });
1658                    }
1659                }
1660
1661                // Resolve type alias by looking it up in the environment
1662                match self.lookup(resolved_name, loc.clone()) {
1663                    Ok(resolved_ty) => {
1664                        let resolved_ty = self.resolve_type_alias(resolved_ty);
1665                        let resolved_ty =
1666                            self.convert_unknown_to_intermediate(resolved_ty, loc.clone());
1667                        log::trace!(
1668                            "Resolved TypeAlias {} to {}",
1669                            resolved_name.as_str(),
1670                            resolved_ty.to_type()
1671                        );
1672                        resolved_ty
1673                    }
1674                    Err(_) => {
1675                        log::warn!(
1676                            "TypeAlias {} not found, treating as Unknown",
1677                            resolved_name.as_str()
1678                        );
1679                        // If not found, treat as Unknown and convert to intermediate
1680                        self.gen_intermediate_type_with_location(loc.clone())
1681                    }
1682                }
1683            }
1684            _ => t.apply_fn(|t| self.convert_unknown_to_intermediate(t, loc.clone())),
1685        }
1686    }
1687
1688    fn provisional_lambda_function_type(
1689        &mut self,
1690        params: &[TypedId],
1691        rtype: Option<TypeNodeId>,
1692        loc: Location,
1693    ) -> TypeNodeId {
1694        let param_fields = params
1695            .iter()
1696            .map(|param| {
1697                let annotated_ty =
1698                    self.convert_unknown_to_intermediate(param.ty, param.ty.to_loc());
1699                RecordTypeField {
1700                    key: param.id,
1701                    ty: self.resolve_type_alias(annotated_ty),
1702                    has_default: param.default_value.is_some(),
1703                }
1704            })
1705            .collect::<Vec<_>>();
1706
1707        let arg_ty = match param_fields.len() {
1708            0 => Type::Primitive(PType::Unit).into_id_with_location(loc.clone()),
1709            1 => param_fields[0].ty,
1710            _ => Type::Record(param_fields).into_id_with_location(loc.clone()),
1711        };
1712
1713        let ret_ty = rtype
1714            .map(|ret| {
1715                let annotated_ret = self.convert_unknown_to_intermediate(ret, ret.to_loc());
1716                self.resolve_type_alias(annotated_ret)
1717            })
1718            .unwrap_or_else(|| self.gen_intermediate_type_with_location(loc.clone()));
1719
1720        Type::Function {
1721            arg: arg_ty,
1722            ret: ret_ty,
1723        }
1724        .into_id_with_location(loc)
1725    }
1726
1727    fn provisional_letrec_binding_type(
1728        &mut self,
1729        id: &TypedId,
1730        body: ExprNodeId,
1731        loc: Location,
1732    ) -> TypeNodeId {
1733        match body.to_expr() {
1734            Expr::Lambda(params, rtype, _) => {
1735                let has_explicit_lambda_signature =
1736                    params.iter().any(|param| !matches!(param.ty.to_type(), Type::Unknown))
1737                        || rtype.is_some();
1738
1739                if has_explicit_lambda_signature || matches!(id.ty.to_type(), Type::Unknown) {
1740                    self.provisional_lambda_function_type(params.as_slice(), rtype, loc)
1741                } else {
1742                    self.convert_unknown_to_intermediate(id.ty, id.ty.to_loc())
1743                }
1744            }
1745            _ if !matches!(id.ty.to_type(), Type::Unknown) => {
1746                self.convert_unknown_to_intermediate(id.ty, id.ty.to_loc())
1747            }
1748            _ => self.convert_unknown_to_intermediate(id.ty, id.ty.to_loc()),
1749        }
1750    }
1751
1752    /// Check if a symbol is public based on the visibility map
1753    fn is_public(&self, name: &Symbol) -> bool {
1754        let resolved_name = self.resolve_type_alias_symbol_fallback(*name);
1755        self.module_info
1756            .as_ref()
1757            .and_then(|info| info.visibility_map.get(&resolved_name))
1758            .is_some_and(|vis| *vis)
1759    }
1760
1761    fn is_private(&self, name: &Symbol) -> bool {
1762        !self.is_public(name)
1763    }
1764
1765    /// Check if a public function leaks private types in its signature
1766    fn check_private_type_leak(&mut self, name: Symbol, ty: TypeNodeId, loc: Location) {
1767        // Check if the function is public
1768        if !self.is_public(&name) {
1769            return; // Private functions can use private types
1770        }
1771
1772        // Check if the type contains any private type references
1773        if let Some(type_name) = self.contains_private_type(ty) {
1774            self.errors.push(Error::PrivateTypeLeak {
1775                function_name: name,
1776                private_type: type_name,
1777                location: loc,
1778            });
1779        }
1780    }
1781
1782    /// Recursively check if a type contains references to private types
1783    /// Returns Some(type_name) if a private type is found
1784    fn contains_private_type(&self, ty: TypeNodeId) -> Option<Symbol> {
1785        let resolved = Self::substitute_type(ty);
1786        match resolved.to_type() {
1787            Type::TypeAlias(name) => {
1788                if Self::is_explicit_type_param_name(name) {
1789                    return None;
1790                }
1791                let resolved_name = self.resolve_type_alias_symbol_fallback(name);
1792                // Check if this type alias is private
1793                if self.is_private(&resolved_name) {
1794                    return Some(resolved_name);
1795                }
1796
1797                // If it's a qualified name, extract type name and check visibility
1798                let name_str = name.as_str();
1799                if name_str.contains("::") {
1800                    let parts: Vec<&str> = name_str.split("::").collect();
1801                    if parts.len() >= 2 {
1802                        let module_path: Vec<Symbol> = parts[..parts.len() - 1]
1803                            .iter()
1804                            .map(|s| s.to_symbol())
1805                            .collect();
1806                        let type_name = parts[parts.len() - 1].to_symbol();
1807
1808                        let module_path_str = module_path
1809                            .iter()
1810                            .map(|s| s.as_str())
1811                            .collect::<Vec<_>>()
1812                            .join("::");
1813                        let mangled_name =
1814                            format!("{}::{}", module_path_str, type_name.as_str()).to_symbol();
1815
1816                        if self.is_private(&mangled_name) {
1817                            return Some(type_name);
1818                        }
1819                    }
1820                }
1821                None
1822            }
1823            Type::Function { arg, ret } => {
1824                // Check argument type (can be a single type or a record of multiple args)
1825                if let Some(private_type) = self.contains_private_type(arg) {
1826                    return Some(private_type);
1827                }
1828                // Check return type
1829                self.contains_private_type(ret)
1830            }
1831            Type::Tuple(ref elements) => {
1832                for elem_ty in elements.iter() {
1833                    if let Some(private_type) = self.contains_private_type(*elem_ty) {
1834                        return Some(private_type);
1835                    }
1836                }
1837                None
1838            }
1839            Type::Array(elem_ty) => self.contains_private_type(elem_ty),
1840            Type::Record(ref fields) => {
1841                for field in fields.iter() {
1842                    if let Some(private_type) = self.contains_private_type(field.ty) {
1843                        return Some(private_type);
1844                    }
1845                }
1846                None
1847            }
1848            Type::Union(ref variants) => {
1849                for variant_ty in variants.iter() {
1850                    if let Some(private_type) = self.contains_private_type(*variant_ty) {
1851                        return Some(private_type);
1852                    }
1853                }
1854                None
1855            }
1856            Type::Ref(inner_ty) => self.contains_private_type(inner_ty),
1857            Type::Code(inner_ty) => self.contains_private_type(inner_ty),
1858            Type::Boxed(inner_ty) => self.contains_private_type(inner_ty),
1859            Type::UserSum { name, variants } => {
1860                // Check if the user-defined sum type itself is private
1861                if self.is_private(&name) {
1862                    return Some(name);
1863                }
1864
1865                // Check payload types of variants
1866                for (_variant_name, payload_ty_opt) in variants.iter() {
1867                    if let Some(payload_ty) = payload_ty_opt
1868                        && let Some(private_type) = self.contains_private_type(*payload_ty)
1869                    {
1870                        return Some(private_type);
1871                    }
1872                }
1873                None
1874            }
1875            Type::Intermediate(_)
1876            | Type::Primitive(_)
1877            | Type::TypeScheme(_)
1878            | Type::Any
1879            | Type::Failure
1880            | Type::Unknown => None,
1881        }
1882    }
1883
1884    fn convert_unify_error(&self, e: UnificationError) -> Error {
1885        let gen_loc = |span| Location::new(span, self.file_path.clone());
1886        match e {
1887            UnificationError::TypeMismatch {
1888                left: (left, lspan),
1889                right: (right, rspan),
1890            } => Error::TypeMismatch {
1891                left: (left, gen_loc(lspan)),
1892                right: (right, gen_loc(rspan)),
1893            },
1894            UnificationError::LengthMismatch {
1895                left: (left, lspan),
1896                right: (right, rspan),
1897            } => Error::LengthMismatch {
1898                left: (left.len(), gen_loc(lspan)),
1899                right: (right.len(), gen_loc(rspan)),
1900            },
1901            UnificationError::CircularType { left, right } => {
1902                Error::CircularType(gen_loc(left), gen_loc(right))
1903            }
1904            UnificationError::ImcompatibleRecords {
1905                left: (left, lspan),
1906                right: (right, rspan),
1907            } => Error::IncompatibleKeyInRecord {
1908                left: (left, gen_loc(lspan)),
1909                right: (right, gen_loc(rspan)),
1910            },
1911        }
1912    }
1913    fn unify_types(&self, t1: TypeNodeId, t2: TypeNodeId) -> Result<Relation, Vec<Error>> {
1914        // Resolve type aliases before unification
1915        let resolved_t1 = self.resolve_type_alias(t1);
1916        let resolved_t2 = self.resolve_type_alias(t2);
1917
1918        unify_types(resolved_t1, resolved_t2)
1919            .map_err(|e| e.into_iter().map(|e| self.convert_unify_error(e)).collect())
1920    }
1921    // helper function
1922    fn merge_rel_result(
1923        &self,
1924        rel1: Result<Relation, Vec<Error>>,
1925        rel2: Result<Relation, Vec<Error>>,
1926        t1: TypeNodeId,
1927        t2: TypeNodeId,
1928    ) -> Result<(), Vec<Error>> {
1929        match (rel1, rel2) {
1930            (Ok(Relation::Identical), Ok(Relation::Identical)) => Ok(()),
1931            (Ok(_), Ok(_)) => Err(vec![Error::TypeMismatch {
1932                left: (t1, Location::new(t1.to_span(), self.file_path.clone())),
1933                right: (t2, Location::new(t2.to_span(), self.file_path.clone())),
1934            }]),
1935            (Err(e1), Err(e2)) => Err(e1.into_iter().chain(e2).collect()),
1936            (Err(e), _) | (_, Err(e)) => Err(e),
1937        }
1938    }
1939    pub fn substitute_type(t: TypeNodeId) -> TypeNodeId {
1940        match t.to_type() {
1941            Type::Intermediate(cell) => {
1942                let TypeVar { parent, .. } = &*cell.read().unwrap() as &TypeVar;
1943                match parent {
1944                    Some(p) => Self::substitute_type(*p),
1945                    None => Type::Unknown.into_id_with_location(t.to_loc()),
1946                }
1947            }
1948            _ => t.apply_fn(Self::substitute_type),
1949        }
1950    }
1951    fn substitute_all_intermediates(&mut self) {
1952        let mut e_list = self
1953            .result_memo
1954            .iter()
1955            .map(|(e, t)| (*e, Self::substitute_type(*t)))
1956            .collect::<Vec<_>>();
1957
1958        e_list.iter_mut().for_each(|(e, t)| {
1959            log::trace!("e: {:?} t: {}", e, t.to_type());
1960            let _old = self.result_memo.insert(*e, *t);
1961        })
1962    }
1963
1964    fn generalize(&mut self, t: TypeNodeId) -> TypeNodeId {
1965        match t.to_type() {
1966            Type::Intermediate(tvar) => {
1967                let &TypeVar { level, var, .. } = &*tvar.read().unwrap() as &TypeVar;
1968                if level > self.level {
1969                    self.get_typescheme(var, t.to_loc())
1970                } else {
1971                    t
1972                }
1973            }
1974            _ => t.apply_fn(|t| self.generalize(t)),
1975        }
1976    }
1977
1978    fn instantiate(&mut self, t: TypeNodeId) -> TypeNodeId {
1979        match t.to_type() {
1980            Type::TypeScheme(id) => {
1981                log::debug!("instantiate typescheme id: {id:?}");
1982                if let Some(tvar) = self.instantiated_map.get(&id) {
1983                    *tvar
1984                } else {
1985                    let res = self.gen_intermediate_type_with_location(t.to_loc());
1986                    self.instantiated_map.insert(id, res);
1987                    res
1988                }
1989            }
1990            _ => t.apply_fn(|t| self.instantiate(t)),
1991        }
1992    }
1993
1994    fn instantiate_fresh(&mut self, t: TypeNodeId) -> TypeNodeId {
1995        self.instantiated_map.clear();
1996        let res = self.instantiate(t);
1997        self.instantiated_map.clear();
1998        res
1999    }
2000
2001    // Note: the third argument `span` is used for the error location in case of
2002    // type mismatch. This is needed because `t`'s span refers to the location
2003    // where it originally defined (e.g. the explicit return type of the
2004    // function) and is not necessarily the same as where the error happens.
2005    fn bind_pattern(
2006        &mut self,
2007        pat: (TypedPattern, Location),
2008        body: (TypeNodeId, Location),
2009    ) -> Result<TypeNodeId, Vec<Error>> {
2010        let (TypedPattern { pat, ty, .. }, loc_p) = pat;
2011        let (body_t, loc_b) = body.clone();
2012        let should_generalize =
2013            !matches!(&pat, Pattern::Single(id) if *id == "record_update_temp".to_symbol());
2014        let mut bind_item = |pat| {
2015            let newloc = ty.to_loc();
2016            let ity = self.gen_intermediate_type_with_location(newloc.clone());
2017            let p = TypedPattern::new(pat, ity);
2018            self.bind_pattern((p, newloc.clone()), (ity, newloc))
2019        };
2020        let pat_t = match pat {
2021            Pattern::Single(id) => {
2022                let pat_t = self.convert_unknown_to_intermediate(ty, loc_p);
2023                log::trace!("bind {} : {}", id, pat_t.to_type());
2024                self.env.add_bind(&[(id, (pat_t, self.stage))]);
2025                Ok::<TypeNodeId, Vec<Error>>(pat_t)
2026            }
2027            Pattern::Placeholder => {
2028                // Placeholder doesn't bind anything, just check the type
2029                let pat_t = self.convert_unknown_to_intermediate(ty, loc_p);
2030                log::trace!("bind _ (placeholder) : {}", pat_t.to_type());
2031                Ok::<TypeNodeId, Vec<Error>>(pat_t)
2032            }
2033            Pattern::Tuple(pats) => {
2034                let elems = pats.iter().map(|p| bind_item(p.clone())).try_collect()?; //todo multiple errors
2035                let res = Type::Tuple(elems).into_id_with_location(loc_p);
2036                let target = self.convert_unknown_to_intermediate(ty, loc_b);
2037                let rel = self.unify_types(res, target)?;
2038                Ok(res)
2039            }
2040            Pattern::Record(items) => {
2041                let res = items
2042                    .iter()
2043                    .map(|(key, v)| {
2044                        bind_item(v.clone()).map(|ty| RecordTypeField {
2045                            key: *key,
2046                            ty,
2047                            has_default: false,
2048                        })
2049                    })
2050                    .try_collect()?; //todo multiple errors
2051                let res = Type::Record(res).into_id_with_location(loc_p);
2052                let target = self.convert_unknown_to_intermediate(ty, loc_b);
2053                let rel = self.unify_types(res, target)?;
2054                Ok(res)
2055            }
2056            Pattern::Error => Err(vec![Error::PatternMismatch(
2057                (
2058                    Type::Failure.into_id_with_location(loc_p.clone()),
2059                    loc_b.clone(),
2060                ),
2061                (pat, loc_p.clone()),
2062            )]),
2063        }?;
2064        let rel = self.unify_types(pat_t, body_t)?;
2065        if should_generalize {
2066            Ok(self.generalize(pat_t))
2067        } else {
2068            Ok(pat_t)
2069        }
2070    }
2071
2072    pub fn lookup(&self, name: Symbol, loc: Location) -> Result<TypeNodeId, Error> {
2073        use crate::utils::environment::LookupRes;
2074        let lookup_res = self.env.lookup_cls(&name);
2075        match lookup_res {
2076            LookupRes::Local((ty, bound_stage)) if self.stage == *bound_stage => Ok(*ty),
2077            LookupRes::UpValue(_, (ty, bound_stage)) if self.stage == *bound_stage => Ok(*ty),
2078            LookupRes::Global((ty, bound_stage))
2079                if self.stage == *bound_stage || *bound_stage == EvalStage::Persistent =>
2080            {
2081                Ok(*ty)
2082            }
2083            LookupRes::None => Err(Error::VariableNotFound(name, loc)),
2084            LookupRes::Local((_, bound_stage))
2085            | LookupRes::UpValue(_, (_, bound_stage))
2086            | LookupRes::Global((_, bound_stage)) => Err(Error::StageMismatch {
2087                variable: name,
2088                expected_stage: self.stage,
2089                found_stage: *bound_stage,
2090                location: loc,
2091            }),
2092        }
2093    }
2094
2095    /// Resolve a type through intermediate type variables, type aliases, and
2096    /// single-element wrappers, returning the concrete inner type.
2097    fn peel_to_inner(&self, ty: TypeNodeId) -> TypeNodeId {
2098        let resolved = self.resolve_type_alias(ty);
2099        match resolved.to_type() {
2100            Type::Intermediate(tv) => {
2101                let tv = tv.read().unwrap();
2102                let next = tv.parent.unwrap_or(tv.bound.lower);
2103                if next.0 == resolved.0 {
2104                    resolved
2105                } else {
2106                    self.peel_to_inner(next)
2107                }
2108            }
2109            Type::Tuple(elems) if elems.len() == 1 => self.peel_to_inner(elems[0]),
2110            _ => resolved,
2111        }
2112    }
2113
2114    /// Look up `field` in a type that may be wrapped in intermediates,
2115    /// type aliases, or single-element tuples.
2116    fn lookup_field_in_type(&self, ty: TypeNodeId, field: Symbol) -> FieldLookup {
2117        let peeled = self.peel_to_inner(ty);
2118        match peeled.to_type() {
2119            Type::Record(fields) => fields
2120                .iter()
2121                .find(|f| f.key == field)
2122                .map(|f| FieldLookup::Found(f.ty))
2123                .unwrap_or(FieldLookup::RecordWithoutField),
2124            _ => FieldLookup::NotRecord,
2125        }
2126    }
2127
2128    /// Type-check a field access expression (`expr.field`).
2129    ///
2130    /// Three cases:
2131    /// 1. Completely unresolved intermediate — constrain with `{field: ?a}`.
2132    /// 2. Resolves to a record containing `field` — return the field type.
2133    /// 3. Record exists but `field` is missing and the outer type is an
2134    ///    intermediate — extend the record constraint via unification.
2135    fn infer_field_access(
2136        &mut self,
2137        et: TypeNodeId,
2138        field: Symbol,
2139        loc: Location,
2140    ) -> Result<TypeNodeId, Vec<Error>> {
2141        // Case 1: completely unresolved intermediate — constrain as {field: ?a}.
2142        if let Type::Intermediate(tv) = et.to_type() {
2143            let is_unresolved = {
2144                let tv = tv.read().unwrap();
2145                let lower_is_record_like = match tv.bound.lower.to_type() {
2146                    Type::Record(_) => true,
2147                    Type::Tuple(elems) => elems.len() == 1,
2148                    _ => false,
2149                };
2150                tv.parent.is_none() && !lower_is_record_like
2151            };
2152            if is_unresolved {
2153                let field_ty = self.gen_intermediate_type_with_location(loc.clone());
2154                let expected = Type::Record(vec![RecordTypeField {
2155                    key: field,
2156                    ty: field_ty,
2157                    has_default: false,
2158                }])
2159                .into_id_with_location(loc);
2160                let _rel = self.unify_types(et, expected)?;
2161                return Ok(field_ty);
2162            }
2163        }
2164
2165        // Cases 2 & 3: peel wrappers and look up the field in the record.
2166        match self.lookup_field_in_type(et, field) {
2167            FieldLookup::Found(field_ty) => Ok(field_ty),
2168            FieldLookup::RecordWithoutField => self.extend_record_with_field(et, field, loc),
2169            FieldLookup::NotRecord => Err(vec![Error::FieldForNonRecord(loc, et)]),
2170        }
2171    }
2172
2173    /// When a record type is known but doesn't yet contain `field`, extend
2174    /// the record constraint via unification. Falls back to `FieldNotExist`
2175    /// when the expression type is not an intermediate.
2176    fn extend_record_with_field(
2177        &mut self,
2178        et: TypeNodeId,
2179        field: Symbol,
2180        loc: Location,
2181    ) -> Result<TypeNodeId, Vec<Error>> {
2182        if let Type::Intermediate(tv) = et.to_type() {
2183            let existing_fields = {
2184                let tv = tv.read().unwrap();
2185                match tv.parent.map(|p| p.to_type()) {
2186                    Some(Type::Record(fields)) => Some(fields),
2187                    _ => match tv.bound.lower.to_type() {
2188                        Type::Record(fields) => Some(fields),
2189                        _ => None,
2190                    },
2191                }
2192            };
2193            if let Some(mut fields) = existing_fields {
2194                let field_ty = self.gen_intermediate_type_with_location(loc.clone());
2195                if fields.iter().all(|f| f.key != field) {
2196                    fields.push(RecordTypeField {
2197                        key: field,
2198                        ty: field_ty,
2199                        has_default: false,
2200                    });
2201                }
2202                let extended = Type::Record(fields).into_id_with_location(loc);
2203                // Directly update the Intermediate's parent to the extended
2204                // record. We cannot use `unify_types` here because it calls
2205                // `get_root()`, which follows the parent chain past the
2206                // Intermediate to the old (incomplete) parent Record — ending
2207                // up in a Record-Record unification that never updates the
2208                // Intermediate itself.
2209                {
2210                    let mut guard = tv.write().unwrap();
2211                    guard.parent = Some(extended);
2212                }
2213                return Ok(field_ty);
2214            }
2215        }
2216        Err(vec![Error::FieldNotExist { field, loc, et }])
2217    }
2218
2219    pub(crate) fn infer_type_literal(e: &Literal, loc: Location) -> Result<TypeNodeId, Error> {
2220        let pt = match e {
2221            Literal::Float(_) | Literal::Now | Literal::SampleRate => PType::Numeric,
2222            Literal::Int(_s) => PType::Int,
2223            Literal::String(_s) => PType::String,
2224            Literal::SelfLit => panic!("\"self\" should not be shown at type inference stage"),
2225            Literal::PlaceHolder => panic!("\"_\" should not be shown at type inference stage"),
2226        };
2227        Ok(Type::Primitive(pt).into_id_with_location(loc))
2228    }
2229    fn infer_vec(&mut self, e: &[ExprNodeId]) -> Result<Vec<TypeNodeId>, Vec<Error>> {
2230        e.iter().map(|e| self.infer_type(*e)).try_collect()
2231    }
2232    fn infer_type_levelup(&mut self, e: ExprNodeId) -> TypeNodeId {
2233        self.level += 1;
2234        let res = self.infer_type_unwrapping(e);
2235        self.level -= 1;
2236        res
2237    }
2238    pub fn infer_type(&mut self, e: ExprNodeId) -> Result<TypeNodeId, Vec<Error>> {
2239        if let Some(r) = self.result_memo.get(&e.0) {
2240            //use cached result
2241            return Ok(*r);
2242        }
2243        let loc = e.to_location();
2244        let res: Result<TypeNodeId, Vec<Error>> = match &e.to_expr() {
2245            Expr::Literal(l) => Self::infer_type_literal(l, loc).map_err(|e| vec![e]),
2246            Expr::Tuple(e) => {
2247                Ok(Type::Tuple(self.infer_vec(e.as_slice())?).into_id_with_location(loc))
2248            }
2249            Expr::ArrayLiteral(e) => {
2250                let elem_types = self.infer_vec(e.as_slice())?;
2251                let first = elem_types
2252                    .first()
2253                    .copied()
2254                    .unwrap_or_else(|| self.gen_intermediate_type_with_location(loc.clone()));
2255                //todo:collect multiple errors
2256                let elem_t = elem_types
2257                    .iter()
2258                    .try_fold(first, |acc, t| self.unify_types(acc, *t).map(|rel| *t))?;
2259
2260                Ok(Type::Array(elem_t).into_id_with_location(loc.clone()))
2261            }
2262            Expr::ArrayAccess(e, idx) => {
2263                let arr_t = self.infer_type_unwrapping(*e);
2264                let loc_e = e.to_location();
2265                let idx_t = self.infer_type_unwrapping(*idx);
2266                let loc_i = idx.to_location();
2267
2268                let elem_t = self.gen_intermediate_type_with_location(loc_e.clone());
2269
2270                let rel1 = self.unify_types(
2271                    idx_t,
2272                    Type::Primitive(PType::Numeric).into_id_with_location(loc_i),
2273                );
2274                let rel2 = self.unify_types(
2275                    Type::Array(elem_t).into_id_with_location(loc_e.clone()),
2276                    arr_t,
2277                );
2278                self.merge_rel_result(rel1, rel2, arr_t, idx_t)?;
2279                Ok(elem_t)
2280            }
2281            Expr::Proj(e, idx) => {
2282                let tup = self.infer_type_unwrapping(*e);
2283                // we directly inspect if the intermediate type is a tuple or not.
2284                // this is because we can not infer the number of fields in the tuple from the fields access expression.
2285                // This rule will be loosened when structural subtyping is implemented.
2286                let vec_to_ans = |vec: &[_]| {
2287                    if vec.len() < *idx as usize {
2288                        Err(vec![Error::IndexOutOfRange {
2289                            len: vec.len() as u16,
2290                            idx: *idx as u16,
2291                            loc: loc.clone(),
2292                        }])
2293                    } else {
2294                        Ok(vec[*idx as usize])
2295                    }
2296                };
2297                match tup.to_type() {
2298                    Type::Tuple(vec) => vec_to_ans(&vec),
2299                    Type::Intermediate(tv) => {
2300                        let tv = tv.read().unwrap();
2301                        if let Some(parent) = tv.parent {
2302                            match parent.to_type() {
2303                                Type::Tuple(vec) => vec_to_ans(&vec),
2304                                _ => Err(vec![Error::IndexForNonTuple(loc, tup)]),
2305                            }
2306                        } else {
2307                            Err(vec![Error::IndexForNonTuple(loc, tup)])
2308                        }
2309                    }
2310                    _ => Err(vec![Error::IndexForNonTuple(loc, tup)]),
2311                }
2312            }
2313            Expr::RecordLiteral(kvs) => {
2314                let duplicate_keys = kvs
2315                    .iter()
2316                    .map(|RecordField { name, .. }| *name)
2317                    .duplicates();
2318                if duplicate_keys.clone().count() > 0 {
2319                    Err(vec![Error::DuplicateKeyInRecord {
2320                        key: duplicate_keys.collect(),
2321                        loc,
2322                    }])
2323                } else {
2324                    let kts: Vec<_> = kvs
2325                        .iter()
2326                        .map(|RecordField { name, expr }| {
2327                            let ty = self.infer_type_unwrapping(*expr);
2328                            RecordTypeField {
2329                                key: *name,
2330                                ty,
2331                                has_default: true,
2332                            }
2333                        })
2334                        .collect();
2335                    Ok(Type::Record(kts).into_id_with_location(loc))
2336                }
2337            }
2338            Expr::RecordUpdate(_, _) => {
2339                // RecordUpdate should never reach type inference as it gets expanded
2340                // to Block/Let/Assign expressions during syntax sugar conversion in convert_pronoun.rs
2341                unreachable!("RecordUpdate should be expanded before type inference")
2342            }
2343            Expr::FieldAccess(expr, field) => {
2344                let et = self.infer_type_unwrapping(*expr);
2345                log::trace!("field access {} : {}", field, et.to_type());
2346                self.infer_field_access(et, *field, loc)
2347            }
2348            Expr::Feed(id, body) => {
2349                //todo: add span to Feed expr for keeping the location of `self`.
2350                let feedv = self.gen_intermediate_type_with_location(loc);
2351
2352                self.env.add_bind(&[(*id, (feedv, self.stage))]);
2353                let bty = self.infer_type_unwrapping(*body);
2354                let _rel = self.unify_types(bty, feedv)?;
2355                if bty.to_type().contains_function() {
2356                    Err(vec![Error::NonPrimitiveInFeed(body.to_location())])
2357                } else {
2358                    Ok(bty)
2359                }
2360            }
2361            Expr::Lambda(p, rtype, body) => {
2362                let mut scoped_types = p
2363                    .iter()
2364                    .map(|id| id.ty)
2365                    .filter(|ty| ty.to_type() != Type::Unknown)
2366                    .collect::<Vec<_>>();
2367                rtype.iter().copied().for_each(|ty| scoped_types.push(ty));
2368                self.with_explicit_type_param_scope_from_types(&scoped_types, |this| {
2369                    this.env.extend();
2370                    let lambda_res = (|| -> Result<TypeNodeId, Vec<Error>> {
2371                        this.instantiated_map.clear();
2372                        let dup = p.iter().duplicates_by(|id| id.id).map(|id| {
2373                            let loc = Location::new(id.to_span(), this.file_path.clone());
2374                            (id.id, loc)
2375                        });
2376                        if dup.clone().count() > 0 {
2377                            return Err(vec![Error::DuplicateKeyInParams(dup.collect())]);
2378                        }
2379                        let pvec = p
2380                            .iter()
2381                            .map(|id| {
2382                                let annotated_ty =
2383                                    this.convert_unknown_to_intermediate(id.ty, id.ty.to_loc());
2384                                let annotated_ty = this.resolve_type_alias(annotated_ty);
2385                                let ity = this.instantiate(annotated_ty);
2386                                this.env.add_bind(&[(id.id, (ity, this.stage))]);
2387                                RecordTypeField {
2388                                    key: id.id,
2389                                    ty: ity,
2390                                    has_default: id.default_value.is_some(),
2391                                }
2392                            })
2393                            .collect::<Vec<_>>();
2394                        let ptype = if pvec.is_empty() {
2395                            Type::Primitive(PType::Unit).into_id_with_location(loc.clone())
2396                        } else if pvec.len() == 1 {
2397                            pvec[0].ty
2398                        } else {
2399                            Type::Record(pvec).into_id_with_location(loc.clone())
2400                        };
2401                        let bty = if let Some(r) = rtype {
2402                            let annotated_ret =
2403                                this.convert_unknown_to_intermediate(*r, r.to_loc());
2404                            let annotated_ret = this.resolve_type_alias(annotated_ret);
2405                            let expected_ret = this.instantiate(annotated_ret);
2406                            let bty = this.infer_type_unwrapping(*body);
2407                            let _rel = this.unify_types(expected_ret, bty)?;
2408                            bty
2409                        } else {
2410                            this.infer_type_unwrapping(*body)
2411                        };
2412                        this.instantiated_map.clear();
2413                        Ok(Type::Function {
2414                            arg: ptype,
2415                            ret: bty,
2416                        }
2417                        .into_id_with_location(e.to_location()))
2418                    })();
2419                    this.env.to_outer();
2420                    this.instantiated_map.clear();
2421                    lambda_res
2422                })
2423            }
2424            Expr::Let(tpat, body, then) => {
2425                let bodyt = self.infer_type_levelup(*body);
2426
2427                let loc_p = tpat.to_loc();
2428                let loc_b = body.to_location();
2429
2430                // Check for private type leak in public function declarations
2431                // Use the original type before resolution to catch TypeAlias references
2432                if let Pattern::Single(name) = &tpat.pat {
2433                    log::trace!(
2434                        "Checking private type leak for Let binding: {}",
2435                        name.as_str()
2436                    );
2437                    log::trace!("Original type before resolution: {:?}", tpat.ty.to_type());
2438                    self.check_private_type_leak(*name, tpat.ty, loc_p.clone());
2439                }
2440
2441                let pat_t = self.with_explicit_type_param_scope_from_types(&[tpat.ty], |this| {
2442                    this.bind_pattern((tpat.clone(), loc_p), (bodyt, loc_b))
2443                });
2444                let _pat_t = self.unwrap_result(pat_t);
2445                match then {
2446                    Some(e) => self.infer_type(*e),
2447                    None => Ok(Type::Primitive(PType::Unit).into_id_with_location(loc)),
2448                }
2449            }
2450            Expr::LetRec(id, body, then) => {
2451                let body_expr = *body;
2452                let mut scoped_types = vec![id.ty];
2453                if let Expr::Lambda(params, rtype, _) = body_expr.to_expr() {
2454                    params
2455                        .iter()
2456                        .filter(|param| param.ty.to_type() != Type::Unknown)
2457                        .for_each(|param| scoped_types.push(param.ty));
2458                    rtype.iter().copied().for_each(|ret| scoped_types.push(ret));
2459                }
2460
2461                self.with_explicit_type_param_scope_from_types(&scoped_types, |this| {
2462                    let idt = this.provisional_letrec_binding_type(id, body_expr, loc.clone());
2463                    this.env.add_bind(&[(id.id, (idt, this.stage))]);
2464                    //polymorphic inference is not allowed in recursive function.
2465
2466                    let bodyt = this.infer_type_levelup(body_expr);
2467
2468                    let _res = this.unify_types(idt, bodyt);
2469
2470                    // Check if public function leaks private type in its declared signature
2471                    this.check_private_type_leak(id.id, id.ty, loc.clone());
2472                });
2473
2474                match then {
2475                    Some(e) => self.infer_type(*e),
2476                    None => Ok(Type::Primitive(PType::Unit).into_id_with_location(loc)),
2477                }
2478            }
2479            Expr::Assign(assignee, expr) => {
2480                match assignee.to_expr() {
2481                    Expr::Var(name) => {
2482                        let assignee_t =
2483                            self.unwrap_result(self.lookup(name, loc).map_err(|e| vec![e]));
2484                        let e_t = self.infer_type_unwrapping(*expr);
2485                        let _rel = self.unify_types(assignee_t, e_t)?;
2486                        Ok(unit!())
2487                    }
2488                    Expr::FieldAccess(record, field_name) => {
2489                        // Handle field assignment: record.field = value
2490                        let _record_type = self.infer_type_unwrapping(record);
2491                        let value_type = self.infer_type_unwrapping(*expr);
2492                        let field_type = self.infer_type_unwrapping(*assignee);
2493                        let _rel = self.unify_types(field_type, value_type)?;
2494                        Ok(unit!())
2495                    }
2496                    Expr::ArrayAccess(_, _) => {
2497                        unimplemented!("Assignment to array is not implemented yet.")
2498                    }
2499                    _ => {
2500                        // This should be caught by parser, but add a generic error just in case
2501                        Err(vec![Error::VariableNotFound(
2502                            "invalid_assignment_target".to_symbol(),
2503                            loc.clone(),
2504                        )])
2505                    }
2506                }
2507            }
2508            Expr::Then(e, then) => {
2509                let _ = self.infer_type(*e)?;
2510                then.map_or(Ok(unit!()), |t| self.infer_type(t))
2511            }
2512            Expr::Var(name) => {
2513                // First check if this is a constructor from a user-defined sum type
2514                if let Some(constructor_info) = self.constructor_env.get(name) {
2515                    if let Some(payload_ty) = constructor_info.payload_type {
2516                        // Constructor with payload: type is `payload_type -> sum_type`
2517                        let fn_type = Type::Function {
2518                            arg: payload_ty,
2519                            ret: constructor_info.sum_type,
2520                        }
2521                        .into_id_with_location(loc.clone());
2522                        return Ok(fn_type);
2523                    } else {
2524                        // Constructor without payload: type is the sum type itself
2525                        return Ok(constructor_info.sum_type);
2526                    }
2527                }
2528                // Aliases and wildcards are already resolved by convert_qualified_names
2529                let res = self.unwrap_result(self.lookup(*name, loc).map_err(|e| vec![e]));
2530                Ok(self.instantiate_fresh(res))
2531            }
2532            Expr::QualifiedVar(path) => {
2533                unreachable!("Qualified Var should be removed in the previous step.")
2534            }
2535            Expr::Apply(fun, callee) => {
2536                let loc_f = fun.to_location();
2537                if callee.len() == 2 && self.try_get_tuple_arithmetic_binop_label(*fun).is_some() {
2538                    let lhs_ty = self.infer_type_unwrapping(callee[0]);
2539                    let rhs_ty = self.infer_type_unwrapping(callee[1]);
2540                    let lhs_is_tuple = matches!(
2541                        self.resolve_for_tuple_binop(lhs_ty).to_type(),
2542                        Type::Tuple(_)
2543                    );
2544                    let rhs_is_tuple = matches!(
2545                        self.resolve_for_tuple_binop(rhs_ty).to_type(),
2546                        Type::Tuple(_)
2547                    );
2548                    if lhs_is_tuple || rhs_is_tuple {
2549                        return self.infer_tuple_arithmetic_binop_type(
2550                            lhs_ty,
2551                            rhs_ty,
2552                            loc_f.clone(),
2553                        );
2554                    }
2555                }
2556
2557                if callee.len() == 1 {
2558                    let fnl = self.infer_type_unwrapping(*fun);
2559                    let arg_ty = self.infer_type_unwrapping(callee[0]);
2560                    let arg_is_tuple = matches!(
2561                        self.resolve_for_tuple_binop(arg_ty).to_type(),
2562                        Type::Tuple(_)
2563                    );
2564                    if arg_is_tuple && self.is_numeric_to_numeric_function_for_auto_spread(fnl) {
2565                        return self.infer_auto_spread_type(fnl, arg_ty, loc_f.clone());
2566                    }
2567
2568                    let try_record_default_pack = || -> Result<Option<TypeNodeId>, Vec<Error>> {
2569                        let fn_ty = self.peel_to_inner(fnl);
2570                        let arg_ty_resolved = self.peel_to_inner(arg_ty);
2571                        let (fn_arg, fn_ret) = match fn_ty.to_type() {
2572                            Type::Function { arg, ret } => (arg, ret),
2573                            _ => return Ok(None),
2574                        };
2575                        let fn_arg_resolved = self.peel_to_inner(fn_arg);
2576                        let (param_fields, provided_fields) =
2577                            match (fn_arg_resolved.to_type(), arg_ty_resolved.to_type()) {
2578                                (Type::Record(param_fields), Type::Record(provided_fields)) => {
2579                                    (param_fields, provided_fields)
2580                                }
2581                                _ => return Ok(None),
2582                            };
2583
2584                        let mut matched_any = false;
2585                        for param in param_fields.iter() {
2586                            if let Some(provided) =
2587                                provided_fields.iter().find(|field| field.key == param.key)
2588                            {
2589                                matched_any = true;
2590                                let _ = self.unify_types(param.ty, provided.ty)?;
2591                            } else if !param.has_default {
2592                                return Ok(None);
2593                            }
2594                        }
2595
2596                        Ok(matched_any.then_some(fn_ret))
2597                    };
2598
2599                    if let Some(ret_ty) = try_record_default_pack()? {
2600                        return Ok(ret_ty);
2601                    }
2602                }
2603
2604                let fnl = self.infer_type_unwrapping(*fun);
2605                let callee_t = match callee.len() {
2606                    0 => Type::Primitive(PType::Unit).into_id_with_location(loc.clone()),
2607                    1 => self.infer_type_unwrapping(callee[0]),
2608                    _ => {
2609                        let at_vec = self.infer_vec(callee.as_slice())?;
2610
2611                        let span = callee[0].to_span().start..callee.last().unwrap().to_span().end;
2612                        let loc = Location::new(span, self.file_path.clone());
2613                        Type::Tuple(at_vec).into_id_with_location(loc)
2614                    }
2615                };
2616                let res_t = self.gen_intermediate_type_with_location(loc);
2617                let fntype = Type::Function {
2618                    arg: callee_t,
2619                    ret: res_t,
2620                }
2621                .into_id_with_location(loc_f.clone());
2622                match self.unify_types(fnl, fntype)? {
2623                    Relation::Subtype => Err(vec![Error::NonSupertypeArgument {
2624                        location: loc_f.clone(),
2625                        expected: fnl,
2626                        found: fntype,
2627                    }]),
2628                    _ => Ok(res_t),
2629                }
2630            }
2631            Expr::If(cond, then, opt_else) => {
2632                let condt = self.infer_type_unwrapping(*cond);
2633                let cond_loc = cond.to_location();
2634                let bt = self.unify_types(
2635                    Type::Primitive(PType::Numeric).into_id_with_location(cond_loc),
2636                    condt,
2637                )?; //todo:boolean type
2638                //todo: introduce row polymophism so that not narrowing the type of `then` and `else` too much.
2639                let thent = self.infer_type_unwrapping(*then);
2640                let elset = opt_else.map_or(Type::Primitive(PType::Unit).into_id(), |e| {
2641                    self.infer_type_unwrapping(e)
2642                });
2643                let rel = self.unify_types(thent, elset)?;
2644                Ok(thent)
2645            }
2646            Expr::Block(expr) => expr.map_or(
2647                Ok(Type::Primitive(PType::Unit).into_id_with_location(loc)),
2648                |e| {
2649                    self.env.extend(); //block creates local scope.
2650                    let res = self.infer_type(e);
2651                    self.env.to_outer();
2652                    res
2653                },
2654            ),
2655            Expr::Escape(e) => {
2656                let loc_e = loc.clone();
2657                let prev_stage = self.stage;
2658                // Decrease stage for escape expression
2659                self.stage = prev_stage.decrement();
2660                log::trace!("Unstaging escape expression, stage => {:?}", self.stage);
2661                let res = self.infer_type_unwrapping(*e);
2662                // Restore previous stage regardless of saturation behavior
2663                self.stage = prev_stage;
2664                if matches!(res.get_root().to_type(), Type::Primitive(PType::Unit)) {
2665                    return Ok(Type::Primitive(PType::Unit).into_id_with_location(loc_e));
2666                }
2667                if !matches!(res.get_root().to_type(), Type::Code(_))
2668                    && res.get_root().to_type().contains_code()
2669                {
2670                    return Err(vec![Error::EscapeRequiresCodeType {
2671                        found: (res.get_root(), loc_e),
2672                    }]);
2673                }
2674                let intermediate = self.gen_intermediate_type_with_location(loc_e.clone());
2675                let rel = self.unify_types(
2676                    res,
2677                    Type::Code(intermediate).into_id_with_location(loc_e.clone()),
2678                )?;
2679                Ok(intermediate)
2680            }
2681            Expr::Bracket(e) => {
2682                let loc_e = loc.clone();
2683                let prev_stage = self.stage;
2684                // Increase stage for bracket expression
2685                self.stage = prev_stage.increment();
2686                log::trace!("Staging bracket expression, stage => {:?}", self.stage);
2687                let res = self.infer_type_unwrapping(*e);
2688                // Restore previous stage regardless of boundary behavior
2689                self.stage = prev_stage;
2690                Ok(Type::Code(res).into_id_with_location(loc_e))
2691            }
2692            Expr::Match(scrutinee, arms) => {
2693                // Infer type of scrutinee
2694                let scrut_ty = self.infer_type_unwrapping(*scrutinee);
2695
2696                // Infer types of all arm bodies, handling patterns with variable bindings
2697                let arm_tys: Vec<TypeNodeId> = arms
2698                    .iter()
2699                    .map(|arm| {
2700                        match &arm.pattern {
2701                            crate::ast::MatchPattern::Literal(lit) => {
2702                                // For numeric patterns, check scrutinee is numeric
2703                                let pat_ty = match lit {
2704                                    crate::ast::Literal::Int(_) | crate::ast::Literal::Float(_) => {
2705                                        Type::Primitive(PType::Numeric)
2706                                            .into_id_with_location(loc.clone())
2707                                    }
2708                                    _ => Type::Failure.into_id_with_location(loc.clone()),
2709                                };
2710                                let _ = self.unify_types(scrut_ty, pat_ty);
2711                                self.infer_type_unwrapping(arm.body)
2712                            }
2713                            crate::ast::MatchPattern::Wildcard => {
2714                                // Wildcard matches anything
2715                                self.infer_type_unwrapping(arm.body)
2716                            }
2717                            crate::ast::MatchPattern::Variable(_) => {
2718                                // Variable pattern binds the whole value
2719                                // This should typically be handled by Constructor pattern
2720                                self.infer_type_unwrapping(arm.body)
2721                            }
2722                            crate::ast::MatchPattern::Constructor(constructor_name, binding) => {
2723                                // Handle constructor patterns for union types
2724                                // Find the type associated with this constructor in the union
2725                                let binding_ty = self
2726                                    .get_constructor_type_from_union(scrut_ty, *constructor_name);
2727
2728                                if let Some(inner_pattern) = binding {
2729                                    // Add bindings for the inner pattern
2730                                    self.env.extend();
2731                                    self.add_pattern_bindings(inner_pattern, binding_ty);
2732                                    let body_ty = self.infer_type_unwrapping(arm.body);
2733                                    self.env.to_outer();
2734                                    body_ty
2735                                } else {
2736                                    self.infer_type_unwrapping(arm.body)
2737                                }
2738                            }
2739                            crate::ast::MatchPattern::Tuple(patterns) => {
2740                                // Tuple pattern in a match arm for multi-scrutinee matching
2741                                // The scrutinee should be a tuple and we need to bind variables
2742                                // from each sub-pattern
2743                                self.env.extend();
2744
2745                                // Get the scrutinee type and check it's a tuple
2746                                let resolved_scrut_ty = scrut_ty.get_root().to_type();
2747                                if let Type::Tuple(elem_types) = resolved_scrut_ty {
2748                                    // Type check each pattern element against corresponding
2749                                    // scrutinee element
2750                                    for (pat, elem_ty) in patterns.iter().zip(elem_types.iter()) {
2751                                        self.check_pattern_against_type(pat, *elem_ty, &loc);
2752                                    }
2753                                } else {
2754                                    // If scrutinee is not a tuple, check each pattern against
2755                                    // the whole type (for error reporting)
2756                                    for pat in patterns.iter() {
2757                                        self.check_pattern_against_type(pat, scrut_ty, &loc);
2758                                    }
2759                                }
2760
2761                                let body_ty = self.infer_type_unwrapping(arm.body);
2762                                self.env.to_outer();
2763                                body_ty
2764                            }
2765                        }
2766                    })
2767                    .collect();
2768
2769                // Record Match expression for exhaustiveness checking after type resolution
2770                self.match_expressions.push((e, scrut_ty));
2771
2772                if arm_tys.is_empty() {
2773                    Ok(Type::Primitive(PType::Unit).into_id_with_location(loc))
2774                } else {
2775                    let first = arm_tys[0];
2776                    for ty in arm_tys.iter().skip(1) {
2777                        let _ = self.unify_types(first, *ty);
2778                    }
2779                    Ok(first)
2780                }
2781            }
2782            _ => Ok(Type::Failure.into_id_with_location(loc)),
2783        };
2784        res.inspect(|ty| {
2785            self.result_memo.insert(e.0, *ty);
2786        })
2787    }
2788    fn infer_type_unwrapping(&mut self, e: ExprNodeId) -> TypeNodeId {
2789        match self.infer_type(e) {
2790            Ok(t) => t,
2791            Err(err) => {
2792                let failure_ty = Type::Failure
2793                    .into_id_with_location(Location::new(e.to_span(), self.file_path.clone()));
2794                self.errors.extend(err);
2795                self.result_memo.insert(e.0, failure_ty);
2796                failure_ty
2797            }
2798        }
2799    }
2800
2801    /// Check if a match expression is exhaustive
2802    /// Returns a list of missing constructor names if not exhaustive
2803    fn check_match_exhaustiveness(
2804        &self,
2805        scrutinee_ty: TypeNodeId,
2806        arms: &[crate::ast::MatchArm],
2807    ) -> Option<Vec<Symbol>> {
2808        // Get all constructors required for the scrutinee type
2809        let required_constructors = self.get_all_constructors(scrutinee_ty);
2810
2811        // If there are no constructors (e.g., primitive types), no exhaustiveness check needed
2812        if required_constructors.is_empty() {
2813            return None;
2814        }
2815
2816        // Check if there's a wildcard pattern (covers everything)
2817        let has_wildcard = arms.iter().any(|arm| {
2818            matches!(
2819                &arm.pattern,
2820                crate::ast::MatchPattern::Wildcard
2821                    | crate::ast::MatchPattern::Variable(_)
2822                    | crate::ast::MatchPattern::Tuple(_)
2823            )
2824        });
2825
2826        // If there's a wildcard, the match is exhaustive
2827        if has_wildcard {
2828            return None;
2829        }
2830
2831        // Collect constructors covered by patterns (only Constructor patterns contribute)
2832        let covered_constructors: Vec<Symbol> = arms
2833            .iter()
2834            .filter_map(|arm| {
2835                if let crate::ast::MatchPattern::Constructor(name, _) = &arm.pattern {
2836                    Some(*name)
2837                } else {
2838                    None
2839                }
2840            })
2841            .collect();
2842
2843        // Find missing constructors
2844        let missing: Vec<Symbol> = required_constructors
2845            .into_iter()
2846            .filter(|req| !covered_constructors.contains(req))
2847            .collect();
2848
2849        if missing.is_empty() {
2850            None
2851        } else {
2852            Some(missing)
2853        }
2854    }
2855
2856    /// Get all constructor names for a type
2857    /// For union types like `float | string`, returns ["float", "string"]
2858    /// For UserSum types, returns all constructor names
2859    fn get_all_constructors(&self, ty: TypeNodeId) -> Vec<Symbol> {
2860        // First resolve type aliases, then substitute intermediate types
2861        let resolved = self.resolve_type_alias(ty);
2862        let substituted = Self::substitute_type(resolved);
2863
2864        match substituted.to_type() {
2865            Type::Union(variants) => {
2866                // For union types, get the constructor name for each variant
2867                variants
2868                    .iter()
2869                    .filter_map(|v| {
2870                        let v_resolved = Self::substitute_type(*v);
2871                        Self::type_constructor_name(&v_resolved.to_type())
2872                    })
2873                    .collect()
2874            }
2875            Type::UserSum { name: _, variants } => {
2876                // For UserSum types, collect all constructor names
2877                variants.iter().map(|(name, _)| *name).collect()
2878            }
2879            _ => {
2880                // For non-union/sum types, no constructors to check
2881                Vec::new()
2882            }
2883        }
2884    }
2885
2886    /// Check exhaustiveness for all recorded Match expressions
2887    /// This should be called after type resolution (substitute_all_intermediates)
2888    pub fn check_all_match_exhaustiveness(&mut self) {
2889        let match_expressions = std::mem::take(&mut self.match_expressions);
2890
2891        let errors: Vec<_> = match_expressions
2892            .into_iter()
2893            .filter_map(|(match_expr, scrut_ty)| {
2894                if let Expr::Match(_scrutinee, arms) = &match_expr.to_expr() {
2895                    let resolved_scrut_ty = self.resolve_type_alias(scrut_ty);
2896                    let substituted_scrut_ty = Self::substitute_type(resolved_scrut_ty);
2897
2898                    self.check_match_exhaustiveness(substituted_scrut_ty, arms)
2899                        .map(|missing| Error::NonExhaustiveMatch {
2900                            missing_constructors: missing,
2901                            location: match_expr.to_location(),
2902                        })
2903                } else {
2904                    None
2905                }
2906            })
2907            .collect();
2908
2909        self.errors.extend(errors);
2910    }
2911}
2912
2913pub fn infer_root(
2914    e: ExprNodeId,
2915    builtin_types: &[(Symbol, TypeNodeId)],
2916    file_path: PathBuf,
2917    type_declarations: Option<&crate::ast::program::TypeDeclarationMap>,
2918    type_aliases: Option<&crate::ast::program::TypeAliasMap>,
2919    module_info: Option<crate::ast::program::ModuleInfo>,
2920) -> InferContext {
2921    use std::sync::atomic::{AtomicUsize, Ordering};
2922    static INFER_ROOT_COUNTER: AtomicUsize = AtomicUsize::new(0);
2923    let call_id = INFER_ROOT_COUNTER.fetch_add(1, Ordering::Relaxed);
2924    let mut ctx = InferContext::new(
2925        builtin_types,
2926        file_path.clone(),
2927        type_declarations,
2928        type_aliases,
2929        module_info,
2930    );
2931    ctx.infer_root_id = call_id;
2932    let _t = ctx
2933        .infer_type(e)
2934        .unwrap_or(Type::Failure.into_id_with_location(e.to_location()));
2935    ctx.substitute_all_intermediates();
2936    ctx.check_all_match_exhaustiveness();
2937    ctx
2938}
2939
2940#[cfg(test)]
2941mod tests {
2942    use super::*;
2943    use crate::interner::ToSymbol;
2944    use crate::types::Type;
2945    use crate::utils::metadata::{Location, Span};
2946
2947    fn create_test_context() -> InferContext {
2948        InferContext::new(&[], PathBuf::from("test"), None, None, None)
2949    }
2950
2951    fn create_test_location() -> Location {
2952        Location::new(Span { start: 0, end: 0 }, PathBuf::from("test"))
2953    }
2954
2955    #[test]
2956    fn test_stage_mismatch_detection() {
2957        let mut ctx = create_test_context();
2958        let loc = create_test_location();
2959
2960        // Define a variable 'x' at stage 0
2961        let var_name = "x".to_symbol();
2962        let var_type =
2963            Type::Primitive(crate::types::PType::Numeric).into_id_with_location(loc.clone());
2964        ctx.env
2965            .add_bind(&[(var_name, (var_type, EvalStage::Stage(0)))]);
2966
2967        // Try to look it up from stage 0 - should succeed
2968        ctx.stage = EvalStage::Stage(0);
2969        let result = ctx.lookup(var_name, loc.clone());
2970        assert!(
2971            result.is_ok(),
2972            "Looking up variable from same stage should succeed"
2973        );
2974
2975        // Try to look it up from stage 1 - should fail with stage mismatch
2976        ctx.stage = EvalStage::Stage(1);
2977        let result = ctx.lookup(var_name, loc.clone());
2978        assert!(
2979            result.is_err(),
2980            "Looking up variable from different stage should fail"
2981        );
2982
2983        if let Err(Error::StageMismatch {
2984            variable,
2985            expected_stage,
2986            found_stage,
2987            ..
2988        }) = result
2989        {
2990            assert_eq!(variable, var_name);
2991            assert_eq!(expected_stage, EvalStage::Stage(1));
2992            assert_eq!(found_stage, EvalStage::Stage(0));
2993        } else {
2994            panic!("Expected StageMismatch error, got: {result:?}");
2995        }
2996    }
2997
2998    #[test]
2999    fn test_persistent_stage_access() {
3000        let mut ctx = create_test_context();
3001        let loc = create_test_location();
3002
3003        // Define a variable at Persistent stage
3004        let var_name = "persistent_var".to_symbol();
3005        let var_type =
3006            Type::Primitive(crate::types::PType::Numeric).into_id_with_location(loc.clone());
3007        ctx.env
3008            .add_bind(&[(var_name, (var_type, EvalStage::Persistent))]);
3009
3010        // Try to access from different stages - should all succeed
3011        for stage in [0, 1, 2] {
3012            ctx.stage = EvalStage::Stage(stage);
3013            let result = ctx.lookup(var_name, loc.clone());
3014            assert!(
3015                result.is_ok(),
3016                "Persistent stage variables should be accessible from stage {stage}"
3017            );
3018        }
3019    }
3020
3021    #[test]
3022    fn test_same_stage_access() {
3023        let mut ctx = create_test_context();
3024        let loc = create_test_location();
3025
3026        // Define variables at different stages
3027        for stage in [0, 1, 2] {
3028            let var_name = format!("var_stage_{stage}").to_symbol();
3029            let var_type =
3030                Type::Primitive(crate::types::PType::Numeric).into_id_with_location(loc.clone());
3031            ctx.env
3032                .add_bind(&[(var_name, (var_type, EvalStage::Stage(stage)))]);
3033        }
3034
3035        // Each variable should only be accessible from its own stage
3036        for stage in [0, 1, 2] {
3037            ctx.stage = EvalStage::Stage(stage);
3038            let var_name = format!("var_stage_{stage}").to_symbol();
3039            let result = ctx.lookup(var_name, loc.clone());
3040            assert!(
3041                result.is_ok(),
3042                "Variable should be accessible from its own stage {stage}"
3043            );
3044
3045            // Should not be accessible from other stages
3046            for other_stage in [0, 1, 2] {
3047                if other_stage != stage {
3048                    ctx.stage = EvalStage::Stage(other_stage);
3049                    let result = ctx.lookup(var_name, loc.clone());
3050                    assert!(
3051                        result.is_err(),
3052                        "Variable from stage {stage} should not be accessible from stage {other_stage}",
3053                    );
3054                }
3055            }
3056        }
3057    }
3058
3059    #[test]
3060    fn test_stage_transitions_bracket_escape() {
3061        let mut ctx = create_test_context();
3062
3063        // Test that stage transitions work correctly
3064        assert_eq!(ctx.stage, EvalStage::Stage(0), "Initial stage should be 0");
3065
3066        // Simulate bracket behavior - stage increment
3067        ctx.stage = ctx.stage.increment();
3068        assert_eq!(
3069            ctx.stage,
3070            EvalStage::Stage(1),
3071            "Stage should increment to 1 in bracket"
3072        );
3073
3074        // Simulate escape behavior - stage decrement
3075        ctx.stage = ctx.stage.decrement();
3076        assert_eq!(
3077            ctx.stage,
3078            EvalStage::Stage(0),
3079            "Stage should decrement back to 0 after escape"
3080        );
3081    }
3082
3083    #[test]
3084    fn test_multi_stage_environment() {
3085        let mut ctx = create_test_context();
3086        let loc = create_test_location();
3087
3088        // Create nested scope with different stages
3089        ctx.env.extend(); // Create new scope
3090
3091        // Add variable at stage 0
3092        let var_stage0 = "x".to_symbol();
3093        let var_type =
3094            Type::Primitive(crate::types::PType::Numeric).into_id_with_location(loc.clone());
3095        ctx.stage = EvalStage::Stage(0);
3096        ctx.env
3097            .add_bind(&[(var_stage0, (var_type, EvalStage::Stage(0)))]);
3098
3099        ctx.env.extend(); // Create another scope
3100
3101        // Add variable with same name at stage 1
3102        let var_stage1 = "x".to_symbol(); // Same name, different stage
3103        ctx.stage = EvalStage::Stage(1);
3104        ctx.env
3105            .add_bind(&[(var_stage1, (var_type, EvalStage::Stage(1)))]);
3106
3107        // Test lookups from different stages
3108        ctx.stage = EvalStage::Stage(0);
3109        let result = ctx.lookup(var_stage0, loc.clone());
3110        assert!(
3111            result.is_err(),
3112            "Stage 0 variable should not be accessible from nested stage 0 context due to shadowing"
3113        );
3114
3115        ctx.stage = EvalStage::Stage(1);
3116        let result = ctx.lookup(var_stage1, loc.clone());
3117        assert!(
3118            result.is_ok(),
3119            "Stage 1 variable should be accessible from stage 1"
3120        );
3121
3122        ctx.stage = EvalStage::Stage(0);
3123        let result = ctx.lookup(var_stage1, loc.clone());
3124        assert!(
3125            result.is_err(),
3126            "Stage 1 variable should not be accessible from stage 0"
3127        );
3128
3129        // Clean up scopes
3130        ctx.env.to_outer();
3131        ctx.env.to_outer();
3132    }
3133
3134    #[test]
3135    fn test_qualified_var_mangling() {
3136        use crate::compiler;
3137
3138        let src = r#"
3139mod mymath {
3140    pub fn add(x, y) {
3141        x + y
3142    }
3143}
3144
3145fn dsp() {
3146    mymath::add(1.0, 2.0)
3147}
3148"#;
3149        // Use the compiler context to process the code through the full pipeline
3150        // (which includes convert_qualified_names before type checking)
3151        let empty_ext_fns: Vec<compiler::ExtFunTypeInfo> = vec![];
3152        let empty_macros: Vec<Box<dyn crate::plugin::MacroFunction>> = vec![];
3153        let ctx = compiler::Context::new(
3154            empty_ext_fns,
3155            empty_macros,
3156            Some(std::path::PathBuf::from("test")),
3157            compiler::Config::default(),
3158        );
3159        let result = ctx.emit_mir(src);
3160
3161        // Check for compilation errors
3162        assert!(result.is_ok(), "Compilation failed: {:?}", result.err());
3163    }
3164
3165    #[test]
3166    fn test_qualified_var_mir_generation() {
3167        use crate::compiler;
3168
3169        let src = r#"
3170mod mymath {
3171    pub fn add(x, y) {
3172        x + y
3173    }
3174}
3175
3176fn dsp() {
3177    mymath::add(1.0, 2.0)
3178}
3179"#;
3180        // Use the compiler context to generate MIR
3181        let empty_ext_fns: Vec<compiler::ExtFunTypeInfo> = vec![];
3182        let empty_macros: Vec<Box<dyn crate::plugin::MacroFunction>> = vec![];
3183        let ctx = compiler::Context::new(
3184            empty_ext_fns,
3185            empty_macros,
3186            Some(std::path::PathBuf::from("test")),
3187            compiler::Config::default(),
3188        );
3189        let result = ctx.emit_mir(src);
3190
3191        // Check for compilation errors
3192        assert!(result.is_ok(), "MIR generation failed: {:?}", result.err());
3193    }
3194
3195    #[test]
3196    fn test_macro_return_record_missing_field_reports_type_error() {
3197        use crate::compiler;
3198
3199        let src = r#"
3200pub type alias Note = {v:float, gate:float}
3201
3202#stage(macro)
3203fn make_note()->`Note{
3204    `({v = 60.0, gate = 1.0})
3205}
3206
3207fn dsp(){
3208    let note = make_note!()
3209    note.val
3210}
3211"#;
3212
3213        let empty_ext_fns: Vec<compiler::ExtFunTypeInfo> = vec![];
3214        let empty_macros: Vec<Box<dyn crate::plugin::MacroFunction>> = vec![];
3215        let ctx = compiler::Context::new(
3216            empty_ext_fns,
3217            empty_macros,
3218            Some(std::path::PathBuf::from("test")),
3219            compiler::Config::default(),
3220        );
3221        let result = ctx.emit_mir(src);
3222
3223        assert!(
3224            result.is_err(),
3225            "Compilation should fail for missing record field access"
3226        );
3227
3228        let errors = result.err().unwrap();
3229        // NOTE:
3230        // Depending on current inference order, this scenario reports either a
3231        // direct missing-field diagnostic (`Field "val"`) or the more general
3232        // non-record access error. Both indicate the intended regression is
3233        // caught: accessing `note.val` is a type error.
3234        assert!(
3235            errors.iter().any(|e| {
3236                let message = e.get_message();
3237                message.contains("Field \"val\"")
3238                    || message.contains("Field access for non-record variable")
3239            }),
3240            "Expected field access type error for \"val\", got: {:?}",
3241            errors.iter().map(|e| e.get_message()).collect::<Vec<_>>()
3242        );
3243    }
3244
3245    #[test]
3246    fn test_recursive_function_preserves_record_array_width_from_param_annotation() {
3247        use crate::compiler;
3248        use crate::plugin;
3249
3250        let src = r#"
3251pub type alias Arc = {start:float, end:float}
3252pub type alias Event = {arc:Arc, active:Arc, val:float}
3253
3254fn value_at_phase(events:[Event], phase:float, current:float)->float{
3255    if (len(events) > 0.0){
3256        let (head,rest) = events |> split_head
3257        if (phase >= head.arc.start){
3258            value_at_phase(rest, phase, head.val)
3259        }else{
3260            current
3261        }
3262    }else{
3263        current
3264    }
3265}
3266
3267fn dsp(){
3268    let events = [{
3269        arc = {start = 0.0, end = 1.0},
3270        active = {start = 0.0, end = 1.0},
3271        val = 1.0,
3272    }]
3273    value_at_phase(events, 0.0, 0.0)
3274}
3275"#;
3276
3277        let ext_fns = plugin::get_extfun_types(&[plugin::get_builtin_fns_as_plugins()])
3278            .collect::<Vec<_>>();
3279        let macros = plugin::get_macro_functions(&[plugin::get_builtin_fns_as_plugins()])
3280            .collect::<Vec<_>>();
3281        let ctx = compiler::Context::new(
3282            ext_fns,
3283            macros,
3284            Some(std::path::PathBuf::from("test")),
3285            compiler::Config::default(),
3286        );
3287        let result = ctx.emit_mir(src);
3288
3289        assert!(result.is_ok(), "MIR generation failed: {:?}", result.err());
3290
3291        let mir = result.unwrap();
3292        let printed = format!("{mir}");
3293        let signature_line = printed
3294            .lines()
3295            .find(|line| line.starts_with("fn value_at_phase ["))
3296            .expect("value_at_phase should be present in MIR");
3297
3298        assert!(
3299            signature_line.contains("active") && signature_line.contains("end:number"),
3300            "record-array parameter width should keep full Event shape, got: {signature_line}"
3301        );
3302        assert!(
3303            printed.contains("split_head$arity5"),
3304            "split_head should specialize for full Event width, got MIR:\n{printed}"
3305        );
3306    }
3307
3308    #[test]
3309    fn test_imported_staging_preserves_record_array_width() {
3310        use crate::compiler;
3311        use crate::plugin;
3312        use std::fs;
3313
3314        let repo_root = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../..");
3315        let fixture_root = repo_root.join("tmp/staging_import_record_array_regression");
3316        let fixture_lib_dir = fixture_root.join("lib");
3317        let fixture_main = fixture_root.join("main.mmm");
3318        let fixture_module = fixture_lib_dir.join("pattern_like.mmm");
3319
3320        fs::create_dir_all(&fixture_lib_dir).expect("fixture lib dir should be created");
3321        fs::write(
3322            &fixture_module,
3323            r#"
3324use osc::phasor
3325
3326pub type alias Arc = {start:float, end:float}
3327pub type alias Event = {arc:Arc, active:Arc, val:float}
3328
3329#stage(main)
3330fn value_at_phase(events:[Event], phase:float, current:float)->float{
3331    if ((events |> len) > 0.0){
3332        let (head,rest) = events |> split_head
3333        if (phase >= head.arc.start){
3334            value_at_phase(rest, phase, head.val)
3335        }else{
3336            current
3337        }
3338    }else{
3339        current
3340    }
3341}
3342
3343#stage(macro)
3344pub fn run_value()->`float{
3345    let events = [{
3346        arc = {start = 0.0, end = 1.0},
3347        active = {start = 0.0, end = 1.0},
3348        val = 60.0,
3349    }]
3350    `{
3351        let phase = phasor(0.5, 0.0)
3352        let v = value_at_phase($(events |> lift), phase, 0.0)
3353        v
3354    }
3355}
3356"#,
3357        )
3358        .expect("fixture module should be written");
3359        fs::write(
3360            &fixture_main,
3361            r#"
3362use pattern_like::*
3363
3364fn dsp(){
3365    let value = run_value!()
3366    (value, value)
3367}
3368"#,
3369        )
3370        .expect("fixture main should be written");
3371
3372        let src = fs::read_to_string(&fixture_main).expect("fixture main should be readable");
3373        let (_ast, module_info, parse_errs) = crate::compiler::parser::parse_to_expr(
3374            &src,
3375            Some(fixture_main.clone()),
3376        );
3377        assert!(parse_errs.is_empty(), "fixture should parse cleanly");
3378        assert!(
3379            module_info
3380                .type_aliases
3381                .keys()
3382                .any(|name| name.as_str() == "pattern_like$Event"),
3383            "imported module type aliases should contain pattern_like$Event, got: {:?}",
3384            module_info
3385                .type_aliases
3386                .keys()
3387                .map(|name| name.as_str().to_string())
3388                .collect::<Vec<_>>()
3389        );
3390        let ext_fns = plugin::get_extfun_types(&[plugin::get_builtin_fns_as_plugins()])
3391            .collect::<Vec<_>>();
3392        let macros = plugin::get_macro_functions(&[plugin::get_builtin_fns_as_plugins()])
3393            .collect::<Vec<_>>();
3394        let ctx = compiler::Context::new(
3395            ext_fns,
3396            macros,
3397            Some(fixture_main.clone()),
3398            compiler::Config::default(),
3399        );
3400        let result = ctx.emit_mir(&src);
3401
3402        assert!(result.is_ok(), "MIR generation failed: {:?}", result.err());
3403
3404        let printed = format!("{}", result.unwrap());
3405        let signature_line = printed
3406            .lines()
3407            .find(|line| line.starts_with("fn pattern_like$value_at_phase ["))
3408            .expect("imported value_at_phase should be present in MIR");
3409
3410        assert!(
3411            signature_line.contains("active") && signature_line.contains("end:number"),
3412            "imported staged record-array parameter width should keep full Event shape, got: {signature_line}"
3413        );
3414        assert!(
3415            printed.contains("split_head$arity5"),
3416            "imported staged split_head should specialize for full Event width, got MIR:\n{printed}"
3417        );
3418    }
3419}