1use crate::ast::{Expr, Literal, RecordField};
2use crate::compiler::{EvalStage, intrinsics};
3use crate::interner::{ExprKey, ExprNodeId, Symbol, ToSymbol, TypeNodeId};
4use crate::pattern::{Pattern, TypedPattern};
5use crate::types::{IntermediateId, PType, RecordTypeField, Type, TypeSchemeId, TypeVar};
6use crate::utils::metadata::Location;
7use crate::utils::{environment::Environment, error::ReportableError};
8use crate::{function, integer, numeric, unit};
9use itertools::Itertools;
10use std::cell::RefCell;
11use std::collections::BTreeMap;
12use std::fmt;
13use std::path::PathBuf;
14use std::rc::Rc;
15use std::sync::{Arc, Mutex, RwLock};
16
17mod unification;
18use unification::{Error as UnificationError, Relation, unify_types};
19
20#[derive(Clone, Debug)]
21pub enum Error {
22 TypeMismatch {
23 left: (TypeNodeId, Location),
24 right: (TypeNodeId, Location),
25 },
26 LengthMismatch {
27 left: (usize, Location),
28 right: (usize, Location),
29 },
30 PatternMismatch((TypeNodeId, Location), (Pattern, Location)),
31 NonFunctionForLetRec(TypeNodeId, Location),
32 NonFunctionForApply(TypeNodeId, Location),
33 NonSupertypeArgument {
34 location: Location,
35 expected: TypeNodeId,
36 found: TypeNodeId,
37 },
38 CircularType(Location, Location),
39 IndexOutOfRange {
40 len: u16,
41 idx: u16,
42 loc: Location,
43 },
44 IndexForNonTuple(Location, TypeNodeId),
45 FieldForNonRecord(Location, TypeNodeId),
46 FieldNotExist {
47 field: Symbol,
48 loc: Location,
49 et: TypeNodeId,
50 },
51 DuplicateKeyInRecord {
52 key: Vec<Symbol>,
53 loc: Location,
54 },
55 DuplicateKeyInParams(Vec<(Symbol, Location)>),
56 IncompatibleKeyInRecord {
58 left: (Vec<(Symbol, TypeNodeId)>, Location),
59 right: (Vec<(Symbol, TypeNodeId)>, Location),
60 },
61 VariableNotFound(Symbol, Location),
62 StageMismatch {
63 variable: Symbol,
64 expected_stage: EvalStage,
65 found_stage: EvalStage,
66 location: Location,
67 },
68 NonPrimitiveInFeed(Location),
69}
70impl fmt::Display for Error {
71 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72 write!(f, "Type Inference Error")
73 }
74}
75
76impl std::error::Error for Error {}
77impl ReportableError for Error {
78 fn get_message(&self) -> String {
79 match self {
80 Error::TypeMismatch { .. } => format!("Type mismatch"),
81 Error::PatternMismatch(..) => format!("Pattern mismatch"),
82 Error::LengthMismatch { .. } => format!("Length of the elements are different"),
83 Error::NonFunctionForLetRec(_, _) => format!("`letrec` can take only function type."),
84 Error::NonFunctionForApply(_, _) => {
85 format!("This is not applicable because it is not a function type.")
86 }
87 Error::CircularType(_, _) => format!("Circular loop of type definition detected."),
88 Error::IndexOutOfRange { len, idx, .. } => {
89 format!("Length of tuple elements is {len} but index was {idx}")
90 }
91 Error::IndexForNonTuple(_, _) => {
92 format!("Index access for non-tuple variable.")
93 }
94 Error::VariableNotFound(symbol, _) => {
95 format!("Variable \"{symbol}\" not found in this scope")
96 }
97 Error::StageMismatch {
98 variable,
99 expected_stage,
100 found_stage,
101 ..
102 } => {
103 format!(
104 "Variable {variable} is defined in stage {} but accessed from stage {}",
105 found_stage.format_for_error(),
106 expected_stage.format_for_error()
107 )
108 }
109 Error::NonPrimitiveInFeed(_) => {
110 format!("Function that uses `self` cannot return function type.")
111 }
112 Error::DuplicateKeyInParams { .. } => {
113 format!("Duplicate keys found in parameter list")
114 }
115 Error::DuplicateKeyInRecord { .. } => {
116 format!("Duplicate keys found in record type")
117 }
118 Error::FieldForNonRecord { .. } => {
119 format!("Field access for non-record variable.")
120 }
121 Error::FieldNotExist { field, .. } => {
122 format!("Field \"{field}\" does not exist in the record type")
123 }
124 Error::IncompatibleKeyInRecord { .. } => {
125 format!("Record type has incompatible keys.",)
126 }
127
128 Error::NonSupertypeArgument { .. } => {
129 format!("Arguments for functions are less than required.")
130 }
131 }
132 }
133 fn get_labels(&self) -> Vec<(Location, String)> {
134 match self {
135 Error::TypeMismatch {
136 left: (lty, locl),
137 right: (rty, locr),
138 } => vec![
139 (locl.clone(), lty.to_type().to_string_for_error()),
140 (locr.clone(), rty.to_type().to_string_for_error()),
141 ],
142 Error::PatternMismatch((ty, loct), (pat, locp)) => vec![
143 (loct.clone(), ty.to_type().to_string_for_error()),
144 (locp.clone(), pat.to_string()),
145 ],
146 Error::LengthMismatch {
147 left: (l, locl),
148 right: (r, locr),
149 } => vec![
150 (locl.clone(), format!("The length is {l}")),
151 (locr.clone(), format!("but the length for here is {r}")),
152 ],
153 Error::NonFunctionForLetRec(ty, loc) => {
154 vec![(loc.clone(), ty.to_type().to_string_for_error())]
155 }
156 Error::NonFunctionForApply(ty, loc) => {
157 vec![(loc.clone(), ty.to_type().to_string_for_error())]
158 }
159 Error::CircularType(loc1, loc2) => vec![
160 (loc1.clone(), format!("Circular type happens here")),
161 (loc2.clone(), format!("and here")),
162 ],
163 Error::IndexOutOfRange { loc, len, .. } => {
164 vec![(loc.clone(), format!("Length for this tuple is {len}"))]
165 }
166 Error::IndexForNonTuple(loc, ty) => {
167 vec![(
168 loc.clone(),
169 format!(
170 "This is not tuple type but {}",
171 ty.to_type().to_string_for_error()
172 ),
173 )]
174 }
175 Error::VariableNotFound(symbol, loc) => {
176 vec![(loc.clone(), format!("{symbol} is not defined"))]
177 }
178 Error::StageMismatch {
179 variable,
180 expected_stage,
181 found_stage,
182 location,
183 } => {
184 vec![(
185 location.clone(),
186 format!(
187 "Variable \"{variable}\" defined in stage {} cannot be accessed from stage {}",
188 found_stage.format_for_error(),
189 expected_stage.format_for_error()
190 ),
191 )]
192 }
193 Error::NonPrimitiveInFeed(loc) => {
194 vec![(loc.clone(), format!("This cannot be function type."))]
195 }
196 Error::DuplicateKeyInRecord { key, loc } => {
197 vec![(
198 loc.clone(),
199 format!(
200 "Duplicate keys \"{}\" found in record type",
201 key.iter()
202 .map(|s| s.to_string())
203 .collect::<Vec<_>>()
204 .join(", ")
205 ),
206 )]
207 }
208 Error::DuplicateKeyInParams(keys) => keys
209 .iter()
210 .map(|(key, loc)| {
211 (
212 loc.clone(),
213 format!("Duplicate key \"{key}\" found in parameter list"),
214 )
215 })
216 .collect(),
217 Error::FieldForNonRecord(location, ty) => {
218 vec![(
219 location.clone(),
220 format!(
221 "Field access for non-record type {}.",
222 ty.to_type().to_string_for_error()
223 ),
224 )]
225 }
226 Error::FieldNotExist { field, loc, et } => vec![(
227 loc.clone(),
228 format!(
229 "Field \"{}\" does not exist in the type {}",
230 field,
231 et.to_type().to_string_for_error()
232 ),
233 )],
234 Error::IncompatibleKeyInRecord {
235 left: (left, lloc),
236 right: (right, rloc),
237 } => {
238 vec![
239 (
240 lloc.clone(),
241 format!(
242 "the record here contains{}",
243 left.iter()
244 .map(|(key, ty)| format!(
245 " \"{key}\":{}",
246 ty.to_type().to_string_for_error()
247 ))
248 .collect::<Vec<_>>()
249 .join(", ")
250 ),
251 ),
252 (
253 rloc.clone(),
254 format!(
255 "but the record here contains {}",
256 right
257 .iter()
258 .map(|(key, ty)| format!(
259 " \"{key}\":{}",
260 ty.to_type().to_string_for_error()
261 ))
262 .collect::<Vec<_>>()
263 .join(", ")
264 ),
265 ),
266 ]
267 }
268
269 Error::NonSupertypeArgument {
270 location,
271 expected,
272 found,
273 } => {
274 vec![(
275 location.clone(),
276 format!(
277 "Type {} is not a supertype of the expected type {}",
278 found.to_type().to_string_for_error(),
279 expected.to_type().to_string_for_error()
280 ),
281 )]
282 }
283 }
284 }
285}
286
287#[derive(Clone, Debug)]
288pub struct InferContext {
289 interm_idx: IntermediateId,
290 typescheme_idx: TypeSchemeId,
291 level: u64,
292 stage: EvalStage,
293 instantiated_map: BTreeMap<TypeSchemeId, TypeNodeId>, generalize_map: BTreeMap<IntermediateId, TypeSchemeId>,
295 result_memo: BTreeMap<ExprKey, TypeNodeId>,
296 file_path: PathBuf,
297 pub env: Environment<(TypeNodeId, EvalStage)>,
298 pub errors: Vec<Error>,
299}
300impl InferContext {
301 fn new(builtins: &[(Symbol, TypeNodeId)], file_path: PathBuf) -> Self {
302 let mut res = Self {
303 interm_idx: Default::default(),
304 typescheme_idx: Default::default(),
305 level: Default::default(),
306 stage: EvalStage::Stage(0), instantiated_map: Default::default(),
308 generalize_map: Default::default(),
309 result_memo: Default::default(),
310 file_path,
311 env: Environment::<(TypeNodeId, EvalStage)>::default(),
312 errors: Default::default(),
313 };
314 res.env.extend();
315 let intrinsics = Self::intrinsic_types()
317 .into_iter()
318 .map(|(name, ty)| (name, (ty, EvalStage::Persistent)))
319 .collect::<Vec<_>>();
320 res.env.add_bind(&intrinsics);
321 let builtins = builtins
323 .iter()
324 .map(|(name, ty)| (*name, (*ty, EvalStage::Persistent)))
325 .collect::<Vec<_>>();
326 res.env.add_bind(&builtins);
327 res
328 }
329}
330impl InferContext {
331 fn intrinsic_types() -> Vec<(Symbol, TypeNodeId)> {
332 let binop_ty = function!(vec![numeric!(), numeric!()], numeric!());
333 let binop_names = [
334 intrinsics::ADD,
335 intrinsics::SUB,
336 intrinsics::MULT,
337 intrinsics::DIV,
338 intrinsics::MODULO,
339 intrinsics::POW,
340 intrinsics::GT,
341 intrinsics::LT,
342 intrinsics::GE,
343 intrinsics::LE,
344 intrinsics::EQ,
345 intrinsics::NE,
346 intrinsics::AND,
347 intrinsics::OR,
348 ];
349 let uniop_ty = function!(vec![numeric!()], numeric!());
350 let uniop_names = [
351 intrinsics::NEG,
352 intrinsics::MEM,
353 intrinsics::SIN,
354 intrinsics::COS,
355 intrinsics::ABS,
356 intrinsics::LOG,
357 intrinsics::SQRT,
358 ];
359
360 let binds = binop_names.map(|n| (n.to_symbol(), binop_ty));
361 let unibinds = uniop_names.map(|n| (n.to_symbol(), uniop_ty));
362 [
363 (
364 intrinsics::DELAY.to_symbol(),
365 function!(vec![numeric!(), numeric!(), numeric!()], numeric!()),
366 ),
367 (
368 intrinsics::TOFLOAT.to_symbol(),
369 function!(vec![integer!()], numeric!()),
370 ),
371 ]
372 .into_iter()
373 .chain(binds)
374 .chain(unibinds)
375 .collect()
376 }
377 fn unwrap_result(&mut self, res: Result<TypeNodeId, Vec<Error>>) -> TypeNodeId {
378 match res {
379 Ok(t) => t,
380 Err(mut e) => {
381 let loc = &e[0].get_labels()[0].0; self.errors.append(&mut e);
383 Type::Failure.into_id_with_location(loc.clone())
384 }
385 }
386 }
387 fn get_typescheme(&mut self, tvid: IntermediateId, loc: Location) -> TypeNodeId {
388 self.generalize_map.get(&tvid).cloned().map_or_else(
389 || self.gen_typescheme(loc),
390 |id| Type::TypeScheme(id).into_id(),
391 )
392 }
393 fn gen_typescheme(&mut self, loc: Location) -> TypeNodeId {
394 let res = Type::TypeScheme(self.typescheme_idx).into_id_with_location(loc);
395 self.typescheme_idx.0 += 1;
396 res
397 }
398
399 fn gen_intermediate_type_with_location(&mut self, loc: Location) -> TypeNodeId {
400 let res = Type::Intermediate(Arc::new(RwLock::new(TypeVar::new(
401 self.interm_idx,
402 self.level,
403 ))))
404 .into_id_with_location(loc);
405 self.interm_idx.0 += 1;
406 res
407 }
408 fn convert_unknown_to_intermediate(&mut self, t: TypeNodeId, loc: Location) -> TypeNodeId {
409 match t.to_type() {
410 Type::Unknown => self.gen_intermediate_type_with_location(loc.clone()),
411 _ => t.apply_fn(|t| self.convert_unknown_to_intermediate(t, loc.clone())),
412 }
413 }
414 fn convert_unify_error(&self, e: UnificationError) -> Error {
415 let gen_loc = |span| Location::new(span, self.file_path.clone());
416 match e {
417 UnificationError::TypeMismatch { left, right } => Error::TypeMismatch {
418 left: (left, gen_loc(left.to_span())),
419 right: (right, gen_loc(right.to_span())),
420 },
421 UnificationError::LengthMismatch {
422 left: (left, lspan),
423 right: (right, rspan),
424 } => Error::LengthMismatch {
425 left: (left.len(), gen_loc(lspan)),
426 right: (right.len(), gen_loc(rspan)),
427 },
428 UnificationError::CircularType { left, right } => {
429 Error::CircularType(gen_loc(left), gen_loc(right))
430 }
431 UnificationError::ImcompatibleRecords {
432 left: (left, lspan),
433 right: (right, rspan),
434 } => Error::IncompatibleKeyInRecord {
435 left: (left, gen_loc(lspan)),
436 right: (right, gen_loc(rspan)),
437 },
438 }
439 }
440 fn unify_types(&self, t1: TypeNodeId, t2: TypeNodeId) -> Result<Relation, Vec<Error>> {
441 unify_types(t1, t2)
442 .map_err(|e| e.into_iter().map(|e| self.convert_unify_error(e)).collect())
443 }
444 fn merge_rel_result(
446 &self,
447 rel1: Result<Relation, Vec<Error>>,
448 rel2: Result<Relation, Vec<Error>>,
449 t1: TypeNodeId,
450 t2: TypeNodeId,
451 ) -> Result<(), Vec<Error>> {
452 match (rel1, rel2) {
453 (Ok(Relation::Identical), Ok(Relation::Identical)) => Ok(()),
454 (Ok(_), Ok(_)) => Err(vec![Error::TypeMismatch {
455 left: (t1, Location::new(t1.to_span(), self.file_path.clone())),
456 right: (t2, Location::new(t2.to_span(), self.file_path.clone())),
457 }]),
458 (Err(e1), Err(e2)) => Err(e1.into_iter().chain(e2).collect()),
459 (Err(e), _) | (_, Err(e)) => Err(e),
460 }
461 }
462 pub fn substitute_type(t: TypeNodeId) -> TypeNodeId {
463 match t.to_type() {
464 Type::Intermediate(cell) => {
465 let TypeVar { parent, .. } = &*cell.read().unwrap() as &TypeVar;
466 match parent {
467 Some(p) => Self::substitute_type(*p),
468 None => Type::Unknown.into_id_with_location(t.to_loc()),
469 }
470 }
471 _ => t.apply_fn(Self::substitute_type),
472 }
473 }
474 fn substitute_all_intermediates(&mut self) {
475 let mut e_list = self
476 .result_memo
477 .iter()
478 .map(|(e, t)| (*e, Self::substitute_type(*t)))
479 .collect::<Vec<_>>();
480
481 e_list.iter_mut().for_each(|(e, t)| {
482 log::trace!("e: {:?} t: {}", e, t.to_type());
483 let _old = self.result_memo.insert(*e, *t);
484 })
485 }
486
487 fn generalize(&mut self, t: TypeNodeId) -> TypeNodeId {
488 match t.to_type() {
489 Type::Intermediate(tvar) => {
490 let &TypeVar { level, var, .. } = &*tvar.read().unwrap() as &TypeVar;
491 if level > self.level {
492 self.get_typescheme(var, t.to_loc())
493 } else {
494 t
495 }
496 }
497 _ => t.apply_fn(|t| self.generalize(t)),
498 }
499 }
500 fn instantiate(&mut self, t: TypeNodeId) -> TypeNodeId {
501 match t.to_type() {
502 Type::TypeScheme(id) => {
503 if let Some(tvar) = self.instantiated_map.get(&id) {
504 *tvar
505 } else {
506 let res = self.gen_intermediate_type_with_location(t.to_loc());
507 self.instantiated_map.insert(id, res);
508 res
509 }
510 }
511 _ => t.apply_fn(|t| self.instantiate(t)),
512 }
513 }
514
515 fn bind_pattern(
520 &mut self,
521 pat: (TypedPattern, Location),
522 body: (TypeNodeId, Location),
523 ) -> Result<TypeNodeId, Vec<Error>> {
524 let (TypedPattern { pat, ty, .. }, loc_p) = pat;
525 let (body_t, loc_b) = body.clone();
526 let mut bind_item = |pat| {
527 let newloc = ty.to_loc();
528 let ity = self.gen_intermediate_type_with_location(newloc.clone());
529 let p = TypedPattern::new(pat, ity);
530 self.bind_pattern((p, newloc.clone()), (ity, newloc))
531 };
532 let pat_t = match pat {
533 Pattern::Single(id) => {
534 let pat_t = self.convert_unknown_to_intermediate(ty, loc_p);
535 log::trace!("bind {} : {}", id, pat_t.to_type().to_string());
536 self.env.add_bind(&[(id, (pat_t, self.stage))]);
537 Ok::<TypeNodeId, Vec<Error>>(pat_t)
538 }
539 Pattern::Tuple(pats) => {
540 let elems = pats.iter().map(|p| bind_item(p.clone())).try_collect()?; let res = Type::Tuple(elems).into_id_with_location(loc_p);
542 let target = self.convert_unknown_to_intermediate(ty, loc_b);
543 let rel = self.unify_types(res, target)?;
544 Ok(res)
545 }
546 Pattern::Record(items) => {
547 let res = items
548 .iter()
549 .map(|(key, v)| {
550 bind_item(v.clone()).map(|ty| RecordTypeField {
551 key: *key,
552 ty,
553 has_default: false,
554 })
555 })
556 .try_collect()?; let res = Type::Record(res).into_id_with_location(loc_p);
558 let target = self.convert_unknown_to_intermediate(ty, loc_b);
559 let rel = self.unify_types(res, target)?;
560 Ok(res)
561 }
562 Pattern::Error => Err(vec![Error::PatternMismatch(
563 (
564 Type::Failure.into_id_with_location(loc_p.clone()),
565 loc_b.clone(),
566 ),
567 (pat, loc_p.clone()),
568 )]),
569 }?;
570 let rel = self.unify_types(pat_t, body_t)?;
571 Ok(self.generalize(pat_t))
572 }
573
574 pub fn lookup(&self, name: Symbol, loc: Location) -> Result<TypeNodeId, Error> {
575 use crate::utils::environment::LookupRes;
576 match self.env.lookup_cls(&name) {
577 LookupRes::Local((ty, bound_stage)) if self.stage == *bound_stage => Ok(*ty),
578 LookupRes::UpValue(_, (ty, bound_stage)) if self.stage == *bound_stage => Ok(*ty),
579 LookupRes::Global((ty, bound_stage))
580 if self.stage == *bound_stage || *bound_stage == EvalStage::Persistent =>
581 {
582 Ok(*ty)
583 }
584 LookupRes::None => Err(Error::VariableNotFound(name, loc)),
585 LookupRes::Local((_, bound_stage))
586 | LookupRes::UpValue(_, (_, bound_stage))
587 | LookupRes::Global((_, bound_stage)) => Err(Error::StageMismatch {
588 variable: name,
589 expected_stage: self.stage,
590 found_stage: *bound_stage,
591 location: loc,
592 }),
593 }
594 }
595 pub(crate) fn infer_type_literal(e: &Literal, loc: Location) -> Result<TypeNodeId, Error> {
596 let pt = match e {
597 Literal::Float(_) | Literal::Now | Literal::SampleRate => PType::Numeric,
598 Literal::Int(_s) => PType::Int,
599 Literal::String(_s) => PType::String,
600 Literal::SelfLit => panic!("\"self\" should not be shown at type inference stage"),
601 Literal::PlaceHolder => panic!("\"_\" should not be shown at type inference stage"),
602 };
603 Ok(Type::Primitive(pt).into_id_with_location(loc))
604 }
605 fn infer_vec(&mut self, e: &[ExprNodeId]) -> Result<Vec<TypeNodeId>, Vec<Error>> {
606 e.iter().map(|e| self.infer_type(*e)).try_collect()
607 }
608 fn infer_type_levelup(&mut self, e: ExprNodeId) -> TypeNodeId {
609 self.level += 1;
610 let res = self.infer_type_unwrapping(e);
611 self.level -= 1;
612 res
613 }
614 pub fn infer_type(&mut self, e: ExprNodeId) -> Result<TypeNodeId, Vec<Error>> {
615 if let Some(r) = self.result_memo.get(&e.0) {
616 return Ok(*r);
618 }
619 let loc = e.to_location();
620 let res: Result<TypeNodeId, Vec<Error>> = match &e.to_expr() {
621 Expr::Literal(l) => Self::infer_type_literal(l, loc).map_err(|e| vec![e]),
622 Expr::Tuple(e) => {
623 Ok(Type::Tuple(self.infer_vec(e.as_slice())?).into_id_with_location(loc))
624 }
625 Expr::ArrayLiteral(e) => {
626 let elem_types = self.infer_vec(e.as_slice())?;
627 let first = elem_types
628 .first()
629 .copied()
630 .unwrap_or(Type::Unknown.into_id_with_location(loc.clone()));
631 let elem_t = elem_types
633 .iter()
634 .try_fold(first, |acc, t| self.unify_types(acc, *t).map(|rel| *t))?;
635
636 Ok(Type::Array(elem_t).into_id_with_location(loc.clone()))
637 }
638 Expr::ArrayAccess(e, idx) => {
639 let arr_t = self.infer_type_unwrapping(*e);
640 let loc_e = e.to_location();
641 let idx_t = self.infer_type_unwrapping(*idx);
642 let loc_i = idx.to_location();
643
644 let elem_t = self.gen_intermediate_type_with_location(loc_e.clone());
645
646 let rel1 = self.unify_types(
647 idx_t,
648 Type::Primitive(PType::Numeric).into_id_with_location(loc_i),
649 );
650 let rel2 = self.unify_types(
651 Type::Array(elem_t).into_id_with_location(loc_e.clone()),
652 arr_t,
653 );
654 let _ = self.merge_rel_result(rel1, rel2, arr_t, idx_t)?;
655 Ok(elem_t)
656 }
657 Expr::Proj(e, idx) => {
658 let tup = self.infer_type_unwrapping(*e);
659 let vec_to_ans = |vec: &[_]| {
663 if vec.len() < *idx as usize {
664 Err(vec![Error::IndexOutOfRange {
665 len: vec.len() as u16,
666 idx: *idx as u16,
667 loc: loc.clone(),
668 }])
669 } else {
670 Ok(vec[*idx as usize])
671 }
672 };
673 match tup.to_type() {
674 Type::Tuple(vec) => vec_to_ans(&vec),
675 Type::Intermediate(tv) => {
676 let tv = tv.read().unwrap();
677 if let Some(parent) = tv.parent {
678 match parent.to_type() {
679 Type::Tuple(vec) => vec_to_ans(&vec),
680 _ => Err(vec![Error::IndexForNonTuple(loc, tup)]),
681 }
682 } else {
683 Err(vec![Error::IndexForNonTuple(loc, tup)])
684 }
685 }
686 _ => Err(vec![Error::IndexForNonTuple(loc, tup)]),
687 }
688 }
689 Expr::RecordLiteral(kvs) => {
690 let duplicate_keys = kvs
691 .iter()
692 .map(|RecordField { name, .. }| *name)
693 .duplicates();
694 if duplicate_keys.clone().count() > 0 {
695 Err(vec![Error::DuplicateKeyInRecord {
696 key: duplicate_keys.collect(),
697 loc,
698 }])
699 } else {
700 let kts: Vec<_> = kvs
701 .iter()
702 .map(|RecordField { name, expr }| {
703 let ty = self.infer_type_unwrapping(*expr);
704 RecordTypeField {
705 key: *name,
706 ty,
707 has_default: true,
708 }
709 })
710 .collect();
711 Ok(Type::Record(kts).into_id_with_location(loc))
712 }
713 }
714 Expr::RecordUpdate(_, _) => {
715 unreachable!("RecordUpdate should be expanded before type inference")
718 }
719 Expr::FieldAccess(expr, field) => {
720 let et = self.infer_type_unwrapping(*expr);
721 log::trace!("field access {} : {}", field, et.to_type());
722 let fields_to_ans = |fields: &[RecordTypeField]| {
723 fields
724 .iter()
725 .find_map(
726 |RecordTypeField { key, ty, .. }| {
727 if *key == *field { Some(*ty) } else { None }
728 },
729 )
730 .ok_or_else(|| {
731 vec![Error::FieldNotExist {
732 field: *field,
733 loc: loc.clone(),
734 et,
735 }]
736 })
737 };
738 match et.to_type() {
742 Type::Record(fields) => fields_to_ans(&fields),
743 Type::Intermediate(tv) => {
744 let tv = tv.read().unwrap();
745 if let Some(parent) = tv.parent {
746 match parent.to_type() {
747 Type::Record(fields) => fields_to_ans(&fields),
748 _ => Err(vec![Error::FieldForNonRecord(loc, et)]),
749 }
750 } else {
751 Err(vec![Error::FieldForNonRecord(loc, et)])
752 }
753 }
754 _ => Err(vec![Error::FieldForNonRecord(loc, et)]),
755 }
756 }
757 Expr::Feed(id, body) => {
758 let feedv = self.gen_intermediate_type_with_location(loc);
760
761 self.env.add_bind(&[(*id, (feedv, self.stage))]);
762 let bty = self.infer_type_unwrapping(*body);
763 let _rel = self.unify_types(bty, feedv)?;
764 if bty.to_type().contains_function() {
765 Err(vec![Error::NonPrimitiveInFeed(body.to_location())])
766 } else {
767 Ok(bty)
768 }
769 }
770 Expr::Lambda(p, rtype, body) => {
771 self.env.extend();
772 let dup = p.iter().duplicates_by(|id| id.id).map(|id| {
773 let loc = Location::new(id.to_span(), self.file_path.clone());
774 (id.id, loc)
775 });
776 if dup.clone().count() > 0 {
777 return Err(vec![Error::DuplicateKeyInParams(dup.collect())]);
778 }
779 let pvec = p
780 .iter()
781 .map(|id| {
782 let ity = self.convert_unknown_to_intermediate(id.ty, id.ty.to_loc());
783 self.env.add_bind(&[(id.id, (ity, self.stage))]);
784 RecordTypeField {
785 key: id.id,
786 ty: ity,
787 has_default: false,
788 }
789 })
790 .collect::<Vec<_>>();
791 let ptype = if pvec.is_empty() {
792 Type::Primitive(PType::Unit).into_id_with_location(loc.clone())
793 } else {
794 Type::Record(pvec).into_id_with_location(loc.clone())
795 };
796 let bty = if let Some(r) = rtype {
797 let bty = self.infer_type_unwrapping(*body);
798 let _rel = self.unify_types(*r, bty)?;
799 bty
800 } else {
801 self.infer_type_unwrapping(*body)
802 };
803 self.env.to_outer();
804 Ok(Type::Function {
805 arg: ptype,
806 ret: bty,
807 }
808 .into_id_with_location(e.to_location()))
809 }
810 Expr::Let(tpat, body, then) => {
811 let bodyt = self.infer_type_levelup(*body);
812 let loc_p = tpat.to_loc();
813 let loc_b = body.to_location();
814 let pat_t = self.bind_pattern((tpat.clone(), loc_p), (bodyt, loc_b));
815 let _pat_t = self.unwrap_result(pat_t);
816 match then {
817 Some(e) => self.infer_type(*e),
818 None => Ok(Type::Primitive(PType::Unit).into_id_with_location(loc)),
819 }
820 }
821 Expr::LetRec(id, body, then) => {
822 let idt = self.convert_unknown_to_intermediate(id.ty, id.ty.to_loc());
823 self.env.add_bind(&[(id.id, (idt, self.stage))]);
824 let bodyt = self.infer_type_levelup(*body);
826 let _res = self.unify_types(idt, bodyt);
827 match then {
828 Some(e) => self.infer_type(*e),
829 None => Ok(Type::Primitive(PType::Unit).into_id_with_location(loc)),
830 }
831 }
832 Expr::Assign(assignee, expr) => {
833 match assignee.to_expr() {
834 Expr::Var(name) => {
835 let assignee_t =
836 self.unwrap_result(self.lookup(name, loc).map_err(|e| vec![e]));
837 let e_t = self.infer_type_unwrapping(*expr);
838 let _rel = self.unify_types(assignee_t, e_t)?;
839 Ok(unit!())
840 }
841 Expr::FieldAccess(record, field_name) => {
842 let record_type = self.infer_type_unwrapping(record);
844 let value_type = self.infer_type_unwrapping(*expr);
845 let tmptype = Type::Record(vec![RecordTypeField {
846 key: field_name,
847 ty: value_type,
848 has_default: false,
849 }])
850 .into_id();
851 if self.unify_types(record_type, tmptype)? == Relation::Supertype {
852 unreachable!(
853 "record field access for an empty record will not likely to happen."
854 )
855 };
856 Ok(value_type)
857 }
858 Expr::ArrayAccess(_, _) => {
859 unimplemented!("Assignment to array is not implemented yet.")
860 }
861 _ => {
862 Err(vec![Error::VariableNotFound(
864 "invalid_assignment_target".to_symbol(),
865 loc.clone(),
866 )])
867 }
868 }
869 }
870 Expr::Then(e, then) => {
871 let _ = self.infer_type(*e)?;
872 then.map_or(Ok(unit!()), |t| self.infer_type(t))
873 }
874 Expr::Var(name) => {
875 let res = self.unwrap_result(self.lookup(*name, loc).map_err(|e| vec![e]));
876 Ok(self.instantiate(res))
878 }
879 Expr::Apply(fun, callee) => {
880 let loc_f = fun.to_location();
881 let fnl = self.infer_type_unwrapping(*fun);
882 let callee_t = match callee.len() {
883 0 => Type::Primitive(PType::Unit).into_id_with_location(loc.clone()),
884 1 => self.infer_type_unwrapping(callee[0]),
885 _ => {
886 let at_vec = self.infer_vec(callee.as_slice())?;
887 let span = callee[0].to_span().start..callee.last().unwrap().to_span().end;
888 let loc = Location::new(span, self.file_path.clone());
889 Type::Tuple(at_vec).into_id_with_location(loc)
890 }
891 };
892 let res_t = self.gen_intermediate_type_with_location(loc);
893 let fntype = Type::Function {
894 arg: callee_t,
895 ret: res_t,
896 }
897 .into_id_with_location(loc_f.clone());
898 match self.unify_types(fnl, fntype)? {
899 Relation::Subtype => Err(vec![Error::NonSupertypeArgument {
900 location: loc_f.clone(),
901 expected: fnl,
902 found: fntype,
903 }]),
904 _ => Ok(res_t),
905 }
906 }
907 Expr::If(cond, then, opt_else) => {
908 let condt = self.infer_type_unwrapping(*cond);
909 let cond_loc = cond.to_location();
910 let bt = self.unify_types(
911 Type::Primitive(PType::Numeric).into_id_with_location(cond_loc),
912 condt,
913 )?; let thent = self.infer_type_unwrapping(*then);
916 let elset = opt_else.map_or(Type::Primitive(PType::Unit).into_id(), |e| {
917 self.infer_type_unwrapping(e)
918 });
919 let rel = self.unify_types(thent, elset)?;
920 Ok(thent)
921 }
922 Expr::Block(expr) => expr.map_or(
923 Ok(Type::Primitive(PType::Unit).into_id_with_location(loc)),
924 |e| {
925 self.env.extend(); let res = self.infer_type(e);
927 self.env.to_outer();
928 res
929 },
930 ),
931 Expr::Escape(e) => {
932 let loc_e = Location::new(e.to_span(), self.file_path.clone());
933 self.stage = self.stage.decrement();
935 log::trace!("Unstaging escape expression, stage => {:?}", self.stage);
936 let res = self.infer_type_unwrapping(*e);
937 self.stage = self.stage.increment();
939 let intermediate = self.gen_intermediate_type_with_location(loc_e.clone());
940 let rel = self.unify_types(
941 res,
942 Type::Code(intermediate).into_id_with_location(loc_e.clone()),
943 )?;
944 Ok(intermediate)
945 }
946 Expr::Bracket(e) => {
947 let loc_e = Location::new(e.to_span(), self.file_path.clone());
948 self.stage = self.stage.increment();
950 log::trace!("Staging bracket expression, stage => {:?}", self.stage);
951 let res = self.infer_type_unwrapping(*e);
952 self.stage = self.stage.decrement();
954 Ok(Type::Code(res).into_id_with_location(loc_e))
955 }
956 _ => Ok(Type::Failure.into_id_with_location(loc)),
957 };
958 res.inspect(|ty| {
959 self.result_memo.insert(e.0, *ty);
960 })
961 }
962 fn infer_type_unwrapping(&mut self, e: ExprNodeId) -> TypeNodeId {
963 match self.infer_type(e) {
964 Ok(t) => t,
965 Err(err) => {
966 self.errors.extend(err);
967 Type::Failure
968 .into_id_with_location(Location::new(e.to_span(), self.file_path.clone()))
969 }
970 }
971 }
972}
973
974pub fn infer_root(
975 e: ExprNodeId,
976 builtin_types: &[(Symbol, TypeNodeId)],
977 file_path: PathBuf,
978) -> InferContext {
979 let mut ctx = InferContext::new(builtin_types, file_path.clone());
980 let _t = ctx
981 .infer_type(e)
982 .unwrap_or(Type::Failure.into_id_with_location(e.to_location()));
983 ctx.substitute_all_intermediates();
984 ctx
985}
986
987#[cfg(test)]
988mod tests {
989 use super::*;
990 use crate::interner::ToSymbol;
991 use crate::types::Type;
992 use crate::utils::metadata::{Location, Span};
993
994 fn create_test_context() -> InferContext {
995 InferContext::new(&[], PathBuf::from("test"))
996 }
997
998 fn create_test_location() -> Location {
999 Location::new(Span { start: 0, end: 0 }, PathBuf::from("test"))
1000 }
1001
1002 #[test]
1003 fn test_stage_mismatch_detection() {
1004 let mut ctx = create_test_context();
1005 let loc = create_test_location();
1006
1007 let var_name = "x".to_symbol();
1009 let var_type =
1010 Type::Primitive(crate::types::PType::Numeric).into_id_with_location(loc.clone());
1011 ctx.env
1012 .add_bind(&[(var_name, (var_type, EvalStage::Stage(0)))]);
1013
1014 ctx.stage = EvalStage::Stage(0);
1016 let result = ctx.lookup(var_name, loc.clone());
1017 assert!(
1018 result.is_ok(),
1019 "Looking up variable from same stage should succeed"
1020 );
1021
1022 ctx.stage = EvalStage::Stage(1);
1024 let result = ctx.lookup(var_name, loc.clone());
1025 assert!(
1026 result.is_err(),
1027 "Looking up variable from different stage should fail"
1028 );
1029
1030 if let Err(Error::StageMismatch {
1031 variable,
1032 expected_stage,
1033 found_stage,
1034 ..
1035 }) = result
1036 {
1037 assert_eq!(variable, var_name);
1038 assert_eq!(expected_stage, EvalStage::Stage(1));
1039 assert_eq!(found_stage, EvalStage::Stage(0));
1040 } else {
1041 panic!("Expected StageMismatch error, got: {:?}", result);
1042 }
1043 }
1044
1045 #[test]
1046 fn test_persistent_stage_access() {
1047 let mut ctx = create_test_context();
1048 let loc = create_test_location();
1049
1050 let var_name = "persistent_var".to_symbol();
1052 let var_type =
1053 Type::Primitive(crate::types::PType::Numeric).into_id_with_location(loc.clone());
1054 ctx.env
1055 .add_bind(&[(var_name, (var_type, EvalStage::Persistent))]);
1056
1057 for stage in [0, 1, 2] {
1059 ctx.stage = EvalStage::Stage(stage);
1060 let result = ctx.lookup(var_name, loc.clone());
1061 assert!(
1062 result.is_ok(),
1063 "Persistent stage variables should be accessible from stage {}",
1064 stage
1065 );
1066 }
1067 }
1068
1069 #[test]
1070 fn test_same_stage_access() {
1071 let mut ctx = create_test_context();
1072 let loc = create_test_location();
1073
1074 for stage in [0, 1, 2] {
1076 let var_name = format!("var_stage_{}", stage).to_symbol();
1077 let var_type =
1078 Type::Primitive(crate::types::PType::Numeric).into_id_with_location(loc.clone());
1079 ctx.env
1080 .add_bind(&[(var_name, (var_type, EvalStage::Stage(stage)))]);
1081 }
1082
1083 for stage in [0, 1, 2] {
1085 ctx.stage = EvalStage::Stage(stage);
1086 let var_name = format!("var_stage_{}", stage).to_symbol();
1087 let result = ctx.lookup(var_name, loc.clone());
1088 assert!(
1089 result.is_ok(),
1090 "Variable should be accessible from its own stage {}",
1091 stage
1092 );
1093
1094 for other_stage in [0, 1, 2] {
1096 if other_stage != stage {
1097 ctx.stage = EvalStage::Stage(other_stage);
1098 let result = ctx.lookup(var_name, loc.clone());
1099 assert!(
1100 result.is_err(),
1101 "Variable from stage {} should not be accessible from stage {}",
1102 stage,
1103 other_stage
1104 );
1105 }
1106 }
1107 }
1108 }
1109
1110 #[test]
1111 fn test_stage_transitions_bracket_escape() {
1112 let mut ctx = create_test_context();
1113
1114 assert_eq!(ctx.stage, EvalStage::Stage(0), "Initial stage should be 0");
1116
1117 ctx.stage = ctx.stage.increment();
1119 assert_eq!(
1120 ctx.stage,
1121 EvalStage::Stage(1),
1122 "Stage should increment to 1 in bracket"
1123 );
1124
1125 ctx.stage = ctx.stage.decrement();
1127 assert_eq!(
1128 ctx.stage,
1129 EvalStage::Stage(0),
1130 "Stage should decrement back to 0 after escape"
1131 );
1132 }
1133
1134 #[test]
1135 fn test_multi_stage_environment() {
1136 let mut ctx = create_test_context();
1137 let loc = create_test_location();
1138
1139 ctx.env.extend(); let var_stage0 = "x".to_symbol();
1144 let var_type =
1145 Type::Primitive(crate::types::PType::Numeric).into_id_with_location(loc.clone());
1146 ctx.stage = EvalStage::Stage(0);
1147 ctx.env
1148 .add_bind(&[(var_stage0, (var_type, EvalStage::Stage(0)))]);
1149
1150 ctx.env.extend(); let var_stage1 = "x".to_symbol(); ctx.stage = EvalStage::Stage(1);
1155 ctx.env
1156 .add_bind(&[(var_stage1, (var_type, EvalStage::Stage(1)))]);
1157
1158 ctx.stage = EvalStage::Stage(0);
1160 let result = ctx.lookup(var_stage0, loc.clone());
1161 assert!(
1162 result.is_err(),
1163 "Stage 0 variable should not be accessible from nested stage 0 context due to shadowing"
1164 );
1165
1166 ctx.stage = EvalStage::Stage(1);
1167 let result = ctx.lookup(var_stage1, loc.clone());
1168 assert!(
1169 result.is_ok(),
1170 "Stage 1 variable should be accessible from stage 1"
1171 );
1172
1173 ctx.stage = EvalStage::Stage(0);
1174 let result = ctx.lookup(var_stage1, loc.clone());
1175 assert!(
1176 result.is_err(),
1177 "Stage 1 variable should not be accessible from stage 0"
1178 );
1179
1180 ctx.env.to_outer();
1182 ctx.env.to_outer();
1183 }
1184}