mimium_lang/compiler/
typing.rs

1use crate::ast::{Expr, Literal, RecordField};
2use crate::compiler::{EvalStage, intrinsics};
3use crate::interner::{ExprKey, ExprNodeId, Symbol, ToSymbol, TypeNodeId};
4use crate::pattern::{Pattern, TypedPattern};
5use crate::types::{IntermediateId, PType, RecordTypeField, Type, TypeSchemeId, TypeVar};
6use crate::utils::metadata::Location;
7use crate::utils::{environment::Environment, error::ReportableError};
8use crate::{function, integer, numeric, unit};
9use itertools::Itertools;
10use std::cell::RefCell;
11use std::collections::BTreeMap;
12use std::fmt;
13use std::path::PathBuf;
14use std::rc::Rc;
15use std::sync::{Arc, Mutex, RwLock};
16
17mod unification;
18use unification::{Error as UnificationError, Relation, unify_types};
19
20#[derive(Clone, Debug)]
21pub enum Error {
22    TypeMismatch {
23        left: (TypeNodeId, Location),
24        right: (TypeNodeId, Location),
25    },
26    LengthMismatch {
27        left: (usize, Location),
28        right: (usize, Location),
29    },
30    PatternMismatch((TypeNodeId, Location), (Pattern, Location)),
31    NonFunctionForLetRec(TypeNodeId, Location),
32    NonFunctionForApply(TypeNodeId, Location),
33    NonSupertypeArgument {
34        location: Location,
35        expected: TypeNodeId,
36        found: TypeNodeId,
37    },
38    CircularType(Location, Location),
39    IndexOutOfRange {
40        len: u16,
41        idx: u16,
42        loc: Location,
43    },
44    IndexForNonTuple(Location, TypeNodeId),
45    FieldForNonRecord(Location, TypeNodeId),
46    FieldNotExist {
47        field: Symbol,
48        loc: Location,
49        et: TypeNodeId,
50    },
51    DuplicateKeyInRecord {
52        key: Vec<Symbol>,
53        loc: Location,
54    },
55    DuplicateKeyInParams(Vec<(Symbol, Location)>),
56    // The error of records, which contains both subtypes and supertypes.
57    IncompatibleKeyInRecord {
58        left: (Vec<(Symbol, TypeNodeId)>, Location),
59        right: (Vec<(Symbol, TypeNodeId)>, Location),
60    },
61    VariableNotFound(Symbol, Location),
62    StageMismatch {
63        variable: Symbol,
64        expected_stage: EvalStage,
65        found_stage: EvalStage,
66        location: Location,
67    },
68    NonPrimitiveInFeed(Location),
69}
70impl fmt::Display for Error {
71    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72        write!(f, "Type Inference Error")
73    }
74}
75
76impl std::error::Error for Error {}
77impl ReportableError for Error {
78    fn get_message(&self) -> String {
79        match self {
80            Error::TypeMismatch { .. } => format!("Type mismatch"),
81            Error::PatternMismatch(..) => format!("Pattern mismatch"),
82            Error::LengthMismatch { .. } => format!("Length of the elements are different"),
83            Error::NonFunctionForLetRec(_, _) => format!("`letrec` can take only function type."),
84            Error::NonFunctionForApply(_, _) => {
85                format!("This is not applicable because it is not a function type.")
86            }
87            Error::CircularType(_, _) => format!("Circular loop of type definition detected."),
88            Error::IndexOutOfRange { len, idx, .. } => {
89                format!("Length of tuple elements is {len} but index was {idx}")
90            }
91            Error::IndexForNonTuple(_, _) => {
92                format!("Index access for non-tuple variable.")
93            }
94            Error::VariableNotFound(symbol, _) => {
95                format!("Variable \"{symbol}\" not found in this scope")
96            }
97            Error::StageMismatch {
98                variable,
99                expected_stage,
100                found_stage,
101                ..
102            } => {
103                format!(
104                    "Variable {variable} is defined in stage {} but accessed from stage {}",
105                    found_stage.format_for_error(),
106                    expected_stage.format_for_error()
107                )
108            }
109            Error::NonPrimitiveInFeed(_) => {
110                format!("Function that uses `self` cannot return function type.")
111            }
112            Error::DuplicateKeyInParams { .. } => {
113                format!("Duplicate keys found in parameter list")
114            }
115            Error::DuplicateKeyInRecord { .. } => {
116                format!("Duplicate keys found in record type")
117            }
118            Error::FieldForNonRecord { .. } => {
119                format!("Field access for non-record variable.")
120            }
121            Error::FieldNotExist { field, .. } => {
122                format!("Field \"{field}\" does not exist in the record type")
123            }
124            Error::IncompatibleKeyInRecord { .. } => {
125                format!("Record type has incompatible keys.",)
126            }
127
128            Error::NonSupertypeArgument { .. } => {
129                format!("Arguments for functions are less than required.")
130            }
131        }
132    }
133    fn get_labels(&self) -> Vec<(Location, String)> {
134        match self {
135            Error::TypeMismatch {
136                left: (lty, locl),
137                right: (rty, locr),
138            } => vec![
139                (locl.clone(), lty.to_type().to_string_for_error()),
140                (locr.clone(), rty.to_type().to_string_for_error()),
141            ],
142            Error::PatternMismatch((ty, loct), (pat, locp)) => vec![
143                (loct.clone(), ty.to_type().to_string_for_error()),
144                (locp.clone(), pat.to_string()),
145            ],
146            Error::LengthMismatch {
147                left: (l, locl),
148                right: (r, locr),
149            } => vec![
150                (locl.clone(), format!("The length is {l}")),
151                (locr.clone(), format!("but the length for here is {r}")),
152            ],
153            Error::NonFunctionForLetRec(ty, loc) => {
154                vec![(loc.clone(), ty.to_type().to_string_for_error())]
155            }
156            Error::NonFunctionForApply(ty, loc) => {
157                vec![(loc.clone(), ty.to_type().to_string_for_error())]
158            }
159            Error::CircularType(loc1, loc2) => vec![
160                (loc1.clone(), format!("Circular type happens here")),
161                (loc2.clone(), format!("and here")),
162            ],
163            Error::IndexOutOfRange { loc, len, .. } => {
164                vec![(loc.clone(), format!("Length for this tuple is {len}"))]
165            }
166            Error::IndexForNonTuple(loc, ty) => {
167                vec![(
168                    loc.clone(),
169                    format!(
170                        "This is not tuple type but {}",
171                        ty.to_type().to_string_for_error()
172                    ),
173                )]
174            }
175            Error::VariableNotFound(symbol, loc) => {
176                vec![(loc.clone(), format!("{symbol} is not defined"))]
177            }
178            Error::StageMismatch {
179                variable,
180                expected_stage,
181                found_stage,
182                location,
183            } => {
184                vec![(
185                    location.clone(),
186                    format!(
187                        "Variable \"{variable}\" defined in stage {} cannot be accessed from stage {}",
188                        found_stage.format_for_error(),
189                        expected_stage.format_for_error()
190                    ),
191                )]
192            }
193            Error::NonPrimitiveInFeed(loc) => {
194                vec![(loc.clone(), format!("This cannot be function type."))]
195            }
196            Error::DuplicateKeyInRecord { key, loc } => {
197                vec![(
198                    loc.clone(),
199                    format!(
200                        "Duplicate keys \"{}\" found in record type",
201                        key.iter()
202                            .map(|s| s.to_string())
203                            .collect::<Vec<_>>()
204                            .join(", ")
205                    ),
206                )]
207            }
208            Error::DuplicateKeyInParams(keys) => keys
209                .iter()
210                .map(|(key, loc)| {
211                    (
212                        loc.clone(),
213                        format!("Duplicate key \"{key}\" found in parameter list"),
214                    )
215                })
216                .collect(),
217            Error::FieldForNonRecord(location, ty) => {
218                vec![(
219                    location.clone(),
220                    format!(
221                        "Field access for non-record type {}.",
222                        ty.to_type().to_string_for_error()
223                    ),
224                )]
225            }
226            Error::FieldNotExist { field, loc, et } => vec![(
227                loc.clone(),
228                format!(
229                    "Field \"{}\" does not exist in the type {}",
230                    field,
231                    et.to_type().to_string_for_error()
232                ),
233            )],
234            Error::IncompatibleKeyInRecord {
235                left: (left, lloc),
236                right: (right, rloc),
237            } => {
238                vec![
239                    (
240                        lloc.clone(),
241                        format!(
242                            "the record here contains{}",
243                            left.iter()
244                                .map(|(key, ty)| format!(
245                                    " \"{key}\":{}",
246                                    ty.to_type().to_string_for_error()
247                                ))
248                                .collect::<Vec<_>>()
249                                .join(", ")
250                        ),
251                    ),
252                    (
253                        rloc.clone(),
254                        format!(
255                            "but the record here contains {}",
256                            right
257                                .iter()
258                                .map(|(key, ty)| format!(
259                                    " \"{key}\":{}",
260                                    ty.to_type().to_string_for_error()
261                                ))
262                                .collect::<Vec<_>>()
263                                .join(", ")
264                        ),
265                    ),
266                ]
267            }
268
269            Error::NonSupertypeArgument {
270                location,
271                expected,
272                found,
273            } => {
274                vec![(
275                    location.clone(),
276                    format!(
277                        "Type {} is not a supertype of the expected type {}",
278                        found.to_type().to_string_for_error(),
279                        expected.to_type().to_string_for_error()
280                    ),
281                )]
282            }
283        }
284    }
285}
286
287#[derive(Clone, Debug)]
288pub struct InferContext {
289    interm_idx: IntermediateId,
290    typescheme_idx: TypeSchemeId,
291    level: u64,
292    stage: EvalStage,
293    instantiated_map: BTreeMap<TypeSchemeId, TypeNodeId>, //from type scheme to typevar
294    generalize_map: BTreeMap<IntermediateId, TypeSchemeId>,
295    result_memo: BTreeMap<ExprKey, TypeNodeId>,
296    file_path: PathBuf,
297    pub env: Environment<(TypeNodeId, EvalStage)>,
298    pub errors: Vec<Error>,
299}
300impl InferContext {
301    fn new(builtins: &[(Symbol, TypeNodeId)], file_path: PathBuf) -> Self {
302        let mut res = Self {
303            interm_idx: Default::default(),
304            typescheme_idx: Default::default(),
305            level: Default::default(),
306            stage: EvalStage::Stage(0), // Start at stage 0
307            instantiated_map: Default::default(),
308            generalize_map: Default::default(),
309            result_memo: Default::default(),
310            file_path,
311            env: Environment::<(TypeNodeId, EvalStage)>::default(),
312            errors: Default::default(),
313        };
314        res.env.extend();
315        // Intrinsic types are persistent (available at all stages)
316        let intrinsics = Self::intrinsic_types()
317            .into_iter()
318            .map(|(name, ty)| (name, (ty, EvalStage::Persistent)))
319            .collect::<Vec<_>>();
320        res.env.add_bind(&intrinsics);
321        // Builtins are also persistent
322        let builtins = builtins
323            .iter()
324            .map(|(name, ty)| (*name, (*ty, EvalStage::Persistent)))
325            .collect::<Vec<_>>();
326        res.env.add_bind(&builtins);
327        res
328    }
329}
330impl InferContext {
331    fn intrinsic_types() -> Vec<(Symbol, TypeNodeId)> {
332        let binop_ty = function!(vec![numeric!(), numeric!()], numeric!());
333        let binop_names = [
334            intrinsics::ADD,
335            intrinsics::SUB,
336            intrinsics::MULT,
337            intrinsics::DIV,
338            intrinsics::MODULO,
339            intrinsics::POW,
340            intrinsics::GT,
341            intrinsics::LT,
342            intrinsics::GE,
343            intrinsics::LE,
344            intrinsics::EQ,
345            intrinsics::NE,
346            intrinsics::AND,
347            intrinsics::OR,
348        ];
349        let uniop_ty = function!(vec![numeric!()], numeric!());
350        let uniop_names = [
351            intrinsics::NEG,
352            intrinsics::MEM,
353            intrinsics::SIN,
354            intrinsics::COS,
355            intrinsics::ABS,
356            intrinsics::LOG,
357            intrinsics::SQRT,
358        ];
359
360        let binds = binop_names.map(|n| (n.to_symbol(), binop_ty));
361        let unibinds = uniop_names.map(|n| (n.to_symbol(), uniop_ty));
362        [
363            (
364                intrinsics::DELAY.to_symbol(),
365                function!(vec![numeric!(), numeric!(), numeric!()], numeric!()),
366            ),
367            (
368                intrinsics::TOFLOAT.to_symbol(),
369                function!(vec![integer!()], numeric!()),
370            ),
371        ]
372        .into_iter()
373        .chain(binds)
374        .chain(unibinds)
375        .collect()
376    }
377    fn unwrap_result(&mut self, res: Result<TypeNodeId, Vec<Error>>) -> TypeNodeId {
378        match res {
379            Ok(t) => t,
380            Err(mut e) => {
381                let loc = &e[0].get_labels()[0].0; //todo
382                self.errors.append(&mut e);
383                Type::Failure.into_id_with_location(loc.clone())
384            }
385        }
386    }
387    fn get_typescheme(&mut self, tvid: IntermediateId, loc: Location) -> TypeNodeId {
388        self.generalize_map.get(&tvid).cloned().map_or_else(
389            || self.gen_typescheme(loc),
390            |id| Type::TypeScheme(id).into_id(),
391        )
392    }
393    fn gen_typescheme(&mut self, loc: Location) -> TypeNodeId {
394        let res = Type::TypeScheme(self.typescheme_idx).into_id_with_location(loc);
395        self.typescheme_idx.0 += 1;
396        res
397    }
398
399    fn gen_intermediate_type_with_location(&mut self, loc: Location) -> TypeNodeId {
400        let res = Type::Intermediate(Arc::new(RwLock::new(TypeVar::new(
401            self.interm_idx,
402            self.level,
403        ))))
404        .into_id_with_location(loc);
405        self.interm_idx.0 += 1;
406        res
407    }
408    fn convert_unknown_to_intermediate(&mut self, t: TypeNodeId, loc: Location) -> TypeNodeId {
409        match t.to_type() {
410            Type::Unknown => self.gen_intermediate_type_with_location(loc.clone()),
411            _ => t.apply_fn(|t| self.convert_unknown_to_intermediate(t, loc.clone())),
412        }
413    }
414    fn convert_unify_error(&self, e: UnificationError) -> Error {
415        let gen_loc = |span| Location::new(span, self.file_path.clone());
416        match e {
417            UnificationError::TypeMismatch { left, right } => Error::TypeMismatch {
418                left: (left, gen_loc(left.to_span())),
419                right: (right, gen_loc(right.to_span())),
420            },
421            UnificationError::LengthMismatch {
422                left: (left, lspan),
423                right: (right, rspan),
424            } => Error::LengthMismatch {
425                left: (left.len(), gen_loc(lspan)),
426                right: (right.len(), gen_loc(rspan)),
427            },
428            UnificationError::CircularType { left, right } => {
429                Error::CircularType(gen_loc(left), gen_loc(right))
430            }
431            UnificationError::ImcompatibleRecords {
432                left: (left, lspan),
433                right: (right, rspan),
434            } => Error::IncompatibleKeyInRecord {
435                left: (left, gen_loc(lspan)),
436                right: (right, gen_loc(rspan)),
437            },
438        }
439    }
440    fn unify_types(&self, t1: TypeNodeId, t2: TypeNodeId) -> Result<Relation, Vec<Error>> {
441        unify_types(t1, t2)
442            .map_err(|e| e.into_iter().map(|e| self.convert_unify_error(e)).collect())
443    }
444    // helper function
445    fn merge_rel_result(
446        &self,
447        rel1: Result<Relation, Vec<Error>>,
448        rel2: Result<Relation, Vec<Error>>,
449        t1: TypeNodeId,
450        t2: TypeNodeId,
451    ) -> Result<(), Vec<Error>> {
452        match (rel1, rel2) {
453            (Ok(Relation::Identical), Ok(Relation::Identical)) => Ok(()),
454            (Ok(_), Ok(_)) => Err(vec![Error::TypeMismatch {
455                left: (t1, Location::new(t1.to_span(), self.file_path.clone())),
456                right: (t2, Location::new(t2.to_span(), self.file_path.clone())),
457            }]),
458            (Err(e1), Err(e2)) => Err(e1.into_iter().chain(e2).collect()),
459            (Err(e), _) | (_, Err(e)) => Err(e),
460        }
461    }
462    pub fn substitute_type(t: TypeNodeId) -> TypeNodeId {
463        match t.to_type() {
464            Type::Intermediate(cell) => {
465                let TypeVar { parent, .. } = &*cell.read().unwrap() as &TypeVar;
466                match parent {
467                    Some(p) => Self::substitute_type(*p),
468                    None => Type::Unknown.into_id_with_location(t.to_loc()),
469                }
470            }
471            _ => t.apply_fn(Self::substitute_type),
472        }
473    }
474    fn substitute_all_intermediates(&mut self) {
475        let mut e_list = self
476            .result_memo
477            .iter()
478            .map(|(e, t)| (*e, Self::substitute_type(*t)))
479            .collect::<Vec<_>>();
480
481        e_list.iter_mut().for_each(|(e, t)| {
482            log::trace!("e: {:?} t: {}", e, t.to_type());
483            let _old = self.result_memo.insert(*e, *t);
484        })
485    }
486
487    fn generalize(&mut self, t: TypeNodeId) -> TypeNodeId {
488        match t.to_type() {
489            Type::Intermediate(tvar) => {
490                let &TypeVar { level, var, .. } = &*tvar.read().unwrap() as &TypeVar;
491                if level > self.level {
492                    self.get_typescheme(var, t.to_loc())
493                } else {
494                    t
495                }
496            }
497            _ => t.apply_fn(|t| self.generalize(t)),
498        }
499    }
500    fn instantiate(&mut self, t: TypeNodeId) -> TypeNodeId {
501        match t.to_type() {
502            Type::TypeScheme(id) => {
503                if let Some(tvar) = self.instantiated_map.get(&id) {
504                    *tvar
505                } else {
506                    let res = self.gen_intermediate_type_with_location(t.to_loc());
507                    self.instantiated_map.insert(id, res);
508                    res
509                }
510            }
511            _ => t.apply_fn(|t| self.instantiate(t)),
512        }
513    }
514
515    // Note: the third argument `span` is used for the error location in case of
516    // type mismatch. This is needed because `t`'s span refers to the location
517    // where it originally defined (e.g. the explicit return type of the
518    // function) and is not necessarily the same as where the error happens.
519    fn bind_pattern(
520        &mut self,
521        pat: (TypedPattern, Location),
522        body: (TypeNodeId, Location),
523    ) -> Result<TypeNodeId, Vec<Error>> {
524        let (TypedPattern { pat, ty, .. }, loc_p) = pat;
525        let (body_t, loc_b) = body.clone();
526        let mut bind_item = |pat| {
527            let newloc = ty.to_loc();
528            let ity = self.gen_intermediate_type_with_location(newloc.clone());
529            let p = TypedPattern::new(pat, ity);
530            self.bind_pattern((p, newloc.clone()), (ity, newloc))
531        };
532        let pat_t = match pat {
533            Pattern::Single(id) => {
534                let pat_t = self.convert_unknown_to_intermediate(ty, loc_p);
535                log::trace!("bind {} : {}", id, pat_t.to_type().to_string());
536                self.env.add_bind(&[(id, (pat_t, self.stage))]);
537                Ok::<TypeNodeId, Vec<Error>>(pat_t)
538            }
539            Pattern::Tuple(pats) => {
540                let elems = pats.iter().map(|p| bind_item(p.clone())).try_collect()?; //todo multiple errors
541                let res = Type::Tuple(elems).into_id_with_location(loc_p);
542                let target = self.convert_unknown_to_intermediate(ty, loc_b);
543                let rel = self.unify_types(res, target)?;
544                Ok(res)
545            }
546            Pattern::Record(items) => {
547                let res = items
548                    .iter()
549                    .map(|(key, v)| {
550                        bind_item(v.clone()).map(|ty| RecordTypeField {
551                            key: *key,
552                            ty,
553                            has_default: false,
554                        })
555                    })
556                    .try_collect()?; //todo multiple errors
557                let res = Type::Record(res).into_id_with_location(loc_p);
558                let target = self.convert_unknown_to_intermediate(ty, loc_b);
559                let rel = self.unify_types(res, target)?;
560                Ok(res)
561            }
562            Pattern::Error => Err(vec![Error::PatternMismatch(
563                (
564                    Type::Failure.into_id_with_location(loc_p.clone()),
565                    loc_b.clone(),
566                ),
567                (pat, loc_p.clone()),
568            )]),
569        }?;
570        let rel = self.unify_types(pat_t, body_t)?;
571        Ok(self.generalize(pat_t))
572    }
573
574    pub fn lookup(&self, name: Symbol, loc: Location) -> Result<TypeNodeId, Error> {
575        use crate::utils::environment::LookupRes;
576        match self.env.lookup_cls(&name) {
577            LookupRes::Local((ty, bound_stage)) if self.stage == *bound_stage => Ok(*ty),
578            LookupRes::UpValue(_, (ty, bound_stage)) if self.stage == *bound_stage => Ok(*ty),
579            LookupRes::Global((ty, bound_stage))
580                if self.stage == *bound_stage || *bound_stage == EvalStage::Persistent =>
581            {
582                Ok(*ty)
583            }
584            LookupRes::None => Err(Error::VariableNotFound(name, loc)),
585            LookupRes::Local((_, bound_stage))
586            | LookupRes::UpValue(_, (_, bound_stage))
587            | LookupRes::Global((_, bound_stage)) => Err(Error::StageMismatch {
588                variable: name,
589                expected_stage: self.stage,
590                found_stage: *bound_stage,
591                location: loc,
592            }),
593        }
594    }
595    pub(crate) fn infer_type_literal(e: &Literal, loc: Location) -> Result<TypeNodeId, Error> {
596        let pt = match e {
597            Literal::Float(_) | Literal::Now | Literal::SampleRate => PType::Numeric,
598            Literal::Int(_s) => PType::Int,
599            Literal::String(_s) => PType::String,
600            Literal::SelfLit => panic!("\"self\" should not be shown at type inference stage"),
601            Literal::PlaceHolder => panic!("\"_\" should not be shown at type inference stage"),
602        };
603        Ok(Type::Primitive(pt).into_id_with_location(loc))
604    }
605    fn infer_vec(&mut self, e: &[ExprNodeId]) -> Result<Vec<TypeNodeId>, Vec<Error>> {
606        e.iter().map(|e| self.infer_type(*e)).try_collect()
607    }
608    fn infer_type_levelup(&mut self, e: ExprNodeId) -> TypeNodeId {
609        self.level += 1;
610        let res = self.infer_type_unwrapping(e);
611        self.level -= 1;
612        res
613    }
614    pub fn infer_type(&mut self, e: ExprNodeId) -> Result<TypeNodeId, Vec<Error>> {
615        if let Some(r) = self.result_memo.get(&e.0) {
616            //use cached result
617            return Ok(*r);
618        }
619        let loc = e.to_location();
620        let res: Result<TypeNodeId, Vec<Error>> = match &e.to_expr() {
621            Expr::Literal(l) => Self::infer_type_literal(l, loc).map_err(|e| vec![e]),
622            Expr::Tuple(e) => {
623                Ok(Type::Tuple(self.infer_vec(e.as_slice())?).into_id_with_location(loc))
624            }
625            Expr::ArrayLiteral(e) => {
626                let elem_types = self.infer_vec(e.as_slice())?;
627                let first = elem_types
628                    .first()
629                    .copied()
630                    .unwrap_or(Type::Unknown.into_id_with_location(loc.clone()));
631                //todo:collect multiple errors
632                let elem_t = elem_types
633                    .iter()
634                    .try_fold(first, |acc, t| self.unify_types(acc, *t).map(|rel| *t))?;
635
636                Ok(Type::Array(elem_t).into_id_with_location(loc.clone()))
637            }
638            Expr::ArrayAccess(e, idx) => {
639                let arr_t = self.infer_type_unwrapping(*e);
640                let loc_e = e.to_location();
641                let idx_t = self.infer_type_unwrapping(*idx);
642                let loc_i = idx.to_location();
643
644                let elem_t = self.gen_intermediate_type_with_location(loc_e.clone());
645
646                let rel1 = self.unify_types(
647                    idx_t,
648                    Type::Primitive(PType::Numeric).into_id_with_location(loc_i),
649                );
650                let rel2 = self.unify_types(
651                    Type::Array(elem_t).into_id_with_location(loc_e.clone()),
652                    arr_t,
653                );
654                let _ = self.merge_rel_result(rel1, rel2, arr_t, idx_t)?;
655                Ok(elem_t)
656            }
657            Expr::Proj(e, idx) => {
658                let tup = self.infer_type_unwrapping(*e);
659                // we directly inspect if the intermediate type is a tuple or not.
660                // this is because we can not infer the number of fields in the tuple from the fields access expression.
661                // This rule will be loosened when structural subtyping is implemented.
662                let vec_to_ans = |vec: &[_]| {
663                    if vec.len() < *idx as usize {
664                        Err(vec![Error::IndexOutOfRange {
665                            len: vec.len() as u16,
666                            idx: *idx as u16,
667                            loc: loc.clone(),
668                        }])
669                    } else {
670                        Ok(vec[*idx as usize])
671                    }
672                };
673                match tup.to_type() {
674                    Type::Tuple(vec) => vec_to_ans(&vec),
675                    Type::Intermediate(tv) => {
676                        let tv = tv.read().unwrap();
677                        if let Some(parent) = tv.parent {
678                            match parent.to_type() {
679                                Type::Tuple(vec) => vec_to_ans(&vec),
680                                _ => Err(vec![Error::IndexForNonTuple(loc, tup)]),
681                            }
682                        } else {
683                            Err(vec![Error::IndexForNonTuple(loc, tup)])
684                        }
685                    }
686                    _ => Err(vec![Error::IndexForNonTuple(loc, tup)]),
687                }
688            }
689            Expr::RecordLiteral(kvs) => {
690                let duplicate_keys = kvs
691                    .iter()
692                    .map(|RecordField { name, .. }| *name)
693                    .duplicates();
694                if duplicate_keys.clone().count() > 0 {
695                    Err(vec![Error::DuplicateKeyInRecord {
696                        key: duplicate_keys.collect(),
697                        loc,
698                    }])
699                } else {
700                    let kts: Vec<_> = kvs
701                        .iter()
702                        .map(|RecordField { name, expr }| {
703                            let ty = self.infer_type_unwrapping(*expr);
704                            RecordTypeField {
705                                key: *name,
706                                ty,
707                                has_default: true,
708                            }
709                        })
710                        .collect();
711                    Ok(Type::Record(kts).into_id_with_location(loc))
712                }
713            }
714            Expr::RecordUpdate(_, _) => {
715                // RecordUpdate should never reach type inference as it gets expanded
716                // to Block/Let/Assign expressions during syntax sugar conversion in convert_pronoun.rs
717                unreachable!("RecordUpdate should be expanded before type inference")
718            }
719            Expr::FieldAccess(expr, field) => {
720                let et = self.infer_type_unwrapping(*expr);
721                log::trace!("field access {} : {}", field, et.to_type());
722                let fields_to_ans = |fields: &[RecordTypeField]| {
723                    fields
724                        .iter()
725                        .find_map(
726                            |RecordTypeField { key, ty, .. }| {
727                                if *key == *field { Some(*ty) } else { None }
728                            },
729                        )
730                        .ok_or_else(|| {
731                            vec![Error::FieldNotExist {
732                                field: *field,
733                                loc: loc.clone(),
734                                et,
735                            }]
736                        })
737                };
738                // we directly inspect if the intermediate type is a record or not.
739                // this is because we can not infer the number of fields in the record from the fields access expression.
740                // This rule will be loosened when structural subtyping is implemented.
741                match et.to_type() {
742                    Type::Record(fields) => fields_to_ans(&fields),
743                    Type::Intermediate(tv) => {
744                        let tv = tv.read().unwrap();
745                        if let Some(parent) = tv.parent {
746                            match parent.to_type() {
747                                Type::Record(fields) => fields_to_ans(&fields),
748                                _ => Err(vec![Error::FieldForNonRecord(loc, et)]),
749                            }
750                        } else {
751                            Err(vec![Error::FieldForNonRecord(loc, et)])
752                        }
753                    }
754                    _ => Err(vec![Error::FieldForNonRecord(loc, et)]),
755                }
756            }
757            Expr::Feed(id, body) => {
758                //todo: add span to Feed expr for keeping the location of `self`.
759                let feedv = self.gen_intermediate_type_with_location(loc);
760
761                self.env.add_bind(&[(*id, (feedv, self.stage))]);
762                let bty = self.infer_type_unwrapping(*body);
763                let _rel = self.unify_types(bty, feedv)?;
764                if bty.to_type().contains_function() {
765                    Err(vec![Error::NonPrimitiveInFeed(body.to_location())])
766                } else {
767                    Ok(bty)
768                }
769            }
770            Expr::Lambda(p, rtype, body) => {
771                self.env.extend();
772                let dup = p.iter().duplicates_by(|id| id.id).map(|id| {
773                    let loc = Location::new(id.to_span(), self.file_path.clone());
774                    (id.id, loc)
775                });
776                if dup.clone().count() > 0 {
777                    return Err(vec![Error::DuplicateKeyInParams(dup.collect())]);
778                }
779                let pvec = p
780                    .iter()
781                    .map(|id| {
782                        let ity = self.convert_unknown_to_intermediate(id.ty, id.ty.to_loc());
783                        self.env.add_bind(&[(id.id, (ity, self.stage))]);
784                        RecordTypeField {
785                            key: id.id,
786                            ty: ity,
787                            has_default: false,
788                        }
789                    })
790                    .collect::<Vec<_>>();
791                let ptype = if pvec.is_empty() {
792                    Type::Primitive(PType::Unit).into_id_with_location(loc.clone())
793                } else {
794                    Type::Record(pvec).into_id_with_location(loc.clone())
795                };
796                let bty = if let Some(r) = rtype {
797                    let bty = self.infer_type_unwrapping(*body);
798                    let _rel = self.unify_types(*r, bty)?;
799                    bty
800                } else {
801                    self.infer_type_unwrapping(*body)
802                };
803                self.env.to_outer();
804                Ok(Type::Function {
805                    arg: ptype,
806                    ret: bty,
807                }
808                .into_id_with_location(e.to_location()))
809            }
810            Expr::Let(tpat, body, then) => {
811                let bodyt = self.infer_type_levelup(*body);
812                let loc_p = tpat.to_loc();
813                let loc_b = body.to_location();
814                let pat_t = self.bind_pattern((tpat.clone(), loc_p), (bodyt, loc_b));
815                let _pat_t = self.unwrap_result(pat_t);
816                match then {
817                    Some(e) => self.infer_type(*e),
818                    None => Ok(Type::Primitive(PType::Unit).into_id_with_location(loc)),
819                }
820            }
821            Expr::LetRec(id, body, then) => {
822                let idt = self.convert_unknown_to_intermediate(id.ty, id.ty.to_loc());
823                self.env.add_bind(&[(id.id, (idt, self.stage))]);
824                //polymorphic inference is not allowed in recursive function.
825                let bodyt = self.infer_type_levelup(*body);
826                let _res = self.unify_types(idt, bodyt);
827                match then {
828                    Some(e) => self.infer_type(*e),
829                    None => Ok(Type::Primitive(PType::Unit).into_id_with_location(loc)),
830                }
831            }
832            Expr::Assign(assignee, expr) => {
833                match assignee.to_expr() {
834                    Expr::Var(name) => {
835                        let assignee_t =
836                            self.unwrap_result(self.lookup(name, loc).map_err(|e| vec![e]));
837                        let e_t = self.infer_type_unwrapping(*expr);
838                        let _rel = self.unify_types(assignee_t, e_t)?;
839                        Ok(unit!())
840                    }
841                    Expr::FieldAccess(record, field_name) => {
842                        // Handle field assignment: record.field = value
843                        let record_type = self.infer_type_unwrapping(record);
844                        let value_type = self.infer_type_unwrapping(*expr);
845                        let tmptype = Type::Record(vec![RecordTypeField {
846                            key: field_name,
847                            ty: value_type,
848                            has_default: false,
849                        }])
850                        .into_id();
851                        if self.unify_types(record_type, tmptype)? == Relation::Supertype {
852                            unreachable!(
853                                "record field access for an empty record will not likely to happen."
854                            )
855                        };
856                        Ok(value_type)
857                    }
858                    Expr::ArrayAccess(_, _) => {
859                        unimplemented!("Assignment to array is not implemented yet.")
860                    }
861                    _ => {
862                        // This should be caught by parser, but add a generic error just in case
863                        Err(vec![Error::VariableNotFound(
864                            "invalid_assignment_target".to_symbol(),
865                            loc.clone(),
866                        )])
867                    }
868                }
869            }
870            Expr::Then(e, then) => {
871                let _ = self.infer_type(*e)?;
872                then.map_or(Ok(unit!()), |t| self.infer_type(t))
873            }
874            Expr::Var(name) => {
875                let res = self.unwrap_result(self.lookup(*name, loc).map_err(|e| vec![e]));
876                // log::trace!("{} {} /level{}", name.as_str(), res.to_type(), self.level);
877                Ok(self.instantiate(res))
878            }
879            Expr::Apply(fun, callee) => {
880                let loc_f = fun.to_location();
881                let fnl = self.infer_type_unwrapping(*fun);
882                let callee_t = match callee.len() {
883                    0 => Type::Primitive(PType::Unit).into_id_with_location(loc.clone()),
884                    1 => self.infer_type_unwrapping(callee[0]),
885                    _ => {
886                        let at_vec = self.infer_vec(callee.as_slice())?;
887                        let span = callee[0].to_span().start..callee.last().unwrap().to_span().end;
888                        let loc = Location::new(span, self.file_path.clone());
889                        Type::Tuple(at_vec).into_id_with_location(loc)
890                    }
891                };
892                let res_t = self.gen_intermediate_type_with_location(loc);
893                let fntype = Type::Function {
894                    arg: callee_t,
895                    ret: res_t,
896                }
897                .into_id_with_location(loc_f.clone());
898                match self.unify_types(fnl, fntype)? {
899                    Relation::Subtype => Err(vec![Error::NonSupertypeArgument {
900                        location: loc_f.clone(),
901                        expected: fnl,
902                        found: fntype,
903                    }]),
904                    _ => Ok(res_t),
905                }
906            }
907            Expr::If(cond, then, opt_else) => {
908                let condt = self.infer_type_unwrapping(*cond);
909                let cond_loc = cond.to_location();
910                let bt = self.unify_types(
911                    Type::Primitive(PType::Numeric).into_id_with_location(cond_loc),
912                    condt,
913                )?; //todo:boolean type
914                //todo: introduce row polymophism so that not narrowing the type of `then` and `else` too much.
915                let thent = self.infer_type_unwrapping(*then);
916                let elset = opt_else.map_or(Type::Primitive(PType::Unit).into_id(), |e| {
917                    self.infer_type_unwrapping(e)
918                });
919                let rel = self.unify_types(thent, elset)?;
920                Ok(thent)
921            }
922            Expr::Block(expr) => expr.map_or(
923                Ok(Type::Primitive(PType::Unit).into_id_with_location(loc)),
924                |e| {
925                    self.env.extend(); //block creates local scope.
926                    let res = self.infer_type(e);
927                    self.env.to_outer();
928                    res
929                },
930            ),
931            Expr::Escape(e) => {
932                let loc_e = Location::new(e.to_span(), self.file_path.clone());
933                // Decrease stage for escape expression
934                self.stage = self.stage.decrement();
935                log::trace!("Unstaging escape expression, stage => {:?}", self.stage);
936                let res = self.infer_type_unwrapping(*e);
937                // Increase stage back
938                self.stage = self.stage.increment();
939                let intermediate = self.gen_intermediate_type_with_location(loc_e.clone());
940                let rel = self.unify_types(
941                    res,
942                    Type::Code(intermediate).into_id_with_location(loc_e.clone()),
943                )?;
944                Ok(intermediate)
945            }
946            Expr::Bracket(e) => {
947                let loc_e = Location::new(e.to_span(), self.file_path.clone());
948                // Increase stage for bracket expression
949                self.stage = self.stage.increment();
950                log::trace!("Staging bracket expression, stage => {:?}", self.stage);
951                let res = self.infer_type_unwrapping(*e);
952                // Decrease stage back
953                self.stage = self.stage.decrement();
954                Ok(Type::Code(res).into_id_with_location(loc_e))
955            }
956            _ => Ok(Type::Failure.into_id_with_location(loc)),
957        };
958        res.inspect(|ty| {
959            self.result_memo.insert(e.0, *ty);
960        })
961    }
962    fn infer_type_unwrapping(&mut self, e: ExprNodeId) -> TypeNodeId {
963        match self.infer_type(e) {
964            Ok(t) => t,
965            Err(err) => {
966                self.errors.extend(err);
967                Type::Failure
968                    .into_id_with_location(Location::new(e.to_span(), self.file_path.clone()))
969            }
970        }
971    }
972}
973
974pub fn infer_root(
975    e: ExprNodeId,
976    builtin_types: &[(Symbol, TypeNodeId)],
977    file_path: PathBuf,
978) -> InferContext {
979    let mut ctx = InferContext::new(builtin_types, file_path.clone());
980    let _t = ctx
981        .infer_type(e)
982        .unwrap_or(Type::Failure.into_id_with_location(e.to_location()));
983    ctx.substitute_all_intermediates();
984    ctx
985}
986
987#[cfg(test)]
988mod tests {
989    use super::*;
990    use crate::interner::ToSymbol;
991    use crate::types::Type;
992    use crate::utils::metadata::{Location, Span};
993
994    fn create_test_context() -> InferContext {
995        InferContext::new(&[], PathBuf::from("test"))
996    }
997
998    fn create_test_location() -> Location {
999        Location::new(Span { start: 0, end: 0 }, PathBuf::from("test"))
1000    }
1001
1002    #[test]
1003    fn test_stage_mismatch_detection() {
1004        let mut ctx = create_test_context();
1005        let loc = create_test_location();
1006
1007        // Define a variable 'x' at stage 0
1008        let var_name = "x".to_symbol();
1009        let var_type =
1010            Type::Primitive(crate::types::PType::Numeric).into_id_with_location(loc.clone());
1011        ctx.env
1012            .add_bind(&[(var_name, (var_type, EvalStage::Stage(0)))]);
1013
1014        // Try to look it up from stage 0 - should succeed
1015        ctx.stage = EvalStage::Stage(0);
1016        let result = ctx.lookup(var_name, loc.clone());
1017        assert!(
1018            result.is_ok(),
1019            "Looking up variable from same stage should succeed"
1020        );
1021
1022        // Try to look it up from stage 1 - should fail with stage mismatch
1023        ctx.stage = EvalStage::Stage(1);
1024        let result = ctx.lookup(var_name, loc.clone());
1025        assert!(
1026            result.is_err(),
1027            "Looking up variable from different stage should fail"
1028        );
1029
1030        if let Err(Error::StageMismatch {
1031            variable,
1032            expected_stage,
1033            found_stage,
1034            ..
1035        }) = result
1036        {
1037            assert_eq!(variable, var_name);
1038            assert_eq!(expected_stage, EvalStage::Stage(1));
1039            assert_eq!(found_stage, EvalStage::Stage(0));
1040        } else {
1041            panic!("Expected StageMismatch error, got: {:?}", result);
1042        }
1043    }
1044
1045    #[test]
1046    fn test_persistent_stage_access() {
1047        let mut ctx = create_test_context();
1048        let loc = create_test_location();
1049
1050        // Define a variable at Persistent stage
1051        let var_name = "persistent_var".to_symbol();
1052        let var_type =
1053            Type::Primitive(crate::types::PType::Numeric).into_id_with_location(loc.clone());
1054        ctx.env
1055            .add_bind(&[(var_name, (var_type, EvalStage::Persistent))]);
1056
1057        // Try to access from different stages - should all succeed
1058        for stage in [0, 1, 2] {
1059            ctx.stage = EvalStage::Stage(stage);
1060            let result = ctx.lookup(var_name, loc.clone());
1061            assert!(
1062                result.is_ok(),
1063                "Persistent stage variables should be accessible from stage {}",
1064                stage
1065            );
1066        }
1067    }
1068
1069    #[test]
1070    fn test_same_stage_access() {
1071        let mut ctx = create_test_context();
1072        let loc = create_test_location();
1073
1074        // Define variables at different stages
1075        for stage in [0, 1, 2] {
1076            let var_name = format!("var_stage_{}", stage).to_symbol();
1077            let var_type =
1078                Type::Primitive(crate::types::PType::Numeric).into_id_with_location(loc.clone());
1079            ctx.env
1080                .add_bind(&[(var_name, (var_type, EvalStage::Stage(stage)))]);
1081        }
1082
1083        // Each variable should only be accessible from its own stage
1084        for stage in [0, 1, 2] {
1085            ctx.stage = EvalStage::Stage(stage);
1086            let var_name = format!("var_stage_{}", stage).to_symbol();
1087            let result = ctx.lookup(var_name, loc.clone());
1088            assert!(
1089                result.is_ok(),
1090                "Variable should be accessible from its own stage {}",
1091                stage
1092            );
1093
1094            // Should not be accessible from other stages
1095            for other_stage in [0, 1, 2] {
1096                if other_stage != stage {
1097                    ctx.stage = EvalStage::Stage(other_stage);
1098                    let result = ctx.lookup(var_name, loc.clone());
1099                    assert!(
1100                        result.is_err(),
1101                        "Variable from stage {} should not be accessible from stage {}",
1102                        stage,
1103                        other_stage
1104                    );
1105                }
1106            }
1107        }
1108    }
1109
1110    #[test]
1111    fn test_stage_transitions_bracket_escape() {
1112        let mut ctx = create_test_context();
1113
1114        // Test that stage transitions work correctly
1115        assert_eq!(ctx.stage, EvalStage::Stage(0), "Initial stage should be 0");
1116
1117        // Simulate bracket behavior - stage increment
1118        ctx.stage = ctx.stage.increment();
1119        assert_eq!(
1120            ctx.stage,
1121            EvalStage::Stage(1),
1122            "Stage should increment to 1 in bracket"
1123        );
1124
1125        // Simulate escape behavior - stage decrement
1126        ctx.stage = ctx.stage.decrement();
1127        assert_eq!(
1128            ctx.stage,
1129            EvalStage::Stage(0),
1130            "Stage should decrement back to 0 after escape"
1131        );
1132    }
1133
1134    #[test]
1135    fn test_multi_stage_environment() {
1136        let mut ctx = create_test_context();
1137        let loc = create_test_location();
1138
1139        // Create nested scope with different stages
1140        ctx.env.extend(); // Create new scope
1141
1142        // Add variable at stage 0
1143        let var_stage0 = "x".to_symbol();
1144        let var_type =
1145            Type::Primitive(crate::types::PType::Numeric).into_id_with_location(loc.clone());
1146        ctx.stage = EvalStage::Stage(0);
1147        ctx.env
1148            .add_bind(&[(var_stage0, (var_type, EvalStage::Stage(0)))]);
1149
1150        ctx.env.extend(); // Create another scope
1151
1152        // Add variable with same name at stage 1
1153        let var_stage1 = "x".to_symbol(); // Same name, different stage
1154        ctx.stage = EvalStage::Stage(1);
1155        ctx.env
1156            .add_bind(&[(var_stage1, (var_type, EvalStage::Stage(1)))]);
1157
1158        // Test lookups from different stages
1159        ctx.stage = EvalStage::Stage(0);
1160        let result = ctx.lookup(var_stage0, loc.clone());
1161        assert!(
1162            result.is_err(),
1163            "Stage 0 variable should not be accessible from nested stage 0 context due to shadowing"
1164        );
1165
1166        ctx.stage = EvalStage::Stage(1);
1167        let result = ctx.lookup(var_stage1, loc.clone());
1168        assert!(
1169            result.is_ok(),
1170            "Stage 1 variable should be accessible from stage 1"
1171        );
1172
1173        ctx.stage = EvalStage::Stage(0);
1174        let result = ctx.lookup(var_stage1, loc.clone());
1175        assert!(
1176            result.is_err(),
1177            "Stage 1 variable should not be accessible from stage 0"
1178        );
1179
1180        // Clean up scopes
1181        ctx.env.to_outer();
1182        ctx.env.to_outer();
1183    }
1184}