1use crate::ast::program::TypeAliasMap;
2use crate::ast::{Expr, Literal, RecordField};
3use crate::compiler::{EvalStage, intrinsics};
4use crate::interner::{ExprKey, ExprNodeId, Symbol, ToSymbol, TypeNodeId};
5use crate::pattern::{Pattern, TypedId, TypedPattern};
6use crate::types::{IntermediateId, PType, RecordTypeField, Type, TypeSchemeId, TypeVar};
7use crate::utils::metadata::Location;
8use crate::utils::{environment::Environment, error::ReportableError};
9use crate::{function, integer, numeric, unit};
10use itertools::Itertools;
11use std::collections::{BTreeMap, HashMap};
12use std::path::PathBuf;
13use std::sync::{Arc, RwLock};
14use thiserror::Error;
15
16mod unification;
17pub(crate) use unification::Relation;
18use unification::{Error as UnificationError, unify_types};
19
20#[derive(Clone, Debug, Error)]
21#[error("Type Inference Error")]
22pub enum Error {
23 TypeMismatch {
24 left: (TypeNodeId, Location),
25 right: (TypeNodeId, Location),
26 },
27 EscapeRequiresCodeType {
28 found: (TypeNodeId, Location),
29 },
30 LengthMismatch {
31 left: (usize, Location),
32 right: (usize, Location),
33 },
34 PatternMismatch((TypeNodeId, Location), (Pattern, Location)),
35 NonFunctionForLetRec(TypeNodeId, Location),
36 NonFunctionForApply(TypeNodeId, Location),
37 NonSupertypeArgument {
38 location: Location,
39 expected: TypeNodeId,
40 found: TypeNodeId,
41 },
42 CircularType(Location, Location),
43 IndexOutOfRange {
44 len: u16,
45 idx: u16,
46 loc: Location,
47 },
48 IndexForNonTuple(Location, TypeNodeId),
49 FieldForNonRecord(Location, TypeNodeId),
50 FieldNotExist {
51 field: Symbol,
52 loc: Location,
53 et: TypeNodeId,
54 },
55 DuplicateKeyInRecord {
56 key: Vec<Symbol>,
57 loc: Location,
58 },
59 DuplicateKeyInParams(Vec<(Symbol, Location)>),
60 IncompatibleKeyInRecord {
62 left: (Vec<(Symbol, TypeNodeId)>, Location),
63 right: (Vec<(Symbol, TypeNodeId)>, Location),
64 },
65 VariableNotFound(Symbol, Location),
66 ModuleNotFound {
68 module_path: Vec<Symbol>,
69 location: Location,
70 },
71 MemberNotFound {
73 module_path: Vec<Symbol>,
74 member: Symbol,
75 location: Location,
76 },
77 PrivateMemberAccess {
79 module_path: Vec<Symbol>,
80 member: Symbol,
81 location: Location,
82 },
83 StageMismatch {
84 variable: Symbol,
85 expected_stage: EvalStage,
86 found_stage: EvalStage,
87 location: Location,
88 },
89 NonPrimitiveInFeed(Location),
90 ConstructorNotInUnion {
92 constructor: Symbol,
93 union_type: TypeNodeId,
94 location: Location,
95 },
96 ExpectedUnionType {
98 found: TypeNodeId,
99 location: Location,
100 },
101 NonExhaustiveMatch {
103 missing_constructors: Vec<Symbol>,
104 location: Location,
105 },
106 RecursiveTypeAlias {
108 type_name: Symbol,
109 cycle: Vec<Symbol>,
110 location: Location,
111 },
112 PrivateTypeAccess {
114 module_path: Vec<Symbol>,
115 type_name: Symbol,
116 location: Location,
117 },
118 PrivateTypeLeak {
120 function_name: Symbol,
121 private_type: Symbol,
122 location: Location,
123 },
124}
125
126impl ReportableError for Error {
127 fn get_message(&self) -> String {
128 match self {
129 Error::TypeMismatch { .. } => format!("Type mismatch"),
130 Error::EscapeRequiresCodeType { found: (ty, ..) } => {
131 format!(
132 "Escape requires a code value, but found {}",
133 ty.to_type().to_string_for_error()
134 )
135 }
136 Error::PatternMismatch(..) => format!("Pattern mismatch"),
137 Error::LengthMismatch { .. } => format!("Length of the elements are different"),
138 Error::NonFunctionForLetRec(_, _) => format!("`letrec` can take only function type."),
139 Error::NonFunctionForApply(_, _) => {
140 format!("This is not applicable because it is not a function type.")
141 }
142 Error::CircularType(_, _) => format!("Circular loop of type definition detected."),
143 Error::IndexOutOfRange { len, idx, .. } => {
144 format!("Length of tuple elements is {len} but index was {idx}")
145 }
146 Error::IndexForNonTuple(_, _) => {
147 format!("Index access for non-tuple variable.")
148 }
149 Error::VariableNotFound(symbol, _) => {
150 format!("Variable \"{symbol}\" not found in this scope")
151 }
152 Error::ModuleNotFound { module_path, .. } => {
153 let path_str = module_path
154 .iter()
155 .map(|s| s.to_string())
156 .collect::<Vec<_>>()
157 .join("::");
158 format!("Module \"{path_str}\" not found")
159 }
160 Error::MemberNotFound {
161 module_path,
162 member,
163 ..
164 } => {
165 let path_str = module_path
166 .iter()
167 .map(|s| s.to_string())
168 .collect::<Vec<_>>()
169 .join("::");
170 format!("Member \"{member}\" not found in module \"{path_str}\"")
171 }
172 Error::PrivateMemberAccess {
173 module_path,
174 member,
175 ..
176 } => {
177 let path_str = module_path
178 .iter()
179 .map(|s| s.to_string())
180 .collect::<Vec<_>>()
181 .join("::");
182 format!("Member \"{member}\" in module \"{path_str}\" is private")
183 }
184 Error::StageMismatch {
185 variable,
186 expected_stage,
187 found_stage,
188 ..
189 } => {
190 format!(
191 "Variable {variable} is defined in stage {} but accessed from stage {}",
192 found_stage.format_for_error(),
193 expected_stage.format_for_error()
194 )
195 }
196 Error::NonPrimitiveInFeed(_) => {
197 format!("Function that uses `self` cannot return function type.")
198 }
199 Error::DuplicateKeyInParams { .. } => {
200 format!("Duplicate keys found in parameter list")
201 }
202 Error::DuplicateKeyInRecord { .. } => {
203 format!("Duplicate keys found in record type")
204 }
205 Error::FieldForNonRecord { .. } => {
206 format!("Field access for non-record variable.")
207 }
208 Error::FieldNotExist { field, .. } => {
209 format!("Field \"{field}\" does not exist in the record type")
210 }
211 Error::IncompatibleKeyInRecord { .. } => {
212 format!("Record type has incompatible keys.",)
213 }
214
215 Error::NonSupertypeArgument { .. } => {
216 format!("Arguments for functions are less than required.")
217 }
218 Error::ConstructorNotInUnion { constructor, .. } => {
219 format!("Constructor \"{constructor}\" is not a variant of the union type")
220 }
221 Error::ExpectedUnionType { found, .. } => {
222 format!(
223 "Expected a union type but found {}",
224 found.to_type().to_string_for_error()
225 )
226 }
227 Error::NonExhaustiveMatch {
228 missing_constructors,
229 ..
230 } => {
231 let missing = missing_constructors
232 .iter()
233 .map(|s| s.to_string())
234 .collect::<Vec<_>>()
235 .join(", ");
236 format!("Match expression is not exhaustive. Missing patterns: {missing}")
237 }
238 Error::RecursiveTypeAlias {
239 type_name, cycle, ..
240 } => {
241 let cycle_str = cycle
242 .iter()
243 .map(|s| s.to_string())
244 .collect::<Vec<_>>()
245 .join(" -> ");
246 format!(
247 "Recursive type alias '{type_name}' detected. Cycle: {cycle_str} -> {type_name}. Use 'type rec' to declare recursive types."
248 )
249 }
250 Error::PrivateTypeAccess {
251 module_path,
252 type_name,
253 ..
254 } => {
255 let path_str = module_path
256 .iter()
257 .map(|s| s.to_string())
258 .collect::<Vec<_>>()
259 .join("::");
260 format!(
261 "Type '{type_name}' in module '{path_str}' is private and cannot be accessed from outside"
262 )
263 }
264 Error::PrivateTypeLeak {
265 function_name,
266 private_type,
267 ..
268 } => {
269 format!(
270 "Public function '{function_name}' cannot expose private type '{private_type}' in its signature"
271 )
272 }
273 }
274 }
275 fn get_labels(&self) -> Vec<(Location, String)> {
276 match self {
277 Error::TypeMismatch {
278 left: (lty, locl),
279 right: (rty, locr),
280 } => {
281 let expected = lty.get_root().to_type().to_string_for_error();
282 let found = rty.get_root().to_type().to_string_for_error();
283 let is_dummy = |loc: &Location| {
284 loc.path.as_os_str().is_empty() || (loc.span.start == 0 && loc.span.end == 0)
285 };
286 let normalize_loc = |primary: &Location, fallback: &Location| {
287 let mut loc = if is_dummy(primary) {
288 fallback.clone()
289 } else {
290 primary.clone()
291 };
292
293 if loc.path.as_os_str().is_empty() {
294 loc.path = if !primary.path.as_os_str().is_empty() {
295 primary.path.clone()
296 } else {
297 fallback.path.clone()
298 };
299 }
300
301 if loc.span.start == 0 && loc.span.end == 0 {
302 if !(primary.span.start == 0 && primary.span.end == 0) {
303 loc.span = primary.span.clone();
304 } else if !(fallback.span.start == 0 && fallback.span.end == 0) {
305 loc.span = fallback.span.clone();
306 } else {
307 loc.span = 0..1;
308 }
309 }
310 loc
311 };
312
313 let left_loc = normalize_loc(locl, locr);
314 let right_loc = normalize_loc(locr, &left_loc);
315 if left_loc == right_loc {
316 vec![(
317 left_loc,
318 format!("expected type: {expected}, found type: {found}"),
319 )]
320 } else {
321 vec![
322 (left_loc, format!("expected type: {expected}")),
323 (right_loc, format!("found type: {found}")),
324 ]
325 }
326 }
327 Error::EscapeRequiresCodeType { found: (ty, loc) } => vec![(
328 loc.clone(),
329 format!(
330 "escape expects `Code(T)`, but found {}. Escaping nested code containers such as arrays of quoted values is not supported",
331 ty.to_type().to_string_for_error()
332 ),
333 )],
334 Error::PatternMismatch((ty, loct), (pat, locp)) => vec![
335 (loct.clone(), ty.to_type().to_string_for_error()),
336 (locp.clone(), pat.to_string()),
337 ],
338 Error::LengthMismatch {
339 left: (l, locl),
340 right: (r, locr),
341 } => vec![
342 (locl.clone(), format!("The length is {l}")),
343 (locr.clone(), format!("but the length for here is {r}")),
344 ],
345 Error::NonFunctionForLetRec(ty, loc) => {
346 vec![(loc.clone(), ty.to_type().to_string_for_error())]
347 }
348 Error::NonFunctionForApply(ty, loc) => {
349 vec![(loc.clone(), ty.to_type().to_string_for_error())]
350 }
351 Error::CircularType(loc1, loc2) => vec![
352 (loc1.clone(), format!("Circular type happens here")),
353 (loc2.clone(), format!("and here")),
354 ],
355 Error::IndexOutOfRange { loc, len, .. } => {
356 vec![(loc.clone(), format!("Length for this tuple is {len}"))]
357 }
358 Error::IndexForNonTuple(loc, ty) => {
359 vec![(
360 loc.clone(),
361 format!(
362 "This is not tuple type but {}",
363 ty.to_type().to_string_for_error()
364 ),
365 )]
366 }
367 Error::VariableNotFound(symbol, loc) => {
368 vec![(loc.clone(), format!("{symbol} is not defined"))]
369 }
370 Error::ModuleNotFound {
371 module_path,
372 location,
373 } => {
374 let path_str = module_path
375 .iter()
376 .map(|s| s.to_string())
377 .collect::<Vec<_>>()
378 .join("::");
379 vec![(location.clone(), format!("Module \"{path_str}\" not found"))]
380 }
381 Error::MemberNotFound {
382 module_path,
383 member,
384 location,
385 } => {
386 let path_str = module_path
387 .iter()
388 .map(|s| s.to_string())
389 .collect::<Vec<_>>()
390 .join("::");
391 vec![(
392 location.clone(),
393 format!("\"{member}\" is not a member of \"{path_str}\""),
394 )]
395 }
396 Error::PrivateMemberAccess {
397 module_path,
398 member,
399 location,
400 } => {
401 let path_str = module_path
402 .iter()
403 .map(|s| s.to_string())
404 .collect::<Vec<_>>()
405 .join("::");
406 vec![(
407 location.clone(),
408 format!("\"{member}\" in \"{path_str}\" is private and cannot be accessed"),
409 )]
410 }
411 Error::StageMismatch {
412 variable,
413 expected_stage,
414 found_stage,
415 location,
416 } => {
417 vec![(
418 location.clone(),
419 format!(
420 "Variable \"{variable}\" defined in stage {} cannot be accessed from stage {}",
421 found_stage.format_for_error(),
422 expected_stage.format_for_error()
423 ),
424 )]
425 }
426 Error::NonPrimitiveInFeed(loc) => {
427 vec![(loc.clone(), format!("This cannot be function type."))]
428 }
429 Error::DuplicateKeyInRecord { key, loc } => {
430 vec![(
431 loc.clone(),
432 format!(
433 "Duplicate keys \"{}\" found in record type",
434 key.iter()
435 .map(|s| s.to_string())
436 .collect::<Vec<_>>()
437 .join(", ")
438 ),
439 )]
440 }
441 Error::DuplicateKeyInParams(keys) => keys
442 .iter()
443 .map(|(key, loc)| {
444 (
445 loc.clone(),
446 format!("Duplicate key \"{key}\" found in parameter list"),
447 )
448 })
449 .collect(),
450 Error::FieldForNonRecord(location, ty) => {
451 vec![(
452 location.clone(),
453 format!(
454 "Field access for non-record type {}.",
455 ty.to_type().to_string_for_error()
456 ),
457 )]
458 }
459 Error::FieldNotExist { field, loc, et } => vec![(
460 loc.clone(),
461 format!(
462 "Field \"{}\" does not exist in the type {}",
463 field,
464 et.to_type().to_string_for_error()
465 ),
466 )],
467 Error::IncompatibleKeyInRecord {
468 left: (left, lloc),
469 right: (right, rloc),
470 } => {
471 vec![
472 (
473 lloc.clone(),
474 format!(
475 "the record here contains{}",
476 left.iter()
477 .map(|(key, ty)| format!(
478 " \"{key}\":{}",
479 ty.to_type().to_string_for_error()
480 ))
481 .collect::<Vec<_>>()
482 .join(", ")
483 ),
484 ),
485 (
486 rloc.clone(),
487 format!(
488 "but the record here contains {}",
489 right
490 .iter()
491 .map(|(key, ty)| format!(
492 " \"{key}\":{}",
493 ty.to_type().to_string_for_error()
494 ))
495 .collect::<Vec<_>>()
496 .join(", ")
497 ),
498 ),
499 ]
500 }
501
502 Error::NonSupertypeArgument {
503 location,
504 expected,
505 found,
506 } => {
507 vec![(
508 location.clone(),
509 format!(
510 "Type {} is not a supertype of the expected type {}",
511 found.to_type().to_string_for_error(),
512 expected.to_type().to_string_for_error()
513 ),
514 )]
515 }
516 Error::ConstructorNotInUnion {
517 constructor,
518 union_type,
519 location,
520 } => {
521 vec![(
522 location.clone(),
523 format!(
524 "Constructor \"{constructor}\" is not a variant of {}",
525 union_type.to_type().to_string_for_error()
526 ),
527 )]
528 }
529 Error::ExpectedUnionType { found, location } => {
530 vec![(
531 location.clone(),
532 format!(
533 "Expected a union type but found {}",
534 found.to_type().to_string_for_error()
535 ),
536 )]
537 }
538 Error::NonExhaustiveMatch {
539 missing_constructors,
540 location,
541 } => {
542 let missing = missing_constructors
543 .iter()
544 .map(|s| format!("\"{s}\""))
545 .collect::<Vec<_>>()
546 .join(", ");
547 vec![(location.clone(), format!("Missing patterns: {missing}"))]
548 }
549 Error::RecursiveTypeAlias {
550 type_name,
551 cycle,
552 location,
553 } => {
554 let cycle_str = cycle
555 .iter()
556 .map(|s| s.to_string())
557 .collect::<Vec<_>>()
558 .join(" -> ");
559 vec![(
560 location.clone(),
561 format!(
562 "Type alias '{type_name}' creates a cycle: {cycle_str} -> {type_name}. Consider using 'type rec' instead of 'type alias'."
563 ),
564 )]
565 }
566 Error::PrivateTypeAccess {
567 module_path,
568 type_name,
569 location,
570 } => {
571 let path_str = module_path
572 .iter()
573 .map(|s| s.to_string())
574 .collect::<Vec<_>>()
575 .join("::");
576 vec![(
577 location.clone(),
578 format!("Type '{type_name}' in module '{path_str}' is private"),
579 )]
580 }
581 Error::PrivateTypeLeak { location, .. } => {
582 vec![(
583 location.clone(),
584 "private type leaked in public function signature".to_string(),
585 )]
586 }
587 }
588 }
589}
590
591#[derive(Clone, Debug)]
593pub struct ConstructorInfo {
594 pub sum_type: TypeNodeId,
596 pub tag_index: usize,
598 pub payload_type: Option<TypeNodeId>,
600}
601
602pub type ConstructorEnv = HashMap<Symbol, ConstructorInfo>;
604
605enum FieldLookup {
607 Found(TypeNodeId),
609 RecordWithoutField,
611 NotRecord,
613}
614
615#[derive(Clone, Debug)]
616pub struct InferContext {
617 interm_idx: IntermediateId,
618 typescheme_idx: TypeSchemeId,
619 level: u64,
620 stage: EvalStage,
621 instantiated_map: BTreeMap<TypeSchemeId, TypeNodeId>, generalize_map: BTreeMap<IntermediateId, TypeSchemeId>,
623 result_memo: BTreeMap<ExprKey, TypeNodeId>,
624 explicit_type_param_scopes: Vec<BTreeMap<Symbol, TypeNodeId>>,
625 file_path: PathBuf,
626 pub env: Environment<(TypeNodeId, EvalStage)>,
627 pub constructor_env: ConstructorEnv,
629 pub type_aliases: HashMap<Symbol, TypeNodeId>,
631 module_info: Option<crate::ast::program::ModuleInfo>,
633 match_expressions: Vec<(ExprNodeId, TypeNodeId)>,
635 pub errors: Vec<Error>,
636 pub infer_root_id: usize,
638}
639struct TypeCycle(pub Vec<Symbol>);
640
641impl InferContext {
642 pub fn new(
643 builtins: &[(Symbol, TypeNodeId)],
644 file_path: PathBuf,
645 type_declarations: Option<&crate::ast::program::TypeDeclarationMap>,
646 type_aliases: Option<&crate::ast::program::TypeAliasMap>,
647 module_info: Option<crate::ast::program::ModuleInfo>,
648 ) -> Self {
649 let mut res = Self {
650 interm_idx: Default::default(),
651 typescheme_idx: Default::default(),
652 level: Default::default(),
653 stage: EvalStage::Stage(0), instantiated_map: Default::default(),
655 generalize_map: Default::default(),
656 result_memo: Default::default(),
657 explicit_type_param_scopes: Default::default(),
658 file_path,
659 env: Environment::<(TypeNodeId, EvalStage)>::default(),
660 constructor_env: Default::default(),
661 type_aliases: Default::default(),
662 module_info,
663 match_expressions: Default::default(),
664 errors: Default::default(),
665 infer_root_id: usize::MAX,
666 };
667 res.env.extend();
668 let intrinsics = Self::intrinsic_types()
670 .into_iter()
671 .map(|(name, ty)| (name, (ty, EvalStage::Persistent)))
672 .collect::<Vec<_>>();
673 res.env.add_bind(&intrinsics);
674 let builtins = builtins
676 .iter()
677 .map(|(name, ty)| (*name, (*ty, EvalStage::Persistent)))
678 .collect::<Vec<_>>();
679 res.env.add_bind(&builtins);
680 if let Some(type_decls) = type_declarations {
682 res.register_type_declarations(type_decls);
683 }
684 if let Some(type_aliases) = type_aliases {
686 res.register_type_aliases(type_aliases);
687 }
688 res
689 }
690
691 fn is_explicit_type_param_name(name: Symbol) -> bool {
692 let s = name.as_str();
693 s.len() == 1 && s.as_bytes()[0].is_ascii_lowercase()
694 }
695
696 fn collect_explicit_type_params_in_type(ty: TypeNodeId, out: &mut BTreeMap<Symbol, Location>) {
697 match ty.to_type() {
698 Type::TypeAlias(name) if Self::is_explicit_type_param_name(name) => {
699 out.entry(name).or_insert_with(|| ty.to_loc());
700 }
701 Type::Array(elem) | Type::Ref(elem) | Type::Code(elem) | Type::Boxed(elem) => {
702 Self::collect_explicit_type_params_in_type(elem, out);
703 }
704 Type::Tuple(elems) | Type::Union(elems) => elems
705 .iter()
706 .for_each(|elem| Self::collect_explicit_type_params_in_type(*elem, out)),
707 Type::Record(fields) => fields
708 .iter()
709 .for_each(|field| Self::collect_explicit_type_params_in_type(field.ty, out)),
710 Type::Function { arg, ret } => {
711 Self::collect_explicit_type_params_in_type(arg, out);
712 Self::collect_explicit_type_params_in_type(ret, out);
713 }
714 _ => {}
715 }
716 }
717
718 fn with_explicit_type_param_scope_from_types<T>(
719 &mut self,
720 types: &[TypeNodeId],
721 f: impl FnOnce(&mut Self) -> T,
722 ) -> T {
723 let mut collected = BTreeMap::<Symbol, Location>::new();
724 types
725 .iter()
726 .for_each(|ty| Self::collect_explicit_type_params_in_type(*ty, &mut collected));
727 let map = collected
728 .into_iter()
729 .map(|(name, loc)| {
730 let ty = self
731 .lookup_explicit_type_param(name)
732 .unwrap_or_else(|| self.gen_typescheme(loc));
733 (name, ty)
734 })
735 .collect::<BTreeMap<_, _>>();
736 self.explicit_type_param_scopes.push(map);
737 let res = f(self);
738 let _ = self.explicit_type_param_scopes.pop();
739 res
740 }
741
742 fn lookup_explicit_type_param(&self, name: Symbol) -> Option<TypeNodeId> {
743 self.explicit_type_param_scopes
744 .iter()
745 .rev()
746 .find_map(|scope| scope.get(&name).copied())
747 }
748
749 fn register_type_declarations(
752 &mut self,
753 type_declarations: &crate::ast::program::TypeDeclarationMap,
754 ) {
755 let mut sum_types: std::collections::HashMap<Symbol, TypeNodeId> =
758 std::collections::HashMap::new();
759
760 for (type_name, decl_info) in type_declarations {
761 let variants = &decl_info.variants;
762 let variant_data: Vec<(Symbol, Option<TypeNodeId>)> =
763 variants.iter().map(|v| (v.name, v.payload)).collect();
764
765 let sum_type = Type::UserSum {
766 name: *type_name,
767 variants: variant_data.clone(),
768 }
769 .into_id();
770
771 sum_types.insert(*type_name, sum_type);
772 self.env
774 .add_bind(&[(*type_name, (sum_type, EvalStage::Persistent))]);
775 }
776
777 for (type_name, decl_info) in type_declarations {
779 if !decl_info.is_recursive {
780 continue;
781 }
782
783 let variants = &decl_info.variants;
784 let sum_type_id = sum_types[type_name];
785
786 let variant_data: Vec<(Symbol, Option<TypeNodeId>)> = variants
788 .iter()
789 .map(|v| {
790 let wrapped_payload = v.payload.map(|payload_type| {
791 Self::wrap_recursive_refs_static(payload_type, *type_name, sum_type_id)
792 });
793 (v.name, wrapped_payload)
794 })
795 .collect();
796
797 let new_sum_type = Type::UserSum {
799 name: *type_name,
800 variants: variant_data.clone(),
801 }
802 .into_id();
803
804 self.env
806 .add_bind(&[(*type_name, (new_sum_type, EvalStage::Persistent))]);
807
808 for (tag_index, (variant_name, payload_type)) in variant_data.iter().enumerate() {
810 self.constructor_env.insert(
811 *variant_name,
812 ConstructorInfo {
813 sum_type: new_sum_type,
814 tag_index,
815 payload_type: *payload_type,
816 },
817 );
818 }
819 }
820
821 for (type_name, decl_info) in type_declarations {
823 if decl_info.is_recursive {
824 continue;
825 }
826
827 let sum_type = sum_types[type_name];
828 let variants = &decl_info.variants;
829
830 for (tag_index, variant) in variants.iter().enumerate() {
831 self.constructor_env.insert(
832 variant.name,
833 ConstructorInfo {
834 sum_type,
835 tag_index,
836 payload_type: variant.payload,
837 },
838 );
839 }
840 }
841
842 self.check_type_declaration_recursion(type_declarations);
844 }
845
846 fn wrap_recursive_refs_static(
850 ty: TypeNodeId,
851 self_name: Symbol,
852 sum_type_id: TypeNodeId,
853 ) -> TypeNodeId {
854 match ty.to_type() {
855 Type::TypeAlias(name) if name == self_name => {
856 Type::Boxed(sum_type_id).into_id()
858 }
859 Type::Tuple(elements) => {
860 let wrapped_elements: Vec<TypeNodeId> = elements
862 .iter()
863 .map(|&elem| Self::wrap_recursive_refs_static(elem, self_name, sum_type_id))
864 .collect();
865 Type::Tuple(wrapped_elements).into_id()
866 }
867 Type::Record(fields) => {
868 let wrapped_fields: Vec<RecordTypeField> = fields
870 .iter()
871 .map(|field| RecordTypeField {
872 key: field.key,
873 ty: Self::wrap_recursive_refs_static(field.ty, self_name, sum_type_id),
874 has_default: field.has_default,
875 })
876 .collect();
877 Type::Record(wrapped_fields).into_id()
878 }
879 Type::Union(elements) => {
880 let wrapped_elements: Vec<TypeNodeId> = elements
882 .iter()
883 .map(|&elem| Self::wrap_recursive_refs_static(elem, self_name, sum_type_id))
884 .collect();
885 Type::Union(wrapped_elements).into_id()
886 }
887 _ => ty,
889 }
890 }
891
892 fn check_type_declaration_recursion(
895 &mut self,
896 type_declarations: &crate::ast::program::TypeDeclarationMap,
897 ) {
898 for (type_name, decl_info) in type_declarations {
899 if decl_info.is_recursive {
901 continue;
902 }
903 if let Some(location) =
904 self.is_type_declaration_recursive(*type_name, &decl_info.variants)
905 {
906 self.errors.push(Error::RecursiveTypeAlias {
907 type_name: *type_name,
908 cycle: vec![*type_name],
909 location,
910 });
911 }
912 }
913 }
914
915 fn is_type_declaration_recursive(
918 &self,
919 type_name: Symbol,
920 variants: &[crate::ast::program::VariantDef],
921 ) -> Option<Location> {
922 variants.iter().find_map(|variant| {
923 variant
924 .payload
925 .filter(|&payload_type| self.type_references_name(payload_type, type_name))
926 .map(|payload_type| payload_type.to_loc())
927 })
928 }
929
930 fn type_references_name(&self, type_id: TypeNodeId, target_name: Symbol) -> bool {
932 match type_id.to_type() {
933 Type::TypeAlias(name) if name == target_name => true,
934 Type::TypeAlias(name) => {
935 if let Some(resolved_type) = self.type_aliases.get(&name) {
937 self.type_references_name(*resolved_type, target_name)
938 } else {
939 false
940 }
941 }
942 Type::Function { arg, ret } => {
943 self.type_references_name(arg, target_name)
944 || self.type_references_name(ret, target_name)
945 }
946 Type::Tuple(elements) | Type::Union(elements) => elements
947 .iter()
948 .any(|t| self.type_references_name(*t, target_name)),
949 Type::Array(elem) | Type::Code(elem) => self.type_references_name(elem, target_name),
950 Type::Boxed(inner) => self.type_references_name(inner, target_name),
951 Type::Record(fields) => fields
952 .iter()
953 .any(|f| self.type_references_name(f.ty, target_name)),
954 Type::UserSum { name, .. } if name == target_name => true,
955 Type::UserSum { variants, .. } => variants
956 .iter()
957 .filter_map(|(_, payload)| *payload)
958 .any(|p| self.type_references_name(p, target_name)),
959 _ => false,
960 }
961 }
962
963 fn register_type_aliases(&mut self, type_aliases: &crate::ast::program::TypeAliasMap) {
965 for (alias_name, target_type) in type_aliases {
967 self.type_aliases.insert(*alias_name, *target_type);
968 self.env
970 .add_bind(&[(*alias_name, (*target_type, EvalStage::Persistent))]);
971 }
972
973 self.check_type_alias_cycles(type_aliases);
975 }
976
977 fn check_type_alias_cycles(&mut self, type_aliases: &TypeAliasMap) {
979 let errors: Vec<_> = type_aliases
980 .iter()
981 .filter_map(|(alias_name, target_type)| {
982 Self::detect_type_alias_cycle(*alias_name, type_aliases).map(|cycle| {
983 Error::RecursiveTypeAlias {
984 type_name: *alias_name,
985 cycle,
986 location: target_type.to_loc(),
987 }
988 })
989 })
990 .collect();
991
992 self.errors.extend(errors);
993 }
994
995 fn detect_type_alias_cycle(start: Symbol, type_aliases: &TypeAliasMap) -> Option<Vec<Symbol>> {
998 Self::detect_cycle_helper(start, vec![], type_aliases).map(|t| t.0)
999 }
1000
1001 fn detect_cycle_helper(
1003 current: Symbol,
1004 path: Vec<Symbol>,
1005 type_aliases: &TypeAliasMap,
1006 ) -> Option<TypeCycle> {
1007 if let Some(cycle_start) = path.iter().position(|&s| s == current) {
1009 return Some(TypeCycle(path[cycle_start..].to_vec()));
1010 }
1011
1012 let new_path = [path, vec![current]].concat();
1013
1014 type_aliases.get(¤t).and_then(|target_type| {
1015 Self::find_type_aliases_in_type(*target_type)
1016 .into_iter()
1017 .find_map(|ref_alias| {
1018 Self::detect_cycle_helper(ref_alias, new_path.clone(), type_aliases)
1019 })
1020 })
1021 }
1022
1023 fn find_type_aliases_in_type(type_id: TypeNodeId) -> Vec<Symbol> {
1025 match type_id.to_type() {
1026 Type::TypeAlias(name) => vec![name],
1027 Type::Function { arg, ret } => {
1028 let mut aliases = Self::find_type_aliases_in_type(arg);
1029 aliases.extend(Self::find_type_aliases_in_type(ret));
1030 aliases
1031 }
1032 Type::Tuple(elements) | Type::Union(elements) => elements
1033 .iter()
1034 .flat_map(|t| Self::find_type_aliases_in_type(*t))
1035 .collect(),
1036 Type::Array(elem) | Type::Code(elem) => Self::find_type_aliases_in_type(elem),
1037 Type::Record(fields) => fields
1038 .iter()
1039 .flat_map(|f| Self::find_type_aliases_in_type(f.ty))
1040 .collect(),
1041 Type::UserSum { variants, .. } => variants
1042 .iter()
1043 .filter_map(|(_, payload)| *payload)
1044 .flat_map(Self::find_type_aliases_in_type)
1045 .collect(),
1046 _ => vec![],
1047 }
1048 }
1049
1050 pub fn resolve_type_alias(&self, type_id: TypeNodeId) -> TypeNodeId {
1052 match type_id.to_type() {
1053 Type::TypeAlias(alias_name) => {
1054 let resolved_alias_name = self.resolve_type_alias_symbol_fallback(alias_name);
1055 if let Some(resolved_type) = self.type_aliases.get(&resolved_alias_name) {
1056 self.resolve_type_alias(*resolved_type)
1058 } else {
1059 type_id }
1061 }
1062 _ => type_id.apply_fn(|t| self.resolve_type_alias(t)),
1063 }
1064 }
1065}
1066impl InferContext {
1067 const TUPLE_BINOP_MAX_ARITY: usize = 16;
1068
1069 fn intrinsic_types() -> Vec<(Symbol, TypeNodeId)> {
1070 let binop_ty = function!(vec![numeric!(), numeric!()], numeric!());
1071 let binop_names = [
1072 intrinsics::ADD,
1073 intrinsics::SUB,
1074 intrinsics::MULT,
1075 intrinsics::DIV,
1076 intrinsics::MODULO,
1077 intrinsics::POW,
1078 intrinsics::GT,
1079 intrinsics::LT,
1080 intrinsics::GE,
1081 intrinsics::LE,
1082 intrinsics::EQ,
1083 intrinsics::NE,
1084 intrinsics::AND,
1085 intrinsics::OR,
1086 ];
1087 let uniop_ty = function!(vec![numeric!()], numeric!());
1088 let uniop_names = [
1089 intrinsics::NEG,
1090 intrinsics::MEM,
1091 intrinsics::SIN,
1092 intrinsics::COS,
1093 intrinsics::ABS,
1094 intrinsics::LOG,
1095 intrinsics::SQRT,
1096 ];
1097
1098 let binds = binop_names.map(|n| (n.to_symbol(), binop_ty));
1099 let unibinds = uniop_names.map(|n| (n.to_symbol(), uniop_ty));
1100 [
1101 (
1102 intrinsics::DELAY.to_symbol(),
1103 function!(vec![numeric!(), numeric!(), numeric!()], numeric!()),
1104 ),
1105 (
1106 intrinsics::TOFLOAT.to_symbol(),
1107 function!(vec![integer!()], numeric!()),
1108 ),
1109 ]
1110 .into_iter()
1111 .chain(binds)
1112 .chain(unibinds)
1113 .collect()
1114 }
1115
1116 fn is_tuple_arithmetic_binop_label(label: Symbol) -> bool {
1117 matches!(
1118 label.as_str(),
1119 intrinsics::ADD | intrinsics::SUB | intrinsics::MULT | intrinsics::DIV
1120 )
1121 }
1122
1123 fn try_get_tuple_arithmetic_binop_label(&self, fun: ExprNodeId) -> Option<Symbol> {
1124 match fun.to_expr() {
1125 Expr::Var(name) if Self::is_tuple_arithmetic_binop_label(name) => Some(name),
1126 _ => None,
1127 }
1128 }
1129
1130 fn resolve_for_tuple_binop(&self, ty: TypeNodeId) -> TypeNodeId {
1131 let resolved_alias = self.resolve_type_alias(ty);
1132 Self::substitute_type(resolved_alias)
1133 }
1134
1135 fn type_loc_or_expr_loc(&self, ty: TypeNodeId, expr_loc: &Location) -> Location {
1136 let ty_loc = ty.to_loc();
1137 if ty_loc.path.as_os_str().is_empty() {
1138 expr_loc.clone()
1139 } else {
1140 ty_loc
1141 }
1142 }
1143
1144 fn is_numeric_scalar_for_tuple_binop(&self, ty: TypeNodeId) -> bool {
1145 matches!(
1146 self.resolve_for_tuple_binop(ty).to_type(),
1147 Type::Primitive(PType::Numeric) | Type::Primitive(PType::Int)
1148 )
1149 }
1150
1151 fn make_tuple_binop_arity_error(&self, actual_arity: usize, loc: &Location) -> Error {
1152 Error::TypeMismatch {
1153 left: (
1154 Type::Tuple(vec![numeric!(); Self::TUPLE_BINOP_MAX_ARITY])
1155 .into_id_with_location(loc.clone()),
1156 loc.clone(),
1157 ),
1158 right: (
1159 Type::Tuple(vec![numeric!(); actual_arity]).into_id_with_location(loc.clone()),
1160 loc.clone(),
1161 ),
1162 }
1163 }
1164
1165 fn infer_tuple_arithmetic_binop_type_rec(
1166 &mut self,
1167 lhs_ty: TypeNodeId,
1168 rhs_ty: TypeNodeId,
1169 loc: &Location,
1170 errs: &mut Vec<Error>,
1171 ) -> Option<TypeNodeId> {
1172 let lhs_resolved = self.resolve_for_tuple_binop(lhs_ty);
1173 let rhs_resolved = self.resolve_for_tuple_binop(rhs_ty);
1174
1175 match (lhs_resolved.to_type(), rhs_resolved.to_type()) {
1176 (Type::Tuple(lhs_elems), Type::Tuple(rhs_elems)) => {
1177 if lhs_elems.len() != rhs_elems.len() {
1178 errs.push(Error::TypeMismatch {
1179 left: (lhs_ty, loc.clone()),
1180 right: (rhs_ty, loc.clone()),
1181 });
1182 return None;
1183 }
1184 if lhs_elems.len() > Self::TUPLE_BINOP_MAX_ARITY {
1185 errs.push(self.make_tuple_binop_arity_error(lhs_elems.len(), loc));
1186 return None;
1187 }
1188
1189 let result_elems = lhs_elems
1190 .iter()
1191 .zip(rhs_elems.iter())
1192 .filter_map(|(lt, rt)| {
1193 self.infer_tuple_arithmetic_binop_type_rec(*lt, *rt, loc, errs)
1194 })
1195 .collect::<Vec<_>>();
1196
1197 if result_elems.len() != lhs_elems.len() {
1198 None
1199 } else {
1200 Some(Type::Tuple(result_elems).into_id_with_location(loc.clone()))
1201 }
1202 }
1203 (Type::Tuple(tuple_elems), _) => {
1204 if tuple_elems.len() > Self::TUPLE_BINOP_MAX_ARITY {
1205 errs.push(self.make_tuple_binop_arity_error(tuple_elems.len(), loc));
1206 return None;
1207 }
1208 if !self.is_numeric_scalar_for_tuple_binop(rhs_ty) {
1209 let rhs_loc = self.type_loc_or_expr_loc(rhs_ty, loc);
1210 errs.push(Error::TypeMismatch {
1211 left: (numeric!(), rhs_loc.clone()),
1212 right: (rhs_ty, rhs_loc),
1213 });
1214 return None;
1215 }
1216
1217 let result_elems = tuple_elems
1218 .iter()
1219 .filter_map(|elem_ty| {
1220 self.infer_tuple_arithmetic_binop_type_rec(*elem_ty, rhs_ty, loc, errs)
1221 })
1222 .collect::<Vec<_>>();
1223
1224 if result_elems.len() != tuple_elems.len() {
1225 None
1226 } else {
1227 Some(Type::Tuple(result_elems).into_id_with_location(loc.clone()))
1228 }
1229 }
1230 (_, Type::Tuple(tuple_elems)) => {
1231 if tuple_elems.len() > Self::TUPLE_BINOP_MAX_ARITY {
1232 errs.push(self.make_tuple_binop_arity_error(tuple_elems.len(), loc));
1233 return None;
1234 }
1235 if !self.is_numeric_scalar_for_tuple_binop(lhs_ty) {
1236 let lhs_loc = self.type_loc_or_expr_loc(lhs_ty, loc);
1237 errs.push(Error::TypeMismatch {
1238 left: (numeric!(), lhs_loc.clone()),
1239 right: (lhs_ty, lhs_loc),
1240 });
1241 return None;
1242 }
1243
1244 let result_elems = tuple_elems
1245 .iter()
1246 .filter_map(|elem_ty| {
1247 self.infer_tuple_arithmetic_binop_type_rec(lhs_ty, *elem_ty, loc, errs)
1248 })
1249 .collect::<Vec<_>>();
1250
1251 if result_elems.len() != tuple_elems.len() {
1252 None
1253 } else {
1254 Some(Type::Tuple(result_elems).into_id_with_location(loc.clone()))
1255 }
1256 }
1257 _ => {
1258 let mut valid = true;
1259 if !self.is_numeric_scalar_for_tuple_binop(lhs_ty) {
1260 let lhs_loc = self.type_loc_or_expr_loc(lhs_ty, loc);
1261 errs.push(Error::TypeMismatch {
1262 left: (numeric!(), lhs_loc.clone()),
1263 right: (lhs_ty, lhs_loc),
1264 });
1265 valid = false;
1266 }
1267 if !self.is_numeric_scalar_for_tuple_binop(rhs_ty) {
1268 let rhs_loc = self.type_loc_or_expr_loc(rhs_ty, loc);
1269 errs.push(Error::TypeMismatch {
1270 left: (numeric!(), rhs_loc.clone()),
1271 right: (rhs_ty, rhs_loc),
1272 });
1273 valid = false;
1274 }
1275 if valid { Some(numeric!()) } else { None }
1276 }
1277 }
1278 }
1279
1280 fn infer_tuple_arithmetic_binop_type(
1281 &mut self,
1282 lhs_ty: TypeNodeId,
1283 rhs_ty: TypeNodeId,
1284 loc: Location,
1285 ) -> Result<TypeNodeId, Vec<Error>> {
1286 let mut errs = vec![];
1287 let result_ty = self.infer_tuple_arithmetic_binop_type_rec(lhs_ty, rhs_ty, &loc, &mut errs);
1288 if !errs.is_empty() {
1289 return Err(errs);
1290 }
1291 result_ty.ok_or_else(|| {
1292 vec![Error::TypeMismatch {
1293 left: (lhs_ty, loc.clone()),
1294 right: (rhs_ty, loc),
1295 }]
1296 })
1297 }
1298
1299 fn is_auto_spread_endpoint_type(&self, ty: TypeNodeId) -> bool {
1300 matches!(
1301 self.resolve_for_tuple_binop(ty).to_type(),
1302 Type::Primitive(PType::Numeric)
1303 | Type::Primitive(PType::Int)
1304 | Type::Intermediate(_)
1305 | Type::TypeScheme(_)
1306 | Type::Unknown
1307 | Type::Failure
1308 )
1309 }
1310
1311 fn auto_spread_param_endpoint_type(&self, param_ty: TypeNodeId) -> Option<TypeNodeId> {
1312 let resolved = self.resolve_for_tuple_binop(param_ty);
1313 match resolved.to_type() {
1314 Type::Record(fields) if fields.len() == 1 => Some(fields[0].ty),
1315 _ => Some(resolved),
1316 }
1317 }
1318
1319 fn is_numeric_to_numeric_function_for_auto_spread(&self, fn_ty: TypeNodeId) -> bool {
1320 let resolved = self.resolve_for_tuple_binop(fn_ty);
1321 matches!(
1322 resolved.to_type(),
1323 Type::Function { arg, ret }
1324 if self
1325 .auto_spread_param_endpoint_type(arg)
1326 .is_some_and(|endpoint| self.is_auto_spread_endpoint_type(endpoint))
1327 && self.is_auto_spread_endpoint_type(ret)
1328 )
1329 }
1330
1331 fn infer_auto_spread_type_rec(
1332 &mut self,
1333 arg_ty: TypeNodeId,
1334 loc: &Location,
1335 errs: &mut Vec<Error>,
1336 ) -> Option<TypeNodeId> {
1337 let resolved = self.resolve_for_tuple_binop(arg_ty);
1338 match resolved.to_type() {
1339 Type::Tuple(elems) => {
1340 if elems.len() > Self::TUPLE_BINOP_MAX_ARITY {
1341 errs.push(self.make_tuple_binop_arity_error(elems.len(), loc));
1342 return None;
1343 }
1344 let mapped = elems
1345 .iter()
1346 .filter_map(|elem_ty| self.infer_auto_spread_type_rec(*elem_ty, loc, errs))
1347 .collect::<Vec<_>>();
1348 if mapped.len() != elems.len() {
1349 None
1350 } else {
1351 Some(Type::Tuple(mapped).into_id_with_location(loc.clone()))
1352 }
1353 }
1354 _ => {
1355 if self.is_numeric_scalar_for_tuple_binop(arg_ty) {
1356 Some(numeric!())
1357 } else {
1358 let arg_loc = self.type_loc_or_expr_loc(arg_ty, loc);
1359 errs.push(Error::TypeMismatch {
1360 left: (numeric!(), arg_loc.clone()),
1361 right: (arg_ty, arg_loc),
1362 });
1363 None
1364 }
1365 }
1366 }
1367 }
1368
1369 fn infer_auto_spread_type(
1370 &mut self,
1371 fn_ty: TypeNodeId,
1372 arg_ty: TypeNodeId,
1373 loc: Location,
1374 ) -> Result<TypeNodeId, Vec<Error>> {
1375 let mut errs = vec![];
1376 let result_ty = self.infer_auto_spread_type_rec(arg_ty, &loc, &mut errs);
1377 if !errs.is_empty() {
1378 return Err(errs);
1379 }
1380 result_ty.ok_or_else(|| {
1381 vec![Error::TypeMismatch {
1382 left: (arg_ty, loc.clone()),
1383 right: (arg_ty, loc),
1384 }]
1385 })
1386 }
1387
1388 fn get_constructor_type_from_union(
1392 &self,
1393 union_ty: TypeNodeId,
1394 constructor_name: Symbol,
1395 ) -> TypeNodeId {
1396 if let Some(constructor_info) = self.constructor_env.get(&constructor_name) {
1399 return constructor_info.payload_type.unwrap_or_else(|| unit!());
1400 }
1401
1402 let resolved = Self::substitute_type(union_ty);
1403 match resolved.to_type() {
1404 Type::Union(variants) => {
1405 for variant_ty in variants.iter() {
1407 let variant_resolved = Self::substitute_type(*variant_ty);
1408 let variant_name = Self::type_constructor_name(&variant_resolved.to_type());
1409 if variant_name == Some(constructor_name) {
1410 return *variant_ty;
1411 }
1412 }
1413 Type::Unknown.into_id_with_location(union_ty.to_loc())
1415 }
1416 Type::UserSum { name: _, variants } => {
1417 if let Some((_, payload_ty)) =
1419 variants.iter().find(|(name, _)| *name == constructor_name)
1420 {
1421 payload_ty.unwrap_or_else(|| unit!())
1423 } else {
1424 Type::Unknown.into_id_with_location(union_ty.to_loc())
1425 }
1426 }
1427 other => {
1429 let type_name = Self::type_constructor_name(&other);
1430 if type_name == Some(constructor_name) {
1431 resolved
1432 } else {
1433 Type::Unknown.into_id_with_location(union_ty.to_loc())
1434 }
1435 }
1436 }
1437 }
1438
1439 fn type_constructor_name(ty: &Type) -> Option<Symbol> {
1442 match ty {
1443 Type::Primitive(PType::Numeric) => Some("float".to_symbol()),
1444 Type::Primitive(PType::String) => Some("string".to_symbol()),
1445 Type::Primitive(PType::Int) => Some("int".to_symbol()),
1446 Type::Primitive(PType::Unit) => Some("unit".to_symbol()),
1447 _ => None,
1449 }
1450 }
1451
1452 fn add_pattern_bindings(&mut self, pattern: &crate::ast::MatchPattern, ty: TypeNodeId) {
1455 use crate::ast::MatchPattern;
1456 let resolved_ty = ty.get_root().to_type();
1458 match pattern {
1459 MatchPattern::Variable(var) => {
1460 self.env.add_bind(&[(*var, (ty, self.stage))]);
1461 }
1462 MatchPattern::Wildcard => {
1463 }
1465 MatchPattern::Literal(_) => {
1466 }
1468 MatchPattern::Tuple(patterns) => {
1469 if let Type::Tuple(elem_types) = resolved_ty {
1472 for (pat, elem_ty) in patterns.iter().zip(elem_types.iter()) {
1473 self.add_pattern_bindings(pat, *elem_ty);
1474 }
1475 } else {
1476 if patterns.len() == 1 {
1480 self.add_pattern_bindings(&patterns[0], ty);
1481 }
1482 }
1483 }
1484 MatchPattern::Constructor(_, inner) => {
1485 if let Some(inner_pat) = inner {
1487 self.add_pattern_bindings(inner_pat, ty);
1488 }
1489 }
1490 }
1491 }
1492
1493 fn check_pattern_against_type(
1496 &mut self,
1497 pattern: &crate::ast::MatchPattern,
1498 ty: TypeNodeId,
1499 loc: &Location,
1500 ) {
1501 use crate::ast::MatchPattern;
1502 match pattern {
1503 MatchPattern::Literal(lit) => {
1504 let pat_ty = match lit {
1506 crate::ast::Literal::Int(_) | crate::ast::Literal::Float(_) => {
1507 Type::Primitive(PType::Numeric).into_id_with_location(loc.clone())
1508 }
1509 _ => Type::Failure.into_id_with_location(loc.clone()),
1510 };
1511 let _ = self.unify_types(ty, pat_ty);
1512 }
1513 MatchPattern::Wildcard => {
1514 }
1516 MatchPattern::Variable(var) => {
1517 self.env.add_bind(&[(*var, (ty, self.stage))]);
1519 }
1520 MatchPattern::Constructor(constructor_name, inner) => {
1521 let binding_ty = self.get_constructor_type_from_union(ty, *constructor_name);
1523 if let Some(inner_pat) = inner {
1524 self.add_pattern_bindings(inner_pat, binding_ty);
1525 }
1526 }
1527 MatchPattern::Tuple(patterns) => {
1528 let resolved_ty = ty.get_root().to_type();
1530 if let Type::Tuple(elem_types) = resolved_ty {
1531 for (pat, elem_ty) in patterns.iter().zip(elem_types.iter()) {
1532 self.check_pattern_against_type(pat, *elem_ty, loc);
1533 }
1534 }
1535 }
1536 }
1537 }
1538
1539 fn unwrap_result(&mut self, res: Result<TypeNodeId, Vec<Error>>) -> TypeNodeId {
1540 match res {
1541 Ok(t) => t,
1542 Err(mut e) => {
1543 let loc = &e[0].get_labels()[0].0; self.errors.append(&mut e);
1545 Type::Failure.into_id_with_location(loc.clone())
1546 }
1547 }
1548 }
1549 fn get_typescheme(&mut self, tvid: IntermediateId, loc: Location) -> TypeNodeId {
1550 self.generalize_map.get(&tvid).cloned().map_or_else(
1551 || self.gen_typescheme(loc),
1552 |id| Type::TypeScheme(id).into_id(),
1553 )
1554 }
1555 fn gen_typescheme(&mut self, loc: Location) -> TypeNodeId {
1556 let res = Type::TypeScheme(self.typescheme_idx).into_id_with_location(loc);
1557 self.typescheme_idx.0 += 1;
1558 res
1559 }
1560
1561 fn gen_intermediate_type_with_location(&mut self, loc: Location) -> TypeNodeId {
1562 let res = Type::Intermediate(Arc::new(RwLock::new(TypeVar::new(
1563 self.interm_idx,
1564 self.level,
1565 ))))
1566 .into_id_with_location(loc);
1567 self.interm_idx.0 += 1;
1568 res
1569 }
1570
1571 fn resolve_type_alias_symbol_fallback(&self, name: Symbol) -> Symbol {
1572 if name.as_str().contains('$') {
1573 return name;
1574 }
1575
1576 if let Some(ref module_info) = self.module_info
1577 && let Some(mapped) = module_info.use_alias_map.get(&name)
1578 {
1579 return *mapped;
1580 }
1581
1582 if self.type_aliases.contains_key(&name) {
1583 return name;
1584 }
1585
1586 if let Some(ref module_info) = self.module_info
1588 && module_info.type_declarations.contains_key(&name)
1589 {
1590 return name;
1591 }
1592
1593 let suffix = format!("${}", name.as_str());
1595 let mut candidates: Vec<Symbol> = self
1596 .type_aliases
1597 .keys()
1598 .copied()
1599 .filter(|symbol| symbol.as_str().ends_with(&suffix))
1600 .collect();
1601
1602 if let Some(ref module_info) = self.module_info {
1603 candidates.extend(
1604 module_info
1605 .type_declarations
1606 .keys()
1607 .copied()
1608 .filter(|symbol| symbol.as_str().ends_with(&suffix)),
1609 );
1610 }
1611
1612 if candidates.len() == 1 {
1613 candidates[0]
1614 } else {
1615 name
1616 }
1617 }
1618
1619 fn convert_unknown_to_intermediate(&mut self, t: TypeNodeId, loc: Location) -> TypeNodeId {
1620 match t.to_type() {
1621 Type::Unknown => self.gen_intermediate_type_with_location(loc.clone()),
1622 Type::TypeAlias(name) => {
1623 if Self::is_explicit_type_param_name(name) {
1624 return self
1625 .lookup_explicit_type_param(name)
1626 .unwrap_or_else(|| self.gen_typescheme(loc.clone()));
1627 }
1628 let resolved_name = self.resolve_type_alias_symbol_fallback(name);
1629
1630 log::trace!(
1631 "Resolving TypeAlias: {} -> {}",
1632 name.as_str(),
1633 resolved_name.as_str()
1634 );
1635
1636 if let Some(ref module_info) = self.module_info
1638 && let Some(&is_public) = module_info.visibility_map.get(&resolved_name)
1639 && !is_public
1640 {
1641 let type_path: Vec<&str> = resolved_name.as_str().split('$').collect();
1643 if type_path.len() > 1 {
1644 let module_path: Vec<crate::interner::Symbol> = type_path
1646 [..type_path.len() - 1]
1647 .iter()
1648 .map(ToSymbol::to_symbol)
1649 .collect();
1650 let type_name = type_path.last().unwrap().to_symbol();
1651
1652 self.errors.push(Error::PrivateTypeAccess {
1654 module_path,
1655 type_name,
1656 location: loc.clone(),
1657 });
1658 }
1659 }
1660
1661 match self.lookup(resolved_name, loc.clone()) {
1663 Ok(resolved_ty) => {
1664 let resolved_ty = self.resolve_type_alias(resolved_ty);
1665 let resolved_ty =
1666 self.convert_unknown_to_intermediate(resolved_ty, loc.clone());
1667 log::trace!(
1668 "Resolved TypeAlias {} to {}",
1669 resolved_name.as_str(),
1670 resolved_ty.to_type()
1671 );
1672 resolved_ty
1673 }
1674 Err(_) => {
1675 log::warn!(
1676 "TypeAlias {} not found, treating as Unknown",
1677 resolved_name.as_str()
1678 );
1679 self.gen_intermediate_type_with_location(loc.clone())
1681 }
1682 }
1683 }
1684 _ => t.apply_fn(|t| self.convert_unknown_to_intermediate(t, loc.clone())),
1685 }
1686 }
1687
1688 fn provisional_lambda_function_type(
1689 &mut self,
1690 params: &[TypedId],
1691 rtype: Option<TypeNodeId>,
1692 loc: Location,
1693 ) -> TypeNodeId {
1694 let param_fields = params
1695 .iter()
1696 .map(|param| {
1697 let annotated_ty =
1698 self.convert_unknown_to_intermediate(param.ty, param.ty.to_loc());
1699 RecordTypeField {
1700 key: param.id,
1701 ty: self.resolve_type_alias(annotated_ty),
1702 has_default: param.default_value.is_some(),
1703 }
1704 })
1705 .collect::<Vec<_>>();
1706
1707 let arg_ty = match param_fields.len() {
1708 0 => Type::Primitive(PType::Unit).into_id_with_location(loc.clone()),
1709 1 => param_fields[0].ty,
1710 _ => Type::Record(param_fields).into_id_with_location(loc.clone()),
1711 };
1712
1713 let ret_ty = rtype
1714 .map(|ret| {
1715 let annotated_ret = self.convert_unknown_to_intermediate(ret, ret.to_loc());
1716 self.resolve_type_alias(annotated_ret)
1717 })
1718 .unwrap_or_else(|| self.gen_intermediate_type_with_location(loc.clone()));
1719
1720 Type::Function {
1721 arg: arg_ty,
1722 ret: ret_ty,
1723 }
1724 .into_id_with_location(loc)
1725 }
1726
1727 fn provisional_letrec_binding_type(
1728 &mut self,
1729 id: &TypedId,
1730 body: ExprNodeId,
1731 loc: Location,
1732 ) -> TypeNodeId {
1733 match body.to_expr() {
1734 Expr::Lambda(params, rtype, _) => {
1735 let has_explicit_lambda_signature =
1736 params.iter().any(|param| !matches!(param.ty.to_type(), Type::Unknown))
1737 || rtype.is_some();
1738
1739 if has_explicit_lambda_signature || matches!(id.ty.to_type(), Type::Unknown) {
1740 self.provisional_lambda_function_type(params.as_slice(), rtype, loc)
1741 } else {
1742 self.convert_unknown_to_intermediate(id.ty, id.ty.to_loc())
1743 }
1744 }
1745 _ if !matches!(id.ty.to_type(), Type::Unknown) => {
1746 self.convert_unknown_to_intermediate(id.ty, id.ty.to_loc())
1747 }
1748 _ => self.convert_unknown_to_intermediate(id.ty, id.ty.to_loc()),
1749 }
1750 }
1751
1752 fn is_public(&self, name: &Symbol) -> bool {
1754 let resolved_name = self.resolve_type_alias_symbol_fallback(*name);
1755 self.module_info
1756 .as_ref()
1757 .and_then(|info| info.visibility_map.get(&resolved_name))
1758 .is_some_and(|vis| *vis)
1759 }
1760
1761 fn is_private(&self, name: &Symbol) -> bool {
1762 !self.is_public(name)
1763 }
1764
1765 fn check_private_type_leak(&mut self, name: Symbol, ty: TypeNodeId, loc: Location) {
1767 if !self.is_public(&name) {
1769 return; }
1771
1772 if let Some(type_name) = self.contains_private_type(ty) {
1774 self.errors.push(Error::PrivateTypeLeak {
1775 function_name: name,
1776 private_type: type_name,
1777 location: loc,
1778 });
1779 }
1780 }
1781
1782 fn contains_private_type(&self, ty: TypeNodeId) -> Option<Symbol> {
1785 let resolved = Self::substitute_type(ty);
1786 match resolved.to_type() {
1787 Type::TypeAlias(name) => {
1788 if Self::is_explicit_type_param_name(name) {
1789 return None;
1790 }
1791 let resolved_name = self.resolve_type_alias_symbol_fallback(name);
1792 if self.is_private(&resolved_name) {
1794 return Some(resolved_name);
1795 }
1796
1797 let name_str = name.as_str();
1799 if name_str.contains("::") {
1800 let parts: Vec<&str> = name_str.split("::").collect();
1801 if parts.len() >= 2 {
1802 let module_path: Vec<Symbol> = parts[..parts.len() - 1]
1803 .iter()
1804 .map(|s| s.to_symbol())
1805 .collect();
1806 let type_name = parts[parts.len() - 1].to_symbol();
1807
1808 let module_path_str = module_path
1809 .iter()
1810 .map(|s| s.as_str())
1811 .collect::<Vec<_>>()
1812 .join("::");
1813 let mangled_name =
1814 format!("{}::{}", module_path_str, type_name.as_str()).to_symbol();
1815
1816 if self.is_private(&mangled_name) {
1817 return Some(type_name);
1818 }
1819 }
1820 }
1821 None
1822 }
1823 Type::Function { arg, ret } => {
1824 if let Some(private_type) = self.contains_private_type(arg) {
1826 return Some(private_type);
1827 }
1828 self.contains_private_type(ret)
1830 }
1831 Type::Tuple(ref elements) => {
1832 for elem_ty in elements.iter() {
1833 if let Some(private_type) = self.contains_private_type(*elem_ty) {
1834 return Some(private_type);
1835 }
1836 }
1837 None
1838 }
1839 Type::Array(elem_ty) => self.contains_private_type(elem_ty),
1840 Type::Record(ref fields) => {
1841 for field in fields.iter() {
1842 if let Some(private_type) = self.contains_private_type(field.ty) {
1843 return Some(private_type);
1844 }
1845 }
1846 None
1847 }
1848 Type::Union(ref variants) => {
1849 for variant_ty in variants.iter() {
1850 if let Some(private_type) = self.contains_private_type(*variant_ty) {
1851 return Some(private_type);
1852 }
1853 }
1854 None
1855 }
1856 Type::Ref(inner_ty) => self.contains_private_type(inner_ty),
1857 Type::Code(inner_ty) => self.contains_private_type(inner_ty),
1858 Type::Boxed(inner_ty) => self.contains_private_type(inner_ty),
1859 Type::UserSum { name, variants } => {
1860 if self.is_private(&name) {
1862 return Some(name);
1863 }
1864
1865 for (_variant_name, payload_ty_opt) in variants.iter() {
1867 if let Some(payload_ty) = payload_ty_opt
1868 && let Some(private_type) = self.contains_private_type(*payload_ty)
1869 {
1870 return Some(private_type);
1871 }
1872 }
1873 None
1874 }
1875 Type::Intermediate(_)
1876 | Type::Primitive(_)
1877 | Type::TypeScheme(_)
1878 | Type::Any
1879 | Type::Failure
1880 | Type::Unknown => None,
1881 }
1882 }
1883
1884 fn convert_unify_error(&self, e: UnificationError) -> Error {
1885 let gen_loc = |span| Location::new(span, self.file_path.clone());
1886 match e {
1887 UnificationError::TypeMismatch {
1888 left: (left, lspan),
1889 right: (right, rspan),
1890 } => Error::TypeMismatch {
1891 left: (left, gen_loc(lspan)),
1892 right: (right, gen_loc(rspan)),
1893 },
1894 UnificationError::LengthMismatch {
1895 left: (left, lspan),
1896 right: (right, rspan),
1897 } => Error::LengthMismatch {
1898 left: (left.len(), gen_loc(lspan)),
1899 right: (right.len(), gen_loc(rspan)),
1900 },
1901 UnificationError::CircularType { left, right } => {
1902 Error::CircularType(gen_loc(left), gen_loc(right))
1903 }
1904 UnificationError::ImcompatibleRecords {
1905 left: (left, lspan),
1906 right: (right, rspan),
1907 } => Error::IncompatibleKeyInRecord {
1908 left: (left, gen_loc(lspan)),
1909 right: (right, gen_loc(rspan)),
1910 },
1911 }
1912 }
1913 fn unify_types(&self, t1: TypeNodeId, t2: TypeNodeId) -> Result<Relation, Vec<Error>> {
1914 let resolved_t1 = self.resolve_type_alias(t1);
1916 let resolved_t2 = self.resolve_type_alias(t2);
1917
1918 unify_types(resolved_t1, resolved_t2)
1919 .map_err(|e| e.into_iter().map(|e| self.convert_unify_error(e)).collect())
1920 }
1921 fn merge_rel_result(
1923 &self,
1924 rel1: Result<Relation, Vec<Error>>,
1925 rel2: Result<Relation, Vec<Error>>,
1926 t1: TypeNodeId,
1927 t2: TypeNodeId,
1928 ) -> Result<(), Vec<Error>> {
1929 match (rel1, rel2) {
1930 (Ok(Relation::Identical), Ok(Relation::Identical)) => Ok(()),
1931 (Ok(_), Ok(_)) => Err(vec![Error::TypeMismatch {
1932 left: (t1, Location::new(t1.to_span(), self.file_path.clone())),
1933 right: (t2, Location::new(t2.to_span(), self.file_path.clone())),
1934 }]),
1935 (Err(e1), Err(e2)) => Err(e1.into_iter().chain(e2).collect()),
1936 (Err(e), _) | (_, Err(e)) => Err(e),
1937 }
1938 }
1939 pub fn substitute_type(t: TypeNodeId) -> TypeNodeId {
1940 match t.to_type() {
1941 Type::Intermediate(cell) => {
1942 let TypeVar { parent, .. } = &*cell.read().unwrap() as &TypeVar;
1943 match parent {
1944 Some(p) => Self::substitute_type(*p),
1945 None => Type::Unknown.into_id_with_location(t.to_loc()),
1946 }
1947 }
1948 _ => t.apply_fn(Self::substitute_type),
1949 }
1950 }
1951 fn substitute_all_intermediates(&mut self) {
1952 let mut e_list = self
1953 .result_memo
1954 .iter()
1955 .map(|(e, t)| (*e, Self::substitute_type(*t)))
1956 .collect::<Vec<_>>();
1957
1958 e_list.iter_mut().for_each(|(e, t)| {
1959 log::trace!("e: {:?} t: {}", e, t.to_type());
1960 let _old = self.result_memo.insert(*e, *t);
1961 })
1962 }
1963
1964 fn generalize(&mut self, t: TypeNodeId) -> TypeNodeId {
1965 match t.to_type() {
1966 Type::Intermediate(tvar) => {
1967 let &TypeVar { level, var, .. } = &*tvar.read().unwrap() as &TypeVar;
1968 if level > self.level {
1969 self.get_typescheme(var, t.to_loc())
1970 } else {
1971 t
1972 }
1973 }
1974 _ => t.apply_fn(|t| self.generalize(t)),
1975 }
1976 }
1977
1978 fn instantiate(&mut self, t: TypeNodeId) -> TypeNodeId {
1979 match t.to_type() {
1980 Type::TypeScheme(id) => {
1981 log::debug!("instantiate typescheme id: {id:?}");
1982 if let Some(tvar) = self.instantiated_map.get(&id) {
1983 *tvar
1984 } else {
1985 let res = self.gen_intermediate_type_with_location(t.to_loc());
1986 self.instantiated_map.insert(id, res);
1987 res
1988 }
1989 }
1990 _ => t.apply_fn(|t| self.instantiate(t)),
1991 }
1992 }
1993
1994 fn instantiate_fresh(&mut self, t: TypeNodeId) -> TypeNodeId {
1995 self.instantiated_map.clear();
1996 let res = self.instantiate(t);
1997 self.instantiated_map.clear();
1998 res
1999 }
2000
2001 fn bind_pattern(
2006 &mut self,
2007 pat: (TypedPattern, Location),
2008 body: (TypeNodeId, Location),
2009 ) -> Result<TypeNodeId, Vec<Error>> {
2010 let (TypedPattern { pat, ty, .. }, loc_p) = pat;
2011 let (body_t, loc_b) = body.clone();
2012 let should_generalize =
2013 !matches!(&pat, Pattern::Single(id) if *id == "record_update_temp".to_symbol());
2014 let mut bind_item = |pat| {
2015 let newloc = ty.to_loc();
2016 let ity = self.gen_intermediate_type_with_location(newloc.clone());
2017 let p = TypedPattern::new(pat, ity);
2018 self.bind_pattern((p, newloc.clone()), (ity, newloc))
2019 };
2020 let pat_t = match pat {
2021 Pattern::Single(id) => {
2022 let pat_t = self.convert_unknown_to_intermediate(ty, loc_p);
2023 log::trace!("bind {} : {}", id, pat_t.to_type());
2024 self.env.add_bind(&[(id, (pat_t, self.stage))]);
2025 Ok::<TypeNodeId, Vec<Error>>(pat_t)
2026 }
2027 Pattern::Placeholder => {
2028 let pat_t = self.convert_unknown_to_intermediate(ty, loc_p);
2030 log::trace!("bind _ (placeholder) : {}", pat_t.to_type());
2031 Ok::<TypeNodeId, Vec<Error>>(pat_t)
2032 }
2033 Pattern::Tuple(pats) => {
2034 let elems = pats.iter().map(|p| bind_item(p.clone())).try_collect()?; let res = Type::Tuple(elems).into_id_with_location(loc_p);
2036 let target = self.convert_unknown_to_intermediate(ty, loc_b);
2037 let rel = self.unify_types(res, target)?;
2038 Ok(res)
2039 }
2040 Pattern::Record(items) => {
2041 let res = items
2042 .iter()
2043 .map(|(key, v)| {
2044 bind_item(v.clone()).map(|ty| RecordTypeField {
2045 key: *key,
2046 ty,
2047 has_default: false,
2048 })
2049 })
2050 .try_collect()?; let res = Type::Record(res).into_id_with_location(loc_p);
2052 let target = self.convert_unknown_to_intermediate(ty, loc_b);
2053 let rel = self.unify_types(res, target)?;
2054 Ok(res)
2055 }
2056 Pattern::Error => Err(vec![Error::PatternMismatch(
2057 (
2058 Type::Failure.into_id_with_location(loc_p.clone()),
2059 loc_b.clone(),
2060 ),
2061 (pat, loc_p.clone()),
2062 )]),
2063 }?;
2064 let rel = self.unify_types(pat_t, body_t)?;
2065 if should_generalize {
2066 Ok(self.generalize(pat_t))
2067 } else {
2068 Ok(pat_t)
2069 }
2070 }
2071
2072 pub fn lookup(&self, name: Symbol, loc: Location) -> Result<TypeNodeId, Error> {
2073 use crate::utils::environment::LookupRes;
2074 let lookup_res = self.env.lookup_cls(&name);
2075 match lookup_res {
2076 LookupRes::Local((ty, bound_stage)) if self.stage == *bound_stage => Ok(*ty),
2077 LookupRes::UpValue(_, (ty, bound_stage)) if self.stage == *bound_stage => Ok(*ty),
2078 LookupRes::Global((ty, bound_stage))
2079 if self.stage == *bound_stage || *bound_stage == EvalStage::Persistent =>
2080 {
2081 Ok(*ty)
2082 }
2083 LookupRes::None => Err(Error::VariableNotFound(name, loc)),
2084 LookupRes::Local((_, bound_stage))
2085 | LookupRes::UpValue(_, (_, bound_stage))
2086 | LookupRes::Global((_, bound_stage)) => Err(Error::StageMismatch {
2087 variable: name,
2088 expected_stage: self.stage,
2089 found_stage: *bound_stage,
2090 location: loc,
2091 }),
2092 }
2093 }
2094
2095 fn peel_to_inner(&self, ty: TypeNodeId) -> TypeNodeId {
2098 let resolved = self.resolve_type_alias(ty);
2099 match resolved.to_type() {
2100 Type::Intermediate(tv) => {
2101 let tv = tv.read().unwrap();
2102 let next = tv.parent.unwrap_or(tv.bound.lower);
2103 if next.0 == resolved.0 {
2104 resolved
2105 } else {
2106 self.peel_to_inner(next)
2107 }
2108 }
2109 Type::Tuple(elems) if elems.len() == 1 => self.peel_to_inner(elems[0]),
2110 _ => resolved,
2111 }
2112 }
2113
2114 fn lookup_field_in_type(&self, ty: TypeNodeId, field: Symbol) -> FieldLookup {
2117 let peeled = self.peel_to_inner(ty);
2118 match peeled.to_type() {
2119 Type::Record(fields) => fields
2120 .iter()
2121 .find(|f| f.key == field)
2122 .map(|f| FieldLookup::Found(f.ty))
2123 .unwrap_or(FieldLookup::RecordWithoutField),
2124 _ => FieldLookup::NotRecord,
2125 }
2126 }
2127
2128 fn infer_field_access(
2136 &mut self,
2137 et: TypeNodeId,
2138 field: Symbol,
2139 loc: Location,
2140 ) -> Result<TypeNodeId, Vec<Error>> {
2141 if let Type::Intermediate(tv) = et.to_type() {
2143 let is_unresolved = {
2144 let tv = tv.read().unwrap();
2145 let lower_is_record_like = match tv.bound.lower.to_type() {
2146 Type::Record(_) => true,
2147 Type::Tuple(elems) => elems.len() == 1,
2148 _ => false,
2149 };
2150 tv.parent.is_none() && !lower_is_record_like
2151 };
2152 if is_unresolved {
2153 let field_ty = self.gen_intermediate_type_with_location(loc.clone());
2154 let expected = Type::Record(vec![RecordTypeField {
2155 key: field,
2156 ty: field_ty,
2157 has_default: false,
2158 }])
2159 .into_id_with_location(loc);
2160 let _rel = self.unify_types(et, expected)?;
2161 return Ok(field_ty);
2162 }
2163 }
2164
2165 match self.lookup_field_in_type(et, field) {
2167 FieldLookup::Found(field_ty) => Ok(field_ty),
2168 FieldLookup::RecordWithoutField => self.extend_record_with_field(et, field, loc),
2169 FieldLookup::NotRecord => Err(vec![Error::FieldForNonRecord(loc, et)]),
2170 }
2171 }
2172
2173 fn extend_record_with_field(
2177 &mut self,
2178 et: TypeNodeId,
2179 field: Symbol,
2180 loc: Location,
2181 ) -> Result<TypeNodeId, Vec<Error>> {
2182 if let Type::Intermediate(tv) = et.to_type() {
2183 let existing_fields = {
2184 let tv = tv.read().unwrap();
2185 match tv.parent.map(|p| p.to_type()) {
2186 Some(Type::Record(fields)) => Some(fields),
2187 _ => match tv.bound.lower.to_type() {
2188 Type::Record(fields) => Some(fields),
2189 _ => None,
2190 },
2191 }
2192 };
2193 if let Some(mut fields) = existing_fields {
2194 let field_ty = self.gen_intermediate_type_with_location(loc.clone());
2195 if fields.iter().all(|f| f.key != field) {
2196 fields.push(RecordTypeField {
2197 key: field,
2198 ty: field_ty,
2199 has_default: false,
2200 });
2201 }
2202 let extended = Type::Record(fields).into_id_with_location(loc);
2203 {
2210 let mut guard = tv.write().unwrap();
2211 guard.parent = Some(extended);
2212 }
2213 return Ok(field_ty);
2214 }
2215 }
2216 Err(vec![Error::FieldNotExist { field, loc, et }])
2217 }
2218
2219 pub(crate) fn infer_type_literal(e: &Literal, loc: Location) -> Result<TypeNodeId, Error> {
2220 let pt = match e {
2221 Literal::Float(_) | Literal::Now | Literal::SampleRate => PType::Numeric,
2222 Literal::Int(_s) => PType::Int,
2223 Literal::String(_s) => PType::String,
2224 Literal::SelfLit => panic!("\"self\" should not be shown at type inference stage"),
2225 Literal::PlaceHolder => panic!("\"_\" should not be shown at type inference stage"),
2226 };
2227 Ok(Type::Primitive(pt).into_id_with_location(loc))
2228 }
2229 fn infer_vec(&mut self, e: &[ExprNodeId]) -> Result<Vec<TypeNodeId>, Vec<Error>> {
2230 e.iter().map(|e| self.infer_type(*e)).try_collect()
2231 }
2232 fn infer_type_levelup(&mut self, e: ExprNodeId) -> TypeNodeId {
2233 self.level += 1;
2234 let res = self.infer_type_unwrapping(e);
2235 self.level -= 1;
2236 res
2237 }
2238 pub fn infer_type(&mut self, e: ExprNodeId) -> Result<TypeNodeId, Vec<Error>> {
2239 if let Some(r) = self.result_memo.get(&e.0) {
2240 return Ok(*r);
2242 }
2243 let loc = e.to_location();
2244 let res: Result<TypeNodeId, Vec<Error>> = match &e.to_expr() {
2245 Expr::Literal(l) => Self::infer_type_literal(l, loc).map_err(|e| vec![e]),
2246 Expr::Tuple(e) => {
2247 Ok(Type::Tuple(self.infer_vec(e.as_slice())?).into_id_with_location(loc))
2248 }
2249 Expr::ArrayLiteral(e) => {
2250 let elem_types = self.infer_vec(e.as_slice())?;
2251 let first = elem_types
2252 .first()
2253 .copied()
2254 .unwrap_or_else(|| self.gen_intermediate_type_with_location(loc.clone()));
2255 let elem_t = elem_types
2257 .iter()
2258 .try_fold(first, |acc, t| self.unify_types(acc, *t).map(|rel| *t))?;
2259
2260 Ok(Type::Array(elem_t).into_id_with_location(loc.clone()))
2261 }
2262 Expr::ArrayAccess(e, idx) => {
2263 let arr_t = self.infer_type_unwrapping(*e);
2264 let loc_e = e.to_location();
2265 let idx_t = self.infer_type_unwrapping(*idx);
2266 let loc_i = idx.to_location();
2267
2268 let elem_t = self.gen_intermediate_type_with_location(loc_e.clone());
2269
2270 let rel1 = self.unify_types(
2271 idx_t,
2272 Type::Primitive(PType::Numeric).into_id_with_location(loc_i),
2273 );
2274 let rel2 = self.unify_types(
2275 Type::Array(elem_t).into_id_with_location(loc_e.clone()),
2276 arr_t,
2277 );
2278 self.merge_rel_result(rel1, rel2, arr_t, idx_t)?;
2279 Ok(elem_t)
2280 }
2281 Expr::Proj(e, idx) => {
2282 let tup = self.infer_type_unwrapping(*e);
2283 let vec_to_ans = |vec: &[_]| {
2287 if vec.len() < *idx as usize {
2288 Err(vec![Error::IndexOutOfRange {
2289 len: vec.len() as u16,
2290 idx: *idx as u16,
2291 loc: loc.clone(),
2292 }])
2293 } else {
2294 Ok(vec[*idx as usize])
2295 }
2296 };
2297 match tup.to_type() {
2298 Type::Tuple(vec) => vec_to_ans(&vec),
2299 Type::Intermediate(tv) => {
2300 let tv = tv.read().unwrap();
2301 if let Some(parent) = tv.parent {
2302 match parent.to_type() {
2303 Type::Tuple(vec) => vec_to_ans(&vec),
2304 _ => Err(vec![Error::IndexForNonTuple(loc, tup)]),
2305 }
2306 } else {
2307 Err(vec![Error::IndexForNonTuple(loc, tup)])
2308 }
2309 }
2310 _ => Err(vec![Error::IndexForNonTuple(loc, tup)]),
2311 }
2312 }
2313 Expr::RecordLiteral(kvs) => {
2314 let duplicate_keys = kvs
2315 .iter()
2316 .map(|RecordField { name, .. }| *name)
2317 .duplicates();
2318 if duplicate_keys.clone().count() > 0 {
2319 Err(vec![Error::DuplicateKeyInRecord {
2320 key: duplicate_keys.collect(),
2321 loc,
2322 }])
2323 } else {
2324 let kts: Vec<_> = kvs
2325 .iter()
2326 .map(|RecordField { name, expr }| {
2327 let ty = self.infer_type_unwrapping(*expr);
2328 RecordTypeField {
2329 key: *name,
2330 ty,
2331 has_default: true,
2332 }
2333 })
2334 .collect();
2335 Ok(Type::Record(kts).into_id_with_location(loc))
2336 }
2337 }
2338 Expr::RecordUpdate(_, _) => {
2339 unreachable!("RecordUpdate should be expanded before type inference")
2342 }
2343 Expr::FieldAccess(expr, field) => {
2344 let et = self.infer_type_unwrapping(*expr);
2345 log::trace!("field access {} : {}", field, et.to_type());
2346 self.infer_field_access(et, *field, loc)
2347 }
2348 Expr::Feed(id, body) => {
2349 let feedv = self.gen_intermediate_type_with_location(loc);
2351
2352 self.env.add_bind(&[(*id, (feedv, self.stage))]);
2353 let bty = self.infer_type_unwrapping(*body);
2354 let _rel = self.unify_types(bty, feedv)?;
2355 if bty.to_type().contains_function() {
2356 Err(vec![Error::NonPrimitiveInFeed(body.to_location())])
2357 } else {
2358 Ok(bty)
2359 }
2360 }
2361 Expr::Lambda(p, rtype, body) => {
2362 let mut scoped_types = p
2363 .iter()
2364 .map(|id| id.ty)
2365 .filter(|ty| ty.to_type() != Type::Unknown)
2366 .collect::<Vec<_>>();
2367 rtype.iter().copied().for_each(|ty| scoped_types.push(ty));
2368 self.with_explicit_type_param_scope_from_types(&scoped_types, |this| {
2369 this.env.extend();
2370 let lambda_res = (|| -> Result<TypeNodeId, Vec<Error>> {
2371 this.instantiated_map.clear();
2372 let dup = p.iter().duplicates_by(|id| id.id).map(|id| {
2373 let loc = Location::new(id.to_span(), this.file_path.clone());
2374 (id.id, loc)
2375 });
2376 if dup.clone().count() > 0 {
2377 return Err(vec![Error::DuplicateKeyInParams(dup.collect())]);
2378 }
2379 let pvec = p
2380 .iter()
2381 .map(|id| {
2382 let annotated_ty =
2383 this.convert_unknown_to_intermediate(id.ty, id.ty.to_loc());
2384 let annotated_ty = this.resolve_type_alias(annotated_ty);
2385 let ity = this.instantiate(annotated_ty);
2386 this.env.add_bind(&[(id.id, (ity, this.stage))]);
2387 RecordTypeField {
2388 key: id.id,
2389 ty: ity,
2390 has_default: id.default_value.is_some(),
2391 }
2392 })
2393 .collect::<Vec<_>>();
2394 let ptype = if pvec.is_empty() {
2395 Type::Primitive(PType::Unit).into_id_with_location(loc.clone())
2396 } else if pvec.len() == 1 {
2397 pvec[0].ty
2398 } else {
2399 Type::Record(pvec).into_id_with_location(loc.clone())
2400 };
2401 let bty = if let Some(r) = rtype {
2402 let annotated_ret =
2403 this.convert_unknown_to_intermediate(*r, r.to_loc());
2404 let annotated_ret = this.resolve_type_alias(annotated_ret);
2405 let expected_ret = this.instantiate(annotated_ret);
2406 let bty = this.infer_type_unwrapping(*body);
2407 let _rel = this.unify_types(expected_ret, bty)?;
2408 bty
2409 } else {
2410 this.infer_type_unwrapping(*body)
2411 };
2412 this.instantiated_map.clear();
2413 Ok(Type::Function {
2414 arg: ptype,
2415 ret: bty,
2416 }
2417 .into_id_with_location(e.to_location()))
2418 })();
2419 this.env.to_outer();
2420 this.instantiated_map.clear();
2421 lambda_res
2422 })
2423 }
2424 Expr::Let(tpat, body, then) => {
2425 let bodyt = self.infer_type_levelup(*body);
2426
2427 let loc_p = tpat.to_loc();
2428 let loc_b = body.to_location();
2429
2430 if let Pattern::Single(name) = &tpat.pat {
2433 log::trace!(
2434 "Checking private type leak for Let binding: {}",
2435 name.as_str()
2436 );
2437 log::trace!("Original type before resolution: {:?}", tpat.ty.to_type());
2438 self.check_private_type_leak(*name, tpat.ty, loc_p.clone());
2439 }
2440
2441 let pat_t = self.with_explicit_type_param_scope_from_types(&[tpat.ty], |this| {
2442 this.bind_pattern((tpat.clone(), loc_p), (bodyt, loc_b))
2443 });
2444 let _pat_t = self.unwrap_result(pat_t);
2445 match then {
2446 Some(e) => self.infer_type(*e),
2447 None => Ok(Type::Primitive(PType::Unit).into_id_with_location(loc)),
2448 }
2449 }
2450 Expr::LetRec(id, body, then) => {
2451 let body_expr = *body;
2452 let mut scoped_types = vec![id.ty];
2453 if let Expr::Lambda(params, rtype, _) = body_expr.to_expr() {
2454 params
2455 .iter()
2456 .filter(|param| param.ty.to_type() != Type::Unknown)
2457 .for_each(|param| scoped_types.push(param.ty));
2458 rtype.iter().copied().for_each(|ret| scoped_types.push(ret));
2459 }
2460
2461 self.with_explicit_type_param_scope_from_types(&scoped_types, |this| {
2462 let idt = this.provisional_letrec_binding_type(id, body_expr, loc.clone());
2463 this.env.add_bind(&[(id.id, (idt, this.stage))]);
2464 let bodyt = this.infer_type_levelup(body_expr);
2467
2468 let _res = this.unify_types(idt, bodyt);
2469
2470 this.check_private_type_leak(id.id, id.ty, loc.clone());
2472 });
2473
2474 match then {
2475 Some(e) => self.infer_type(*e),
2476 None => Ok(Type::Primitive(PType::Unit).into_id_with_location(loc)),
2477 }
2478 }
2479 Expr::Assign(assignee, expr) => {
2480 match assignee.to_expr() {
2481 Expr::Var(name) => {
2482 let assignee_t =
2483 self.unwrap_result(self.lookup(name, loc).map_err(|e| vec![e]));
2484 let e_t = self.infer_type_unwrapping(*expr);
2485 let _rel = self.unify_types(assignee_t, e_t)?;
2486 Ok(unit!())
2487 }
2488 Expr::FieldAccess(record, field_name) => {
2489 let _record_type = self.infer_type_unwrapping(record);
2491 let value_type = self.infer_type_unwrapping(*expr);
2492 let field_type = self.infer_type_unwrapping(*assignee);
2493 let _rel = self.unify_types(field_type, value_type)?;
2494 Ok(unit!())
2495 }
2496 Expr::ArrayAccess(_, _) => {
2497 unimplemented!("Assignment to array is not implemented yet.")
2498 }
2499 _ => {
2500 Err(vec![Error::VariableNotFound(
2502 "invalid_assignment_target".to_symbol(),
2503 loc.clone(),
2504 )])
2505 }
2506 }
2507 }
2508 Expr::Then(e, then) => {
2509 let _ = self.infer_type(*e)?;
2510 then.map_or(Ok(unit!()), |t| self.infer_type(t))
2511 }
2512 Expr::Var(name) => {
2513 if let Some(constructor_info) = self.constructor_env.get(name) {
2515 if let Some(payload_ty) = constructor_info.payload_type {
2516 let fn_type = Type::Function {
2518 arg: payload_ty,
2519 ret: constructor_info.sum_type,
2520 }
2521 .into_id_with_location(loc.clone());
2522 return Ok(fn_type);
2523 } else {
2524 return Ok(constructor_info.sum_type);
2526 }
2527 }
2528 let res = self.unwrap_result(self.lookup(*name, loc).map_err(|e| vec![e]));
2530 Ok(self.instantiate_fresh(res))
2531 }
2532 Expr::QualifiedVar(path) => {
2533 unreachable!("Qualified Var should be removed in the previous step.")
2534 }
2535 Expr::Apply(fun, callee) => {
2536 let loc_f = fun.to_location();
2537 if callee.len() == 2 && self.try_get_tuple_arithmetic_binop_label(*fun).is_some() {
2538 let lhs_ty = self.infer_type_unwrapping(callee[0]);
2539 let rhs_ty = self.infer_type_unwrapping(callee[1]);
2540 let lhs_is_tuple = matches!(
2541 self.resolve_for_tuple_binop(lhs_ty).to_type(),
2542 Type::Tuple(_)
2543 );
2544 let rhs_is_tuple = matches!(
2545 self.resolve_for_tuple_binop(rhs_ty).to_type(),
2546 Type::Tuple(_)
2547 );
2548 if lhs_is_tuple || rhs_is_tuple {
2549 return self.infer_tuple_arithmetic_binop_type(
2550 lhs_ty,
2551 rhs_ty,
2552 loc_f.clone(),
2553 );
2554 }
2555 }
2556
2557 if callee.len() == 1 {
2558 let fnl = self.infer_type_unwrapping(*fun);
2559 let arg_ty = self.infer_type_unwrapping(callee[0]);
2560 let arg_is_tuple = matches!(
2561 self.resolve_for_tuple_binop(arg_ty).to_type(),
2562 Type::Tuple(_)
2563 );
2564 if arg_is_tuple && self.is_numeric_to_numeric_function_for_auto_spread(fnl) {
2565 return self.infer_auto_spread_type(fnl, arg_ty, loc_f.clone());
2566 }
2567
2568 let try_record_default_pack = || -> Result<Option<TypeNodeId>, Vec<Error>> {
2569 let fn_ty = self.peel_to_inner(fnl);
2570 let arg_ty_resolved = self.peel_to_inner(arg_ty);
2571 let (fn_arg, fn_ret) = match fn_ty.to_type() {
2572 Type::Function { arg, ret } => (arg, ret),
2573 _ => return Ok(None),
2574 };
2575 let fn_arg_resolved = self.peel_to_inner(fn_arg);
2576 let (param_fields, provided_fields) =
2577 match (fn_arg_resolved.to_type(), arg_ty_resolved.to_type()) {
2578 (Type::Record(param_fields), Type::Record(provided_fields)) => {
2579 (param_fields, provided_fields)
2580 }
2581 _ => return Ok(None),
2582 };
2583
2584 let mut matched_any = false;
2585 for param in param_fields.iter() {
2586 if let Some(provided) =
2587 provided_fields.iter().find(|field| field.key == param.key)
2588 {
2589 matched_any = true;
2590 let _ = self.unify_types(param.ty, provided.ty)?;
2591 } else if !param.has_default {
2592 return Ok(None);
2593 }
2594 }
2595
2596 Ok(matched_any.then_some(fn_ret))
2597 };
2598
2599 if let Some(ret_ty) = try_record_default_pack()? {
2600 return Ok(ret_ty);
2601 }
2602 }
2603
2604 let fnl = self.infer_type_unwrapping(*fun);
2605 let callee_t = match callee.len() {
2606 0 => Type::Primitive(PType::Unit).into_id_with_location(loc.clone()),
2607 1 => self.infer_type_unwrapping(callee[0]),
2608 _ => {
2609 let at_vec = self.infer_vec(callee.as_slice())?;
2610
2611 let span = callee[0].to_span().start..callee.last().unwrap().to_span().end;
2612 let loc = Location::new(span, self.file_path.clone());
2613 Type::Tuple(at_vec).into_id_with_location(loc)
2614 }
2615 };
2616 let res_t = self.gen_intermediate_type_with_location(loc);
2617 let fntype = Type::Function {
2618 arg: callee_t,
2619 ret: res_t,
2620 }
2621 .into_id_with_location(loc_f.clone());
2622 match self.unify_types(fnl, fntype)? {
2623 Relation::Subtype => Err(vec![Error::NonSupertypeArgument {
2624 location: loc_f.clone(),
2625 expected: fnl,
2626 found: fntype,
2627 }]),
2628 _ => Ok(res_t),
2629 }
2630 }
2631 Expr::If(cond, then, opt_else) => {
2632 let condt = self.infer_type_unwrapping(*cond);
2633 let cond_loc = cond.to_location();
2634 let bt = self.unify_types(
2635 Type::Primitive(PType::Numeric).into_id_with_location(cond_loc),
2636 condt,
2637 )?; let thent = self.infer_type_unwrapping(*then);
2640 let elset = opt_else.map_or(Type::Primitive(PType::Unit).into_id(), |e| {
2641 self.infer_type_unwrapping(e)
2642 });
2643 let rel = self.unify_types(thent, elset)?;
2644 Ok(thent)
2645 }
2646 Expr::Block(expr) => expr.map_or(
2647 Ok(Type::Primitive(PType::Unit).into_id_with_location(loc)),
2648 |e| {
2649 self.env.extend(); let res = self.infer_type(e);
2651 self.env.to_outer();
2652 res
2653 },
2654 ),
2655 Expr::Escape(e) => {
2656 let loc_e = loc.clone();
2657 let prev_stage = self.stage;
2658 self.stage = prev_stage.decrement();
2660 log::trace!("Unstaging escape expression, stage => {:?}", self.stage);
2661 let res = self.infer_type_unwrapping(*e);
2662 self.stage = prev_stage;
2664 if matches!(res.get_root().to_type(), Type::Primitive(PType::Unit)) {
2665 return Ok(Type::Primitive(PType::Unit).into_id_with_location(loc_e));
2666 }
2667 if !matches!(res.get_root().to_type(), Type::Code(_))
2668 && res.get_root().to_type().contains_code()
2669 {
2670 return Err(vec![Error::EscapeRequiresCodeType {
2671 found: (res.get_root(), loc_e),
2672 }]);
2673 }
2674 let intermediate = self.gen_intermediate_type_with_location(loc_e.clone());
2675 let rel = self.unify_types(
2676 res,
2677 Type::Code(intermediate).into_id_with_location(loc_e.clone()),
2678 )?;
2679 Ok(intermediate)
2680 }
2681 Expr::Bracket(e) => {
2682 let loc_e = loc.clone();
2683 let prev_stage = self.stage;
2684 self.stage = prev_stage.increment();
2686 log::trace!("Staging bracket expression, stage => {:?}", self.stage);
2687 let res = self.infer_type_unwrapping(*e);
2688 self.stage = prev_stage;
2690 Ok(Type::Code(res).into_id_with_location(loc_e))
2691 }
2692 Expr::Match(scrutinee, arms) => {
2693 let scrut_ty = self.infer_type_unwrapping(*scrutinee);
2695
2696 let arm_tys: Vec<TypeNodeId> = arms
2698 .iter()
2699 .map(|arm| {
2700 match &arm.pattern {
2701 crate::ast::MatchPattern::Literal(lit) => {
2702 let pat_ty = match lit {
2704 crate::ast::Literal::Int(_) | crate::ast::Literal::Float(_) => {
2705 Type::Primitive(PType::Numeric)
2706 .into_id_with_location(loc.clone())
2707 }
2708 _ => Type::Failure.into_id_with_location(loc.clone()),
2709 };
2710 let _ = self.unify_types(scrut_ty, pat_ty);
2711 self.infer_type_unwrapping(arm.body)
2712 }
2713 crate::ast::MatchPattern::Wildcard => {
2714 self.infer_type_unwrapping(arm.body)
2716 }
2717 crate::ast::MatchPattern::Variable(_) => {
2718 self.infer_type_unwrapping(arm.body)
2721 }
2722 crate::ast::MatchPattern::Constructor(constructor_name, binding) => {
2723 let binding_ty = self
2726 .get_constructor_type_from_union(scrut_ty, *constructor_name);
2727
2728 if let Some(inner_pattern) = binding {
2729 self.env.extend();
2731 self.add_pattern_bindings(inner_pattern, binding_ty);
2732 let body_ty = self.infer_type_unwrapping(arm.body);
2733 self.env.to_outer();
2734 body_ty
2735 } else {
2736 self.infer_type_unwrapping(arm.body)
2737 }
2738 }
2739 crate::ast::MatchPattern::Tuple(patterns) => {
2740 self.env.extend();
2744
2745 let resolved_scrut_ty = scrut_ty.get_root().to_type();
2747 if let Type::Tuple(elem_types) = resolved_scrut_ty {
2748 for (pat, elem_ty) in patterns.iter().zip(elem_types.iter()) {
2751 self.check_pattern_against_type(pat, *elem_ty, &loc);
2752 }
2753 } else {
2754 for pat in patterns.iter() {
2757 self.check_pattern_against_type(pat, scrut_ty, &loc);
2758 }
2759 }
2760
2761 let body_ty = self.infer_type_unwrapping(arm.body);
2762 self.env.to_outer();
2763 body_ty
2764 }
2765 }
2766 })
2767 .collect();
2768
2769 self.match_expressions.push((e, scrut_ty));
2771
2772 if arm_tys.is_empty() {
2773 Ok(Type::Primitive(PType::Unit).into_id_with_location(loc))
2774 } else {
2775 let first = arm_tys[0];
2776 for ty in arm_tys.iter().skip(1) {
2777 let _ = self.unify_types(first, *ty);
2778 }
2779 Ok(first)
2780 }
2781 }
2782 _ => Ok(Type::Failure.into_id_with_location(loc)),
2783 };
2784 res.inspect(|ty| {
2785 self.result_memo.insert(e.0, *ty);
2786 })
2787 }
2788 fn infer_type_unwrapping(&mut self, e: ExprNodeId) -> TypeNodeId {
2789 match self.infer_type(e) {
2790 Ok(t) => t,
2791 Err(err) => {
2792 let failure_ty = Type::Failure
2793 .into_id_with_location(Location::new(e.to_span(), self.file_path.clone()));
2794 self.errors.extend(err);
2795 self.result_memo.insert(e.0, failure_ty);
2796 failure_ty
2797 }
2798 }
2799 }
2800
2801 fn check_match_exhaustiveness(
2804 &self,
2805 scrutinee_ty: TypeNodeId,
2806 arms: &[crate::ast::MatchArm],
2807 ) -> Option<Vec<Symbol>> {
2808 let required_constructors = self.get_all_constructors(scrutinee_ty);
2810
2811 if required_constructors.is_empty() {
2813 return None;
2814 }
2815
2816 let has_wildcard = arms.iter().any(|arm| {
2818 matches!(
2819 &arm.pattern,
2820 crate::ast::MatchPattern::Wildcard
2821 | crate::ast::MatchPattern::Variable(_)
2822 | crate::ast::MatchPattern::Tuple(_)
2823 )
2824 });
2825
2826 if has_wildcard {
2828 return None;
2829 }
2830
2831 let covered_constructors: Vec<Symbol> = arms
2833 .iter()
2834 .filter_map(|arm| {
2835 if let crate::ast::MatchPattern::Constructor(name, _) = &arm.pattern {
2836 Some(*name)
2837 } else {
2838 None
2839 }
2840 })
2841 .collect();
2842
2843 let missing: Vec<Symbol> = required_constructors
2845 .into_iter()
2846 .filter(|req| !covered_constructors.contains(req))
2847 .collect();
2848
2849 if missing.is_empty() {
2850 None
2851 } else {
2852 Some(missing)
2853 }
2854 }
2855
2856 fn get_all_constructors(&self, ty: TypeNodeId) -> Vec<Symbol> {
2860 let resolved = self.resolve_type_alias(ty);
2862 let substituted = Self::substitute_type(resolved);
2863
2864 match substituted.to_type() {
2865 Type::Union(variants) => {
2866 variants
2868 .iter()
2869 .filter_map(|v| {
2870 let v_resolved = Self::substitute_type(*v);
2871 Self::type_constructor_name(&v_resolved.to_type())
2872 })
2873 .collect()
2874 }
2875 Type::UserSum { name: _, variants } => {
2876 variants.iter().map(|(name, _)| *name).collect()
2878 }
2879 _ => {
2880 Vec::new()
2882 }
2883 }
2884 }
2885
2886 pub fn check_all_match_exhaustiveness(&mut self) {
2889 let match_expressions = std::mem::take(&mut self.match_expressions);
2890
2891 let errors: Vec<_> = match_expressions
2892 .into_iter()
2893 .filter_map(|(match_expr, scrut_ty)| {
2894 if let Expr::Match(_scrutinee, arms) = &match_expr.to_expr() {
2895 let resolved_scrut_ty = self.resolve_type_alias(scrut_ty);
2896 let substituted_scrut_ty = Self::substitute_type(resolved_scrut_ty);
2897
2898 self.check_match_exhaustiveness(substituted_scrut_ty, arms)
2899 .map(|missing| Error::NonExhaustiveMatch {
2900 missing_constructors: missing,
2901 location: match_expr.to_location(),
2902 })
2903 } else {
2904 None
2905 }
2906 })
2907 .collect();
2908
2909 self.errors.extend(errors);
2910 }
2911}
2912
2913pub fn infer_root(
2914 e: ExprNodeId,
2915 builtin_types: &[(Symbol, TypeNodeId)],
2916 file_path: PathBuf,
2917 type_declarations: Option<&crate::ast::program::TypeDeclarationMap>,
2918 type_aliases: Option<&crate::ast::program::TypeAliasMap>,
2919 module_info: Option<crate::ast::program::ModuleInfo>,
2920) -> InferContext {
2921 use std::sync::atomic::{AtomicUsize, Ordering};
2922 static INFER_ROOT_COUNTER: AtomicUsize = AtomicUsize::new(0);
2923 let call_id = INFER_ROOT_COUNTER.fetch_add(1, Ordering::Relaxed);
2924 let mut ctx = InferContext::new(
2925 builtin_types,
2926 file_path.clone(),
2927 type_declarations,
2928 type_aliases,
2929 module_info,
2930 );
2931 ctx.infer_root_id = call_id;
2932 let _t = ctx
2933 .infer_type(e)
2934 .unwrap_or(Type::Failure.into_id_with_location(e.to_location()));
2935 ctx.substitute_all_intermediates();
2936 ctx.check_all_match_exhaustiveness();
2937 ctx
2938}
2939
2940#[cfg(test)]
2941mod tests {
2942 use super::*;
2943 use crate::interner::ToSymbol;
2944 use crate::types::Type;
2945 use crate::utils::metadata::{Location, Span};
2946
2947 fn create_test_context() -> InferContext {
2948 InferContext::new(&[], PathBuf::from("test"), None, None, None)
2949 }
2950
2951 fn create_test_location() -> Location {
2952 Location::new(Span { start: 0, end: 0 }, PathBuf::from("test"))
2953 }
2954
2955 #[test]
2956 fn test_stage_mismatch_detection() {
2957 let mut ctx = create_test_context();
2958 let loc = create_test_location();
2959
2960 let var_name = "x".to_symbol();
2962 let var_type =
2963 Type::Primitive(crate::types::PType::Numeric).into_id_with_location(loc.clone());
2964 ctx.env
2965 .add_bind(&[(var_name, (var_type, EvalStage::Stage(0)))]);
2966
2967 ctx.stage = EvalStage::Stage(0);
2969 let result = ctx.lookup(var_name, loc.clone());
2970 assert!(
2971 result.is_ok(),
2972 "Looking up variable from same stage should succeed"
2973 );
2974
2975 ctx.stage = EvalStage::Stage(1);
2977 let result = ctx.lookup(var_name, loc.clone());
2978 assert!(
2979 result.is_err(),
2980 "Looking up variable from different stage should fail"
2981 );
2982
2983 if let Err(Error::StageMismatch {
2984 variable,
2985 expected_stage,
2986 found_stage,
2987 ..
2988 }) = result
2989 {
2990 assert_eq!(variable, var_name);
2991 assert_eq!(expected_stage, EvalStage::Stage(1));
2992 assert_eq!(found_stage, EvalStage::Stage(0));
2993 } else {
2994 panic!("Expected StageMismatch error, got: {result:?}");
2995 }
2996 }
2997
2998 #[test]
2999 fn test_persistent_stage_access() {
3000 let mut ctx = create_test_context();
3001 let loc = create_test_location();
3002
3003 let var_name = "persistent_var".to_symbol();
3005 let var_type =
3006 Type::Primitive(crate::types::PType::Numeric).into_id_with_location(loc.clone());
3007 ctx.env
3008 .add_bind(&[(var_name, (var_type, EvalStage::Persistent))]);
3009
3010 for stage in [0, 1, 2] {
3012 ctx.stage = EvalStage::Stage(stage);
3013 let result = ctx.lookup(var_name, loc.clone());
3014 assert!(
3015 result.is_ok(),
3016 "Persistent stage variables should be accessible from stage {stage}"
3017 );
3018 }
3019 }
3020
3021 #[test]
3022 fn test_same_stage_access() {
3023 let mut ctx = create_test_context();
3024 let loc = create_test_location();
3025
3026 for stage in [0, 1, 2] {
3028 let var_name = format!("var_stage_{stage}").to_symbol();
3029 let var_type =
3030 Type::Primitive(crate::types::PType::Numeric).into_id_with_location(loc.clone());
3031 ctx.env
3032 .add_bind(&[(var_name, (var_type, EvalStage::Stage(stage)))]);
3033 }
3034
3035 for stage in [0, 1, 2] {
3037 ctx.stage = EvalStage::Stage(stage);
3038 let var_name = format!("var_stage_{stage}").to_symbol();
3039 let result = ctx.lookup(var_name, loc.clone());
3040 assert!(
3041 result.is_ok(),
3042 "Variable should be accessible from its own stage {stage}"
3043 );
3044
3045 for other_stage in [0, 1, 2] {
3047 if other_stage != stage {
3048 ctx.stage = EvalStage::Stage(other_stage);
3049 let result = ctx.lookup(var_name, loc.clone());
3050 assert!(
3051 result.is_err(),
3052 "Variable from stage {stage} should not be accessible from stage {other_stage}",
3053 );
3054 }
3055 }
3056 }
3057 }
3058
3059 #[test]
3060 fn test_stage_transitions_bracket_escape() {
3061 let mut ctx = create_test_context();
3062
3063 assert_eq!(ctx.stage, EvalStage::Stage(0), "Initial stage should be 0");
3065
3066 ctx.stage = ctx.stage.increment();
3068 assert_eq!(
3069 ctx.stage,
3070 EvalStage::Stage(1),
3071 "Stage should increment to 1 in bracket"
3072 );
3073
3074 ctx.stage = ctx.stage.decrement();
3076 assert_eq!(
3077 ctx.stage,
3078 EvalStage::Stage(0),
3079 "Stage should decrement back to 0 after escape"
3080 );
3081 }
3082
3083 #[test]
3084 fn test_multi_stage_environment() {
3085 let mut ctx = create_test_context();
3086 let loc = create_test_location();
3087
3088 ctx.env.extend(); let var_stage0 = "x".to_symbol();
3093 let var_type =
3094 Type::Primitive(crate::types::PType::Numeric).into_id_with_location(loc.clone());
3095 ctx.stage = EvalStage::Stage(0);
3096 ctx.env
3097 .add_bind(&[(var_stage0, (var_type, EvalStage::Stage(0)))]);
3098
3099 ctx.env.extend(); let var_stage1 = "x".to_symbol(); ctx.stage = EvalStage::Stage(1);
3104 ctx.env
3105 .add_bind(&[(var_stage1, (var_type, EvalStage::Stage(1)))]);
3106
3107 ctx.stage = EvalStage::Stage(0);
3109 let result = ctx.lookup(var_stage0, loc.clone());
3110 assert!(
3111 result.is_err(),
3112 "Stage 0 variable should not be accessible from nested stage 0 context due to shadowing"
3113 );
3114
3115 ctx.stage = EvalStage::Stage(1);
3116 let result = ctx.lookup(var_stage1, loc.clone());
3117 assert!(
3118 result.is_ok(),
3119 "Stage 1 variable should be accessible from stage 1"
3120 );
3121
3122 ctx.stage = EvalStage::Stage(0);
3123 let result = ctx.lookup(var_stage1, loc.clone());
3124 assert!(
3125 result.is_err(),
3126 "Stage 1 variable should not be accessible from stage 0"
3127 );
3128
3129 ctx.env.to_outer();
3131 ctx.env.to_outer();
3132 }
3133
3134 #[test]
3135 fn test_qualified_var_mangling() {
3136 use crate::compiler;
3137
3138 let src = r#"
3139mod mymath {
3140 pub fn add(x, y) {
3141 x + y
3142 }
3143}
3144
3145fn dsp() {
3146 mymath::add(1.0, 2.0)
3147}
3148"#;
3149 let empty_ext_fns: Vec<compiler::ExtFunTypeInfo> = vec![];
3152 let empty_macros: Vec<Box<dyn crate::plugin::MacroFunction>> = vec![];
3153 let ctx = compiler::Context::new(
3154 empty_ext_fns,
3155 empty_macros,
3156 Some(std::path::PathBuf::from("test")),
3157 compiler::Config::default(),
3158 );
3159 let result = ctx.emit_mir(src);
3160
3161 assert!(result.is_ok(), "Compilation failed: {:?}", result.err());
3163 }
3164
3165 #[test]
3166 fn test_qualified_var_mir_generation() {
3167 use crate::compiler;
3168
3169 let src = r#"
3170mod mymath {
3171 pub fn add(x, y) {
3172 x + y
3173 }
3174}
3175
3176fn dsp() {
3177 mymath::add(1.0, 2.0)
3178}
3179"#;
3180 let empty_ext_fns: Vec<compiler::ExtFunTypeInfo> = vec![];
3182 let empty_macros: Vec<Box<dyn crate::plugin::MacroFunction>> = vec![];
3183 let ctx = compiler::Context::new(
3184 empty_ext_fns,
3185 empty_macros,
3186 Some(std::path::PathBuf::from("test")),
3187 compiler::Config::default(),
3188 );
3189 let result = ctx.emit_mir(src);
3190
3191 assert!(result.is_ok(), "MIR generation failed: {:?}", result.err());
3193 }
3194
3195 #[test]
3196 fn test_macro_return_record_missing_field_reports_type_error() {
3197 use crate::compiler;
3198
3199 let src = r#"
3200pub type alias Note = {v:float, gate:float}
3201
3202#stage(macro)
3203fn make_note()->`Note{
3204 `({v = 60.0, gate = 1.0})
3205}
3206
3207fn dsp(){
3208 let note = make_note!()
3209 note.val
3210}
3211"#;
3212
3213 let empty_ext_fns: Vec<compiler::ExtFunTypeInfo> = vec![];
3214 let empty_macros: Vec<Box<dyn crate::plugin::MacroFunction>> = vec![];
3215 let ctx = compiler::Context::new(
3216 empty_ext_fns,
3217 empty_macros,
3218 Some(std::path::PathBuf::from("test")),
3219 compiler::Config::default(),
3220 );
3221 let result = ctx.emit_mir(src);
3222
3223 assert!(
3224 result.is_err(),
3225 "Compilation should fail for missing record field access"
3226 );
3227
3228 let errors = result.err().unwrap();
3229 assert!(
3235 errors.iter().any(|e| {
3236 let message = e.get_message();
3237 message.contains("Field \"val\"")
3238 || message.contains("Field access for non-record variable")
3239 }),
3240 "Expected field access type error for \"val\", got: {:?}",
3241 errors.iter().map(|e| e.get_message()).collect::<Vec<_>>()
3242 );
3243 }
3244
3245 #[test]
3246 fn test_recursive_function_preserves_record_array_width_from_param_annotation() {
3247 use crate::compiler;
3248 use crate::plugin;
3249
3250 let src = r#"
3251pub type alias Arc = {start:float, end:float}
3252pub type alias Event = {arc:Arc, active:Arc, val:float}
3253
3254fn value_at_phase(events:[Event], phase:float, current:float)->float{
3255 if (len(events) > 0.0){
3256 let (head,rest) = events |> split_head
3257 if (phase >= head.arc.start){
3258 value_at_phase(rest, phase, head.val)
3259 }else{
3260 current
3261 }
3262 }else{
3263 current
3264 }
3265}
3266
3267fn dsp(){
3268 let events = [{
3269 arc = {start = 0.0, end = 1.0},
3270 active = {start = 0.0, end = 1.0},
3271 val = 1.0,
3272 }]
3273 value_at_phase(events, 0.0, 0.0)
3274}
3275"#;
3276
3277 let ext_fns = plugin::get_extfun_types(&[plugin::get_builtin_fns_as_plugins()])
3278 .collect::<Vec<_>>();
3279 let macros = plugin::get_macro_functions(&[plugin::get_builtin_fns_as_plugins()])
3280 .collect::<Vec<_>>();
3281 let ctx = compiler::Context::new(
3282 ext_fns,
3283 macros,
3284 Some(std::path::PathBuf::from("test")),
3285 compiler::Config::default(),
3286 );
3287 let result = ctx.emit_mir(src);
3288
3289 assert!(result.is_ok(), "MIR generation failed: {:?}", result.err());
3290
3291 let mir = result.unwrap();
3292 let printed = format!("{mir}");
3293 let signature_line = printed
3294 .lines()
3295 .find(|line| line.starts_with("fn value_at_phase ["))
3296 .expect("value_at_phase should be present in MIR");
3297
3298 assert!(
3299 signature_line.contains("active") && signature_line.contains("end:number"),
3300 "record-array parameter width should keep full Event shape, got: {signature_line}"
3301 );
3302 assert!(
3303 printed.contains("split_head$arity5"),
3304 "split_head should specialize for full Event width, got MIR:\n{printed}"
3305 );
3306 }
3307
3308 #[test]
3309 fn test_imported_staging_preserves_record_array_width() {
3310 use crate::compiler;
3311 use crate::plugin;
3312 use std::fs;
3313
3314 let repo_root = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../..");
3315 let fixture_root = repo_root.join("tmp/staging_import_record_array_regression");
3316 let fixture_lib_dir = fixture_root.join("lib");
3317 let fixture_main = fixture_root.join("main.mmm");
3318 let fixture_module = fixture_lib_dir.join("pattern_like.mmm");
3319
3320 fs::create_dir_all(&fixture_lib_dir).expect("fixture lib dir should be created");
3321 fs::write(
3322 &fixture_module,
3323 r#"
3324use osc::phasor
3325
3326pub type alias Arc = {start:float, end:float}
3327pub type alias Event = {arc:Arc, active:Arc, val:float}
3328
3329#stage(main)
3330fn value_at_phase(events:[Event], phase:float, current:float)->float{
3331 if ((events |> len) > 0.0){
3332 let (head,rest) = events |> split_head
3333 if (phase >= head.arc.start){
3334 value_at_phase(rest, phase, head.val)
3335 }else{
3336 current
3337 }
3338 }else{
3339 current
3340 }
3341}
3342
3343#stage(macro)
3344pub fn run_value()->`float{
3345 let events = [{
3346 arc = {start = 0.0, end = 1.0},
3347 active = {start = 0.0, end = 1.0},
3348 val = 60.0,
3349 }]
3350 `{
3351 let phase = phasor(0.5, 0.0)
3352 let v = value_at_phase($(events |> lift), phase, 0.0)
3353 v
3354 }
3355}
3356"#,
3357 )
3358 .expect("fixture module should be written");
3359 fs::write(
3360 &fixture_main,
3361 r#"
3362use pattern_like::*
3363
3364fn dsp(){
3365 let value = run_value!()
3366 (value, value)
3367}
3368"#,
3369 )
3370 .expect("fixture main should be written");
3371
3372 let src = fs::read_to_string(&fixture_main).expect("fixture main should be readable");
3373 let (_ast, module_info, parse_errs) = crate::compiler::parser::parse_to_expr(
3374 &src,
3375 Some(fixture_main.clone()),
3376 );
3377 assert!(parse_errs.is_empty(), "fixture should parse cleanly");
3378 assert!(
3379 module_info
3380 .type_aliases
3381 .keys()
3382 .any(|name| name.as_str() == "pattern_like$Event"),
3383 "imported module type aliases should contain pattern_like$Event, got: {:?}",
3384 module_info
3385 .type_aliases
3386 .keys()
3387 .map(|name| name.as_str().to_string())
3388 .collect::<Vec<_>>()
3389 );
3390 let ext_fns = plugin::get_extfun_types(&[plugin::get_builtin_fns_as_plugins()])
3391 .collect::<Vec<_>>();
3392 let macros = plugin::get_macro_functions(&[plugin::get_builtin_fns_as_plugins()])
3393 .collect::<Vec<_>>();
3394 let ctx = compiler::Context::new(
3395 ext_fns,
3396 macros,
3397 Some(fixture_main.clone()),
3398 compiler::Config::default(),
3399 );
3400 let result = ctx.emit_mir(&src);
3401
3402 assert!(result.is_ok(), "MIR generation failed: {:?}", result.err());
3403
3404 let printed = format!("{}", result.unwrap());
3405 let signature_line = printed
3406 .lines()
3407 .find(|line| line.starts_with("fn pattern_like$value_at_phase ["))
3408 .expect("imported value_at_phase should be present in MIR");
3409
3410 assert!(
3411 signature_line.contains("active") && signature_line.contains("end:number"),
3412 "imported staged record-array parameter width should keep full Event shape, got: {signature_line}"
3413 );
3414 assert!(
3415 printed.contains("split_head$arity5"),
3416 "imported staged split_head should specialize for full Event width, got MIR:\n{printed}"
3417 );
3418 }
3419}