1use std::collections::HashMap;
27use std::sync::atomic::{AtomicU32, Ordering};
28
29use bock_air::stubs::{TypeInfo, Value};
30use bock_air::{AIRNode, EnumVariantPayload, NodeId, NodeKind};
31use bock_ast::{BinOp, Literal, TypeConstraint, TypeExpr, TypePath, UnaryOp};
32use bock_errors::{DiagnosticBag, DiagnosticCode, Span};
33
34use crate::traits::{resolve_impl, ImplTable, TraitRef};
35use crate::{unify, EffectRef, FnType, GenericType, PrimitiveType, Substitution, Type, TypeVarId};
36
37const E_TYPE_MISMATCH: DiagnosticCode = DiagnosticCode {
40 prefix: 'E',
41 number: 4001,
42};
43const E_UNDEFINED_VAR: DiagnosticCode = DiagnosticCode {
44 prefix: 'E',
45 number: 4002,
46};
47const E_ARITY_MISMATCH: DiagnosticCode = DiagnosticCode {
48 prefix: 'E',
49 number: 4003,
50};
51const E_NOT_CALLABLE: DiagnosticCode = DiagnosticCode {
52 prefix: 'E',
53 number: 4004,
54};
55const E_WHERE_CLAUSE: DiagnosticCode = DiagnosticCode {
56 prefix: 'E',
57 number: 4005,
58};
59
60struct TypeVarGen {
64 counter: AtomicU32,
65}
66
67impl TypeVarGen {
68 fn new() -> Self {
69 Self {
70 counter: AtomicU32::new(0),
71 }
72 }
73
74 fn next(&self) -> TypeVarId {
75 self.counter.fetch_add(1, Ordering::SeqCst)
76 }
77}
78
79pub struct TypeEnv {
85 scopes: Vec<HashMap<String, Type>>,
86}
87
88impl TypeEnv {
89 #[must_use]
91 pub fn new() -> Self {
92 Self {
93 scopes: vec![HashMap::new()],
94 }
95 }
96
97 pub fn push_scope(&mut self) {
99 self.scopes.push(HashMap::new());
100 }
101
102 pub fn pop_scope(&mut self) {
106 debug_assert!(self.scopes.len() > 1, "cannot pop the global scope");
107 self.scopes.pop();
108 }
109
110 pub fn define(&mut self, name: impl Into<String>, ty: Type) {
112 let scope = self.scopes.last_mut().expect("at least one scope");
113 scope.insert(name.into(), ty);
114 }
115
116 #[must_use]
118 pub fn lookup(&self, name: &str) -> Option<&Type> {
119 self.scopes.iter().rev().find_map(|s| s.get(name))
120 }
121}
122
123impl Default for TypeEnv {
124 fn default() -> Self {
125 Self::new()
126 }
127}
128
129#[derive(Debug, Clone)]
133struct FnSig {
134 generic_params: Vec<String>,
136 generic_var_ids: Vec<TypeVarId>,
140 param_types: Vec<Type>,
143 return_type: Type,
145 where_clause: Vec<TypeConstraint>,
147}
148
149pub struct TypeChecker {
161 pub env: TypeEnv,
163 pub subst: Substitution,
165 pub diags: DiagnosticBag,
167 var_gen: TypeVarGen,
169 types: HashMap<NodeId, Type>,
171 fn_sigs: HashMap<String, FnSig>,
173 return_ty_stack: Vec<Type>,
175 pub impl_table: Option<ImplTable>,
178 method_types: HashMap<String, HashMap<String, Type>>,
180 effect_op_types: HashMap<String, Vec<(String, Type)>>,
183 effect_components: HashMap<String, Vec<String>>,
185 record_field_types: HashMap<String, Vec<(String, Type)>>,
188 record_generic_params: HashMap<String, Vec<String>>,
191 type_aliases: HashMap<String, Type>,
194 trait_method_types: HashMap<String, HashMap<String, Type>>,
199 type_var_bounds: HashMap<TypeVarId, Vec<String>>,
204}
205
206impl TypeChecker {
207 #[must_use]
209 pub fn new() -> Self {
210 Self {
211 env: TypeEnv::new(),
212 subst: Substitution::new(),
213 diags: DiagnosticBag::new(),
214 var_gen: TypeVarGen::new(),
215 types: HashMap::new(),
216 fn_sigs: HashMap::new(),
217 return_ty_stack: Vec::new(),
218 impl_table: None,
219 method_types: HashMap::new(),
220 effect_op_types: HashMap::new(),
221 effect_components: HashMap::new(),
222 record_field_types: HashMap::new(),
223 record_generic_params: HashMap::new(),
224 type_aliases: HashMap::new(),
225 trait_method_types: HashMap::new(),
226 type_var_bounds: HashMap::new(),
227 }
228 }
229
230 fn fresh_var(&self) -> Type {
234 Type::TypeVar(self.var_gen.next())
235 }
236
237 fn record(&mut self, node: &mut AIRNode, ty: Type) -> Type {
242 let resolved = self.subst.apply(&ty);
243 self.types.insert(node.id, resolved.clone());
244 node.type_info = Some(TypeInfo {
245 resolved_type: None,
246 });
247 if matches!(resolved, Type::Primitive(_)) {
249 node.metadata
250 .insert("copy_type".into(), Value::Bool(true));
251 }
252 resolved
253 }
254
255 #[must_use]
257 pub fn type_of(&self, id: NodeId) -> Option<&Type> {
258 self.types.get(&id)
259 }
260
261 #[must_use]
265 pub fn record_field_types(&self) -> &HashMap<String, Vec<(String, Type)>> {
266 &self.record_field_types
267 }
268
269 #[must_use]
271 pub fn record_generic_params(&self) -> &HashMap<String, Vec<String>> {
272 &self.record_generic_params
273 }
274
275 #[must_use]
277 pub fn effect_op_types(&self) -> &HashMap<String, Vec<(String, Type)>> {
278 &self.effect_op_types
279 }
280
281 #[must_use]
283 pub fn effect_components(&self) -> &HashMap<String, Vec<String>> {
284 &self.effect_components
285 }
286
287 #[must_use]
289 pub fn method_types(&self) -> &HashMap<String, HashMap<String, Type>> {
290 &self.method_types
291 }
292
293 #[must_use]
295 pub fn trait_method_types(&self) -> &HashMap<String, HashMap<String, Type>> {
296 &self.trait_method_types
297 }
298
299 #[must_use]
301 pub fn type_aliases(&self) -> &HashMap<String, Type> {
302 &self.type_aliases
303 }
304
305 pub fn insert_record_field_types(&mut self, name: String, fields: Vec<(String, Type)>) {
309 self.record_field_types.insert(name, fields);
310 }
311
312 pub fn insert_record_generic_params(&mut self, name: String, params: Vec<String>) {
314 self.record_generic_params.insert(name, params);
315 }
316
317 pub fn insert_trait_method_types(&mut self, name: String, methods: HashMap<String, Type>) {
319 self.trait_method_types.insert(name, methods);
320 }
321
322 pub fn insert_effect_op_types(&mut self, name: String, ops: Vec<(String, Type)>) {
324 self.effect_op_types.insert(name, ops);
325 }
326
327 pub fn insert_effect_components(&mut self, name: String, components: Vec<String>) {
329 self.effect_components.insert(name, components);
330 }
331
332 pub fn insert_type_alias(&mut self, name: String, underlying: Type) {
334 self.type_aliases.insert(name, underlying);
335 }
336
337 pub fn insert_method_types(&mut self, type_name: String, methods: HashMap<String, Type>) {
339 self.method_types.insert(type_name, methods);
340 }
341
342 pub fn seed_imported_generic_fn(&mut self, name: &str, fn_ty: &FnType) -> Type {
354 let mut original_ids = Vec::new();
357 collect_type_var_ids_fn(fn_ty, &mut original_ids);
358
359 if original_ids.is_empty() {
360 let ty = Type::Function(fn_ty.clone());
362 self.env.define(name, ty.clone());
363 return ty;
364 }
365
366 let remap: HashMap<TypeVarId, Type> = original_ids
368 .iter()
369 .map(|&id| (id, self.fresh_var()))
370 .collect();
371
372 let fresh_ids: Vec<TypeVarId> = original_ids
373 .iter()
374 .map(|id| match &remap[id] {
375 Type::TypeVar(fresh) => *fresh,
376 _ => unreachable!(),
377 })
378 .collect();
379
380 let remapped = Type::Function(FnType {
382 params: fn_ty
383 .params
384 .iter()
385 .map(|t| self.replace_type_vars(t, &remap))
386 .collect(),
387 ret: Box::new(self.replace_type_vars(&fn_ty.ret, &remap)),
388 effects: fn_ty.effects.clone(),
389 });
390
391 self.env.define(name, remapped.clone());
393
394 let generic_params: Vec<String> = (0..original_ids.len())
396 .map(|i| format!("T{i}"))
397 .collect();
398
399 if let Type::Function(ref f) = remapped {
401 self.fn_sigs.insert(
402 name.to_string(),
403 FnSig {
404 generic_params,
405 generic_var_ids: fresh_ids,
406 param_types: f.params.clone(),
407 return_type: (*f.ret).clone(),
408 where_clause: vec![],
409 },
410 );
411 }
412
413 remapped
414 }
415
416 fn unify_or_error(&mut self, a: &Type, b: &Type, span: Span, context: &str) -> Type {
421 let a = self.resolve_alias(&self.subst.apply(a));
422 let b = self.resolve_alias(&self.subst.apply(b));
423 match unify(&a, &b, &mut self.subst) {
424 Ok(()) => self.subst.apply(&a),
425 Err(e) => {
426 let diag = self.diags.error(
427 E_TYPE_MISMATCH,
428 format!("type mismatch in {context}: {e}"),
429 span,
430 );
431 if let Some(hint) = conversion_hint(&a, &b) {
432 diag.note(hint);
433 }
434 Type::Error
435 }
436 }
437 }
438
439 pub fn check_module(&mut self, module: &mut AIRNode) {
447 let items = match &module.kind {
449 NodeKind::Module { items, .. } => items.clone(),
450 _ => return,
451 };
452
453 for item in &items {
455 self.collect_sig(item);
456 }
457
458 if let NodeKind::Module { items, .. } = &mut module.kind {
461 for item in items.iter_mut() {
462 self.check_item(item);
463 }
464 }
465
466 self.record(module, Type::Primitive(PrimitiveType::Void));
467 }
468
469 fn collect_sig(&mut self, node: &AIRNode) {
471 match &node.kind {
472 NodeKind::FnDecl {
473 name,
474 generic_params,
475 params,
476 return_type,
477 effect_clause,
478 where_clause,
479 ..
480 } => {
481 let gp_names: Vec<String> =
482 generic_params.iter().map(|g| g.name.name.clone()).collect();
483
484 let gp_map: HashMap<String, Type> = gp_names
486 .iter()
487 .map(|n| (n.clone(), self.fresh_var()))
488 .collect();
489
490 let gp_var_ids: Vec<TypeVarId> = gp_names
493 .iter()
494 .map(|n| match &gp_map[n] {
495 Type::TypeVar(id) => *id,
496 _ => unreachable!(),
497 })
498 .collect();
499
500 let param_types: Vec<Type> = params
502 .iter()
503 .map(|p| self.air_type_node_to_type(p.kind.param_ty_node(), &gp_map))
504 .collect();
505
506 let ret_ty = return_type
507 .as_deref()
508 .map(|n| self.air_type_node_to_type(n, &gp_map))
509 .unwrap_or(Type::Primitive(PrimitiveType::Void));
510
511 let effects: Vec<EffectRef> = effect_clause
513 .iter()
514 .map(|tp| {
515 let name = tp
516 .segments
517 .iter()
518 .map(|s| s.name.as_str())
519 .collect::<Vec<_>>()
520 .join(".");
521 EffectRef::new(name)
522 })
523 .collect();
524
525 let fn_ty = Type::Function(FnType {
527 params: param_types.clone(),
528 ret: Box::new(ret_ty.clone()),
529 effects: effects.clone(),
530 });
531 self.env.define(name.name.clone(), fn_ty);
532
533 self.fn_sigs.insert(
534 name.name.clone(),
535 FnSig {
536 generic_params: gp_names,
537 generic_var_ids: gp_var_ids,
538 param_types,
539 return_type: ret_ty,
540 where_clause: where_clause.clone(),
541 },
542 );
543 }
544 NodeKind::ConstDecl { name, ty, .. } => {
545 let const_ty = self.air_type_node_to_type(ty, &HashMap::new());
546 self.env.define(name.name.clone(), const_ty);
547 }
548 NodeKind::EnumDecl { name, variants, generic_params, .. } => {
549 let enum_name = name.name.clone();
550
551 let gp_names: Vec<String> = generic_params
553 .iter()
554 .map(|g| g.name.name.clone())
555 .collect();
556
557 let named_ty = Type::Named(crate::NamedType {
567 name: enum_name.clone(),
568 });
569 let (gp_map, gp_var_ids, generic_ret_ty) = if gp_names.is_empty() {
570 (HashMap::new(), vec![], named_ty.clone())
571 } else {
572 let gp_map: HashMap<String, Type> = gp_names
573 .iter()
574 .map(|n| (n.clone(), self.fresh_var()))
575 .collect();
576 let gp_var_ids: Vec<TypeVarId> = gp_names
577 .iter()
578 .map(|n| match &gp_map[n] {
579 Type::TypeVar(id) => *id,
580 _ => unreachable!(),
581 })
582 .collect();
583 let type_args: Vec<Type> =
584 gp_names.iter().map(|n| gp_map[n].clone()).collect();
585 let generic_ret_ty = Type::Generic(GenericType {
586 constructor: enum_name.clone(),
587 args: type_args,
588 });
589 (gp_map, gp_var_ids, generic_ret_ty)
590 };
591
592 self.env.define(enum_name.clone(), named_ty.clone());
595
596 if !gp_names.is_empty() {
598 self.record_generic_params
599 .insert(enum_name.clone(), gp_names.clone());
600 }
601
602 for variant in variants {
604 if let NodeKind::EnumVariant {
605 name: vname,
606 payload,
607 } = &variant.kind
608 {
609 match payload {
610 EnumVariantPayload::Unit => {
611 self.env.define(vname.name.clone(), named_ty.clone());
615 }
616 EnumVariantPayload::Tuple(param_nodes) => {
617 let param_tys: Vec<Type> = param_nodes
619 .iter()
620 .map(|p| self.air_type_node_to_type(p, &gp_map))
621 .collect();
622 let fn_ty = Type::Function(FnType {
623 params: param_tys.clone(),
624 ret: Box::new(generic_ret_ty.clone()),
625 effects: vec![],
626 });
627 self.env.define(vname.name.clone(), fn_ty);
628
629 if !gp_names.is_empty() {
632 self.fn_sigs.insert(
633 vname.name.clone(),
634 FnSig {
635 generic_params: gp_names.clone(),
636 generic_var_ids: gp_var_ids.clone(),
637 param_types: param_tys,
638 return_type: generic_ret_ty.clone(),
639 where_clause: vec![],
640 },
641 );
642 }
643 }
644 EnumVariantPayload::Struct(fields) => {
645 self.env.define(vname.name.clone(), named_ty.clone());
648 let field_types: Vec<(String, Type)> = fields
651 .iter()
652 .map(|f| {
653 let ty = self.type_expr_to_type(
654 &f.ty,
655 &gp_map,
656 );
657 (f.name.name.clone(), ty)
658 })
659 .collect();
660 self.record_field_types
661 .insert(vname.name.clone(), field_types);
662 if !gp_names.is_empty() {
665 self.record_generic_params
666 .insert(vname.name.clone(), gp_names.clone());
667 }
668 }
669 }
670 }
671 }
672 }
673 NodeKind::ImplBlock {
674 target, methods, ..
675 } => {
676 let target_name = match &target.kind {
677 NodeKind::TypeNamed { path, .. } => type_path_to_name(path),
678 _ => return,
679 };
680 let target_ty = Type::Named(crate::NamedType {
681 name: target_name.clone(),
682 });
683 for method in methods {
684 if let NodeKind::FnDecl {
685 name,
686 params,
687 return_type,
688 ..
689 } = &method.kind
690 {
691 let gp_map: HashMap<String, Type> = HashMap::new();
692
693 let param_types: Vec<Type> = params
694 .iter()
695 .map(|p| {
696 if let NodeKind::Param {
698 pattern, ty: None, ..
699 } = &p.kind
700 {
701 if let NodeKind::BindPat { name, .. } = &pattern.kind {
702 if name.name == "self" {
703 return target_ty.clone();
704 }
705 }
706 }
707 self.air_type_node_to_type(p, &gp_map)
708 })
709 .collect();
710
711 let ret_ty = return_type
712 .as_deref()
713 .map(|n| self.air_type_node_to_type(n, &gp_map))
714 .unwrap_or(Type::Primitive(PrimitiveType::Void));
715
716 let fn_ty = Type::Function(FnType {
717 params: param_types,
718 ret: Box::new(ret_ty),
719 effects: vec![],
720 });
721
722 self.method_types
723 .entry(target_name.clone())
724 .or_default()
725 .insert(name.name.clone(), fn_ty);
726 }
727 }
728 }
729 NodeKind::EffectDecl {
730 name,
731 operations,
732 components,
733 ..
734 } => {
735 let mut ops = Vec::new();
738 for op in operations {
739 if let NodeKind::FnDecl {
740 name: op_name,
741 params,
742 return_type,
743 ..
744 } = &op.kind
745 {
746 let param_types: Vec<Type> = params
747 .iter()
748 .map(|p| {
749 self.air_type_node_to_type(p.kind.param_ty_node(), &HashMap::new())
750 })
751 .collect();
752 let ret_ty = return_type
753 .as_deref()
754 .map(|n| self.air_type_node_to_type(n, &HashMap::new()))
755 .unwrap_or(Type::Primitive(PrimitiveType::Void));
756 let fn_ty = Type::Function(FnType {
757 params: param_types,
758 ret: Box::new(ret_ty),
759 effects: vec![],
760 });
761 ops.push((op_name.name.clone(), fn_ty));
762 }
763 }
764 self.effect_op_types.insert(name.name.clone(), ops);
765
766 let comp_names: Vec<String> =
767 components.iter().map(type_path_to_name).collect();
768 if !comp_names.is_empty() {
769 self.effect_components
770 .insert(name.name.clone(), comp_names);
771 }
772 }
773 NodeKind::RecordDecl {
774 name, fields, generic_params, ..
775 } => {
776 let record_name = name.name.clone();
777 let gp_names: Vec<String> = generic_params
778 .iter()
779 .map(|g| g.name.name.clone())
780 .collect();
781 let field_types: Vec<(String, Type)> = fields
782 .iter()
783 .map(|f| {
784 let ty = self.type_expr_to_type(&f.ty, &HashMap::new());
785 (f.name.name.clone(), ty)
786 })
787 .collect();
788 self.record_field_types
789 .insert(record_name.clone(), field_types);
790 if !gp_names.is_empty() {
791 self.record_generic_params
792 .insert(record_name.clone(), gp_names);
793 }
794 self.env.define(
796 record_name.clone(),
797 Type::Named(crate::NamedType {
798 name: record_name,
799 }),
800 );
801 }
802 NodeKind::TypeAlias {
803 name, ty, ..
804 } => {
805 let underlying = self.air_type_node_to_type(ty, &HashMap::new());
806 self.type_aliases
807 .insert(name.name.clone(), underlying);
808 }
809 NodeKind::ClassDecl {
810 name,
811 fields,
812 methods,
813 base,
814 generic_params,
815 ..
816 } => {
817 let class_name = name.name.clone();
818
819 let gp_names: Vec<String> = generic_params
821 .iter()
822 .map(|g| g.name.name.clone())
823 .collect();
824 if !gp_names.is_empty() {
825 self.record_generic_params
826 .insert(class_name.clone(), gp_names);
827 }
828
829 let field_types: Vec<(String, Type)> = fields
831 .iter()
832 .map(|f| {
833 let ty = self.type_expr_to_type(&f.ty, &HashMap::new());
834 (f.name.name.clone(), ty)
835 })
836 .collect();
837 self.record_field_types
838 .insert(class_name.clone(), field_types);
839
840 let class_ty = Type::Named(crate::NamedType {
842 name: class_name.clone(),
843 });
844 self.env.define(class_name.clone(), class_ty.clone());
845
846 if let Some(base_path) = base {
848 let base_name = type_path_to_name(base_path);
849 if let Some(base_methods) = self.method_types.get(&base_name).cloned() {
850 self.method_types
851 .entry(class_name.clone())
852 .or_default()
853 .extend(base_methods);
854 }
855 }
856
857 for method in methods {
859 if let NodeKind::FnDecl {
860 name: method_name,
861 params,
862 return_type,
863 ..
864 } = &method.kind
865 {
866 let gp_map: HashMap<String, Type> = HashMap::new();
867
868 let param_types: Vec<Type> = params
869 .iter()
870 .map(|p| {
871 if let NodeKind::Param {
872 pattern, ty: None, ..
873 } = &p.kind
874 {
875 if let NodeKind::BindPat { name, .. } = &pattern.kind {
876 if name.name == "self" {
877 return class_ty.clone();
878 }
879 }
880 }
881 self.air_type_node_to_type(p, &gp_map)
882 })
883 .collect();
884
885 let ret_ty = return_type
886 .as_deref()
887 .map(|n| self.air_type_node_to_type(n, &gp_map))
888 .unwrap_or(Type::Primitive(PrimitiveType::Void));
889
890 let fn_ty = Type::Function(FnType {
891 params: param_types,
892 ret: Box::new(ret_ty),
893 effects: vec![],
894 });
895
896 self.method_types
897 .entry(class_name.clone())
898 .or_default()
899 .insert(method_name.name.clone(), fn_ty);
900 }
901 }
902 }
903 NodeKind::TraitDecl {
904 name, methods, ..
905 } => {
906 let trait_name = name.name.clone();
907 let self_ty = Type::Named(crate::NamedType {
908 name: "Self".to_string(),
909 });
910 let mut trait_methods = HashMap::new();
911 for method in methods {
912 if let NodeKind::FnDecl {
913 name: method_name,
914 params,
915 return_type,
916 ..
917 } = &method.kind
918 {
919 let gp_map: HashMap<String, Type> = HashMap::new();
920 let param_types: Vec<Type> = params
921 .iter()
922 .map(|p| {
923 if let NodeKind::Param {
924 pattern, ty: None, ..
925 } = &p.kind
926 {
927 if let NodeKind::BindPat { name, .. } = &pattern.kind {
928 if name.name == "self" {
929 return self_ty.clone();
930 }
931 }
932 }
933 self.air_type_node_to_type(p, &gp_map)
934 })
935 .collect();
936 let ret_ty = return_type
937 .as_deref()
938 .map(|n| self.air_type_node_to_type(n, &gp_map))
939 .unwrap_or(Type::Primitive(PrimitiveType::Void));
940 let fn_ty = Type::Function(FnType {
941 params: param_types,
942 ret: Box::new(ret_ty),
943 effects: vec![],
944 });
945 trait_methods.insert(method_name.name.clone(), fn_ty);
946 }
947 }
948 if !trait_methods.is_empty() {
949 self.trait_method_types.insert(trait_name, trait_methods);
950 }
951 }
952 _ => {}
953 }
954 }
955
956 fn resolve_alias(&self, ty: &Type) -> Type {
959 match ty {
960 Type::Named(nt) => {
961 if let Some(underlying) = self.type_aliases.get(&nt.name) {
962 underlying.clone()
963 } else {
964 ty.clone()
965 }
966 }
967 _ => ty.clone(),
968 }
969 }
970
971 fn check_item(&mut self, node: &mut AIRNode) {
973 match &node.kind {
974 NodeKind::FnDecl { .. } => {
975 self.check_fn_decl(node);
976 }
977 NodeKind::ConstDecl { .. } => {
978 self.check_const_decl(node);
979 }
980 _ => {
982 self.record(node, Type::Primitive(PrimitiveType::Void));
983 }
984 }
985 }
986
987 fn check_fn_decl(&mut self, node: &mut AIRNode) {
989 let (_name, generic_params, params, return_type, effect_clause, where_clause, _body) =
991 match node.kind.clone() {
992 NodeKind::FnDecl {
993 name,
994 generic_params,
995 params,
996 return_type,
997 effect_clause,
998 where_clause,
999 body,
1000 ..
1001 } => (
1002 name,
1003 generic_params,
1004 params,
1005 return_type,
1006 effect_clause,
1007 where_clause,
1008 body,
1009 ),
1010 _ => return,
1011 };
1012
1013 self.env.push_scope();
1014
1015 let gp_map: HashMap<String, Type> = generic_params
1017 .iter()
1018 .map(|g| (g.name.name.clone(), self.fresh_var()))
1019 .collect();
1020
1021 for gp in &generic_params {
1024 if let Some(Type::TypeVar(id)) = gp_map.get(&gp.name.name) {
1025 let bound_names: Vec<String> = gp
1026 .bounds
1027 .iter()
1028 .map(type_path_to_name)
1029 .collect();
1030 if !bound_names.is_empty() {
1031 self.type_var_bounds.entry(*id).or_default().extend(bound_names);
1032 }
1033 }
1034 }
1035 for clause in &where_clause {
1036 if let Some(Type::TypeVar(id)) = gp_map.get(&clause.param.name) {
1037 let bound_names: Vec<String> = clause
1038 .bounds
1039 .iter()
1040 .map(type_path_to_name)
1041 .collect();
1042 if !bound_names.is_empty() {
1043 self.type_var_bounds.entry(*id).or_default().extend(bound_names);
1044 }
1045 }
1046 }
1047
1048 let param_types: Vec<Type> = params
1050 .iter()
1051 .map(|p| {
1052 let ty = self.air_type_node_to_type(p.kind.param_ty_node(), &gp_map);
1053 let pat_name = p.kind.param_pat_name();
1054 if let Some(n) = pat_name {
1055 self.env.define(n, ty.clone());
1056 }
1057 ty
1058 })
1059 .collect();
1060
1061 let ret_ty = return_type
1062 .as_deref()
1063 .map(|n| self.air_type_node_to_type(n, &gp_map))
1064 .unwrap_or(Type::Primitive(PrimitiveType::Void));
1065
1066 {
1069 let mut visited = std::collections::HashSet::new();
1070 for effect_tp in &effect_clause {
1071 let ename = type_path_to_name(effect_tp);
1072 self.inject_effect_ops_into_env(&ename, &mut visited);
1073 }
1074 }
1075
1076 self.check_where_clause(&where_clause, &gp_map, node.span);
1079
1080 self.return_ty_stack.push(ret_ty.clone());
1082
1083 if let NodeKind::FnDecl { body, .. } = &mut node.kind {
1085 self.check_node(body, &ret_ty);
1086 }
1087
1088 self.return_ty_stack.pop();
1089 self.env.pop_scope();
1090
1091 let effects: Vec<EffectRef> = effect_clause
1092 .iter()
1093 .map(|tp| {
1094 let name = tp
1095 .segments
1096 .iter()
1097 .map(|s| s.name.as_str())
1098 .collect::<Vec<_>>()
1099 .join(".");
1100 EffectRef::new(name)
1101 })
1102 .collect();
1103
1104 let fn_ty = Type::Function(FnType {
1105 params: param_types,
1106 ret: Box::new(ret_ty),
1107 effects,
1108 });
1109 self.record(node, fn_ty);
1110 }
1111
1112 fn check_const_decl(&mut self, node: &mut AIRNode) {
1114 let (name, ty_node, _value_node) = match node.kind.clone() {
1115 NodeKind::ConstDecl {
1116 name, ty, value, ..
1117 } => (name, ty, value),
1118 _ => return,
1119 };
1120 let expected_ty = self.air_type_node_to_type(&ty_node, &HashMap::new());
1121 if let NodeKind::ConstDecl { value, .. } = &mut node.kind {
1122 self.check_node(value, &expected_ty);
1123 }
1124 self.env.define(name.name, expected_ty.clone());
1125 self.record(node, expected_ty);
1126 }
1127
1128 fn check_where_clause(
1133 &mut self,
1134 clauses: &[TypeConstraint],
1135 gp_map: &HashMap<String, Type>,
1136 span: Span,
1137 ) {
1138 for clause in clauses {
1139 if !gp_map.contains_key(&clause.param.name) {
1140 self.diags.error(
1141 E_WHERE_CLAUSE,
1142 format!(
1143 "where-clause references unknown type parameter `{}`",
1144 clause.param.name
1145 ),
1146 span,
1147 );
1148 }
1149 }
1150 }
1151
1152 fn inject_effect_ops_into_env(
1157 &mut self,
1158 effect_name: &str,
1159 visited: &mut std::collections::HashSet<String>,
1160 ) {
1161 if !visited.insert(effect_name.to_string()) {
1162 return;
1163 }
1164 if let Some(ops) = self.effect_op_types.get(effect_name).cloned() {
1165 for (op_name, fn_ty) in ops {
1166 self.env.define(op_name, fn_ty);
1167 }
1168 }
1169 if let Some(components) = self.effect_components.get(effect_name).cloned() {
1170 for comp in &components {
1171 self.inject_effect_ops_into_env(comp, visited);
1172 }
1173 }
1174 }
1175
1176 fn check_trait_bounds_at_call(
1187 &mut self,
1188 fn_name: &str,
1189 sig: &FnSig,
1190 fresh_map: &HashMap<TypeVarId, Type>,
1191 span: Span,
1192 ) {
1193 let impl_table = match &self.impl_table {
1194 Some(t) => t,
1195 None => return, };
1197
1198 let name_to_fresh: HashMap<&str, &Type> = sig
1201 .generic_params
1202 .iter()
1203 .zip(sig.generic_var_ids.iter())
1204 .filter_map(|(name, orig_id)| {
1205 fresh_map
1206 .get(orig_id)
1207 .map(|fresh_ty| (name.as_str(), fresh_ty))
1208 })
1209 .collect();
1210
1211 for clause in &sig.where_clause {
1212 let param_name = &clause.param.name;
1213 let concrete_ty = match name_to_fresh.get(param_name.as_str()) {
1214 Some(fresh) => self.subst.apply(fresh),
1215 None => continue, };
1217
1218 for bound_path in &clause.bounds {
1219 let trait_name = bound_path
1220 .segments
1221 .iter()
1222 .map(|s| s.name.as_str())
1223 .collect::<Vec<_>>()
1224 .join(".");
1225 let trait_ref = TraitRef::new(&trait_name);
1226 if resolve_impl(&trait_ref, &concrete_ty, impl_table).is_none() {
1227 self.diags.error(
1228 E_WHERE_CLAUSE,
1229 format!(
1230 "type `{concrete_ty:?}` does not satisfy bound `{trait_name}` \
1231 required by function `{fn_name}`",
1232 ),
1233 span,
1234 );
1235 }
1236 }
1237 }
1238 }
1239
1240 fn infer_node(&mut self, node: &mut AIRNode) -> Type {
1247 let span = node.span;
1248 let ty = match &node.kind {
1249 NodeKind::Literal { lit } => self.infer_literal(lit),
1251
1252 NodeKind::Identifier { name } => {
1254 let name = name.name.clone();
1255 match self.env.lookup(&name) {
1256 Some(ty) => {
1257 let ty = ty.clone();
1258 self.subst.apply(&ty)
1259 }
1260 None => {
1261 self.diags.error(
1262 E_UNDEFINED_VAR,
1263 format!("undefined variable `{name}`"),
1264 span,
1265 );
1266 Type::Error
1267 }
1268 }
1269 }
1270
1271 NodeKind::BinaryOp { op, .. } => {
1273 let op = *op;
1274 let (lt, rt) = if let NodeKind::BinaryOp { left, right, .. } = &mut node.kind {
1276 let lt = self.infer_node(left);
1277 let rt = self.infer_node(right);
1278 (lt, rt)
1279 } else {
1280 unreachable!()
1281 };
1282 self.infer_binop(op, <, &rt, span)
1283 }
1284
1285 NodeKind::UnaryOp { op, .. } => {
1287 let op = *op;
1288 let operand_ty = if let NodeKind::UnaryOp { operand, .. } = &mut node.kind {
1289 self.infer_node(operand)
1290 } else {
1291 unreachable!()
1292 };
1293 self.infer_unop(op, &operand_ty, span)
1294 }
1295
1296 NodeKind::FieldAccess { field, .. } => {
1298 let field_name = field.name.clone();
1299 let obj_ty = if let NodeKind::FieldAccess { object, .. } = &mut node.kind {
1300 self.infer_node(object)
1301 } else {
1302 unreachable!()
1303 };
1304 let obj_ty = self.subst.apply(&obj_ty);
1305 match &obj_ty {
1306 Type::Error => Type::Error,
1307 Type::Named(nt) => {
1308 if let Some(methods) = self.method_types.get(&nt.name) {
1310 if let Some(fn_ty) = methods.get(&field_name) {
1311 return self.record(node, fn_ty.clone());
1312 }
1313 }
1314 if let Some(fields) = self.record_field_types.get(&nt.name) {
1316 if let Some((_, field_ty)) =
1317 fields.iter().find(|(n, _)| n == &field_name)
1318 {
1319 return self.record(node, field_ty.clone());
1320 }
1321 }
1322 self.fresh_var()
1323 }
1324 Type::Generic(g) => {
1325 if let Some(methods) = self.method_types.get(&g.constructor) {
1328 if let Some(fn_ty) = methods.get(&field_name) {
1329 let resolved = if let Some(params) =
1330 self.record_generic_params.get(&g.constructor)
1331 {
1332 substitute_type_params(fn_ty, params, &g.args)
1333 } else {
1334 fn_ty.clone()
1335 };
1336 return self.record(node, resolved);
1337 }
1338 }
1339 if let Some(fields) = self.record_field_types.get(&g.constructor) {
1340 if let Some((_, field_ty)) =
1341 fields.iter().find(|(n, _)| n == &field_name)
1342 {
1343 let resolved = if let Some(params) =
1344 self.record_generic_params.get(&g.constructor)
1345 {
1346 substitute_type_params(field_ty, params, &g.args)
1347 } else {
1348 field_ty.clone()
1349 };
1350 return self.record(node, resolved);
1351 }
1352 }
1353 if let Some(fn_ty) =
1355 self.resolve_builtin_method_fn_type(&obj_ty, &field_name)
1356 {
1357 fn_ty
1358 } else {
1359 self.fresh_var()
1360 }
1361 }
1362 Type::TypeVar(id) => {
1363 if let Some(bounds) = self.type_var_bounds.get(id).cloned() {
1366 let self_params = vec!["Self".to_string()];
1367 let self_args = vec![obj_ty.clone()];
1368 for trait_name in &bounds {
1369 if let Some(methods) = self.trait_method_types.get(trait_name).cloned() {
1370 if let Some(fn_ty) = methods.get(&field_name) {
1371 let resolved = substitute_type_params(
1372 fn_ty,
1373 &self_params,
1374 &self_args,
1375 );
1376 return self.record(node, resolved);
1377 }
1378 }
1379 }
1380 }
1381 if let Some(fn_ty) = self.resolve_builtin_method_fn_type(&obj_ty, &field_name) {
1383 fn_ty
1384 } else {
1385 self.fresh_var()
1386 }
1387 }
1388 _ => {
1389 if let Some(fn_ty) = self.resolve_builtin_method_fn_type(&obj_ty, &field_name) {
1391 fn_ty
1392 } else {
1393 self.fresh_var()
1395 }
1396 }
1397 }
1398 }
1399
1400 NodeKind::Index { .. } => {
1402 let (obj_ty, idx_ty) = if let NodeKind::Index { object, index } = &mut node.kind {
1403 let o = self.infer_node(object);
1404 let i = self.infer_node(index);
1405 (o, i)
1406 } else {
1407 unreachable!()
1408 };
1409 self.unify_or_error(&idx_ty, &Type::Primitive(PrimitiveType::Int), span, "index");
1411 match &obj_ty {
1413 Type::Error => Type::Error,
1414 Type::Generic(g) if g.constructor == "List" && g.args.len() == 1 => {
1415 g.args[0].clone()
1416 }
1417 _ => self.fresh_var(),
1418 }
1419 }
1420
1421 NodeKind::Call { .. } => {
1423 let (callee_clone, args_clone, _type_args_clone) = if let NodeKind::Call {
1425 callee,
1426 args,
1427 type_args,
1428 } = &node.kind
1429 {
1430 (*callee.clone(), args.clone(), type_args.clone())
1431 } else {
1432 unreachable!()
1433 };
1434
1435 let callee_name = if let NodeKind::Identifier { name } = &callee_clone.kind {
1437 Some(name.name.clone())
1438 } else {
1439 None
1440 };
1441
1442 let callee_ty = if let NodeKind::Call { callee, .. } = &mut node.kind {
1444 self.infer_node(callee)
1445 } else {
1446 unreachable!()
1447 };
1448
1449 let mut call_site_info: Option<(String, FnSig, HashMap<TypeVarId, Type>)> = None;
1453 let effective_ty = match (&callee_name, &callee_ty) {
1454 (Some(name), Type::Function(f)) => {
1455 if let Some(sig) = self.fn_sigs.get(name).cloned() {
1456 if !sig.generic_params.is_empty() {
1457 let fresh_map: HashMap<TypeVarId, Type> = sig
1458 .generic_var_ids
1459 .iter()
1460 .map(|&id| (id, self.fresh_var()))
1461 .collect();
1462 let ety = Type::Function(FnType {
1463 params: f
1464 .params
1465 .iter()
1466 .map(|t| self.replace_type_vars(t, &fresh_map))
1467 .collect(),
1468 ret: Box::new(self.replace_type_vars(&f.ret, &fresh_map)),
1469 effects: f.effects.clone(),
1470 });
1471 call_site_info = Some((name.clone(), sig, fresh_map));
1472 ety
1473 } else {
1474 callee_ty.clone()
1475 }
1476 } else {
1477 callee_ty.clone()
1478 }
1479 }
1480 _ => callee_ty.clone(),
1481 };
1482
1483 let ret_ty = self.check_call(callee_clone.span, &effective_ty, &args_clone, span);
1484
1485 if let NodeKind::Call { args, .. } = &mut node.kind {
1487 match &effective_ty {
1488 Type::Function(f) => {
1489 for (arg, param_ty) in args.iter_mut().zip(f.params.iter()) {
1490 let pt = self.subst.apply(param_ty);
1491 self.check_node(&mut arg.value, &pt);
1492 }
1493 }
1494 _ => {
1495 for arg in args.iter_mut() {
1496 self.infer_node(&mut arg.value);
1497 }
1498 }
1499 }
1500 }
1501
1502 if let Some((fn_name, sig, fresh_map)) = &call_site_info {
1505 self.check_trait_bounds_at_call(fn_name, sig, fresh_map, span);
1506 }
1507
1508 ret_ty
1509 }
1510
1511 NodeKind::MethodCall { method, .. } => {
1513 let method_name = method.name.clone();
1514 let receiver_ty =
1515 if let NodeKind::MethodCall { receiver, args, .. } = &mut node.kind {
1516 let rt = self.infer_node(receiver);
1517 for arg in args.iter_mut() {
1518 self.infer_node(&mut arg.value);
1519 }
1520 rt
1521 } else {
1522 unreachable!()
1523 };
1524 self.resolve_method_return_type(&receiver_ty, &method_name)
1525 }
1526
1527 NodeKind::Lambda { .. } => {
1529 let (param_tys, body_ty) = self.infer_lambda(node);
1531 Type::Function(FnType {
1532 params: param_tys,
1533 ret: Box::new(body_ty),
1534 effects: vec![],
1535 })
1536 }
1537
1538 NodeKind::Pipe { .. } => {
1540 let (lty, rty) = if let NodeKind::Pipe { left, right } = &mut node.kind {
1542 let l = self.infer_node(left);
1543 let r = self.infer_node(right);
1544 (l, r)
1545 } else {
1546 unreachable!()
1547 };
1548 match &rty {
1550 Type::Function(f) if f.params.len() == 1 => {
1551 let param_ty = self.subst.apply(&f.params[0]);
1552 self.unify_or_error(<y, ¶m_ty, span, "pipe");
1553 self.subst.apply(&f.ret)
1554 }
1555 Type::Error => Type::Error,
1556 _ => self.fresh_var(),
1557 }
1558 }
1559
1560 NodeKind::If { .. } => self.infer_if(node),
1562
1563 NodeKind::Match { .. } => self.infer_match(node),
1565
1566 NodeKind::Block { .. } => self.infer_block(node),
1568
1569 NodeKind::LetBinding { .. } => {
1571 self.check_let_binding(node);
1572 Type::Primitive(PrimitiveType::Void)
1573 }
1574
1575 NodeKind::Return { .. } => {
1577 let expected = self.return_ty_stack.last().cloned();
1578 if let NodeKind::Return { value } = &mut node.kind {
1579 match (value, &expected) {
1580 (Some(v), Some(e)) => {
1581 let et = e.clone();
1582 self.check_node(v, &et);
1583 }
1584 (Some(v), None) => {
1585 self.infer_node(v);
1586 }
1587 _ => {}
1588 }
1589 }
1590 Type::Primitive(PrimitiveType::Never)
1591 }
1592
1593 NodeKind::ListLiteral { .. } => {
1595 let elem_ty = self.fresh_var();
1596 if let NodeKind::ListLiteral { elems } = &mut node.kind {
1597 for elem in elems.iter_mut() {
1598 let et = elem_ty.clone();
1599 self.check_node(elem, &et);
1600 }
1601 }
1602 Type::Generic(GenericType {
1603 constructor: "List".into(),
1604 args: vec![self.subst.apply(&elem_ty)],
1605 })
1606 }
1607
1608 NodeKind::TupleLiteral { .. } => {
1610 let elem_tys: Vec<Type> = if let NodeKind::TupleLiteral { elems } = &mut node.kind {
1611 elems.iter_mut().map(|e| self.infer_node(e)).collect()
1612 } else {
1613 vec![]
1614 };
1615 Type::Tuple(elem_tys)
1616 }
1617
1618 NodeKind::MapLiteral { .. } => {
1620 let k_ty = self.fresh_var();
1621 let v_ty = self.fresh_var();
1622 if let NodeKind::MapLiteral { entries } = &mut node.kind {
1623 for entry in entries.iter_mut() {
1624 let kt = k_ty.clone();
1625 let vt = v_ty.clone();
1626 self.check_node(&mut entry.key, &kt);
1627 self.check_node(&mut entry.value, &vt);
1628 }
1629 }
1630 Type::Generic(GenericType {
1631 constructor: "Map".into(),
1632 args: vec![self.subst.apply(&k_ty), self.subst.apply(&v_ty)],
1633 })
1634 }
1635
1636 NodeKind::SetLiteral { .. } => {
1638 let elem_ty = self.fresh_var();
1639 if let NodeKind::SetLiteral { elems } = &mut node.kind {
1640 for elem in elems.iter_mut() {
1641 let et = elem_ty.clone();
1642 self.check_node(elem, &et);
1643 }
1644 }
1645 Type::Generic(GenericType {
1646 constructor: "Set".into(),
1647 args: vec![self.subst.apply(&elem_ty)],
1648 })
1649 }
1650
1651 NodeKind::Interpolation { .. } => {
1653 if let NodeKind::Interpolation { parts } = &mut node.kind {
1654 for part in parts.iter_mut() {
1655 if let bock_air::AirInterpolationPart::Expr(e) = part {
1656 self.infer_node(e);
1657 }
1658 }
1659 }
1660 Type::Primitive(PrimitiveType::String)
1661 }
1662
1663 NodeKind::ResultConstruct { variant, .. } => {
1665 let variant = *variant;
1668 let has_value =
1669 matches!(&node.kind, NodeKind::ResultConstruct { value: Some(_), .. });
1670 let inner_ty = if has_value {
1671 if let NodeKind::ResultConstruct { value: Some(v), .. } = &mut node.kind {
1672 self.infer_node(v)
1673 } else {
1674 unreachable!()
1675 }
1676 } else {
1677 Type::Primitive(PrimitiveType::Void)
1678 };
1679 let err_ty = self.fresh_var();
1680 let ok_ty = self.fresh_var();
1681 match variant {
1682 bock_air::ResultVariant::Ok => {
1683 self.unify_or_error(&inner_ty, &ok_ty, span, "Ok construct");
1684 Type::Result(Box::new(ok_ty), Box::new(err_ty))
1685 }
1686 bock_air::ResultVariant::Err => {
1687 self.unify_or_error(&inner_ty, &err_ty, span, "Err construct");
1688 Type::Result(Box::new(ok_ty), Box::new(err_ty))
1689 }
1690 }
1691 }
1692
1693 NodeKind::Propagate { .. } => {
1695 let inner_ty = if let NodeKind::Propagate { expr } = &mut node.kind {
1696 self.infer_node(expr)
1697 } else {
1698 unreachable!()
1699 };
1700 match &inner_ty {
1702 Type::Result(ok, _) => *ok.clone(),
1703 Type::Optional(inner) => *inner.clone(),
1704 Type::Error => Type::Error,
1705 _ => self.fresh_var(),
1706 }
1707 }
1708
1709 NodeKind::Await { .. } => {
1711 if let NodeKind::Await { expr } = &mut node.kind {
1712 self.infer_node(expr);
1713 }
1714 self.fresh_var()
1715 }
1716
1717 NodeKind::Borrow { .. } | NodeKind::MutableBorrow { .. } => {
1719 match &mut node.kind {
1721 NodeKind::Borrow { expr } | NodeKind::MutableBorrow { expr } => {
1722 self.infer_node(expr)
1723 }
1724 _ => unreachable!(),
1725 }
1726 }
1727
1728 NodeKind::Move { .. } => {
1729 if let NodeKind::Move { expr } = &mut node.kind {
1730 self.infer_node(expr)
1731 } else {
1732 unreachable!()
1733 }
1734 }
1735
1736 NodeKind::Assign { .. } => {
1738 let (tty, vty) = if let NodeKind::Assign { target, value, .. } = &mut node.kind {
1739 let t = self.infer_node(target);
1740 let v = self.infer_node(value);
1741 (t, v)
1742 } else {
1743 unreachable!()
1744 };
1745 self.unify_or_error(&tty, &vty, span, "assignment");
1746 Type::Primitive(PrimitiveType::Void)
1747 }
1748
1749 NodeKind::Range { .. } => {
1751 let (lty, hty) = if let NodeKind::Range { lo, hi, .. } = &mut node.kind {
1752 let l = self.infer_node(lo);
1753 let h = self.infer_node(hi);
1754 (l, h)
1755 } else {
1756 unreachable!()
1757 };
1758 self.unify_or_error(<y, &hty, span, "range bounds");
1759 Type::Generic(GenericType {
1760 constructor: "Range".into(),
1761 args: vec![lty],
1762 })
1763 }
1764
1765 NodeKind::For { .. } => {
1767 self.env.push_scope();
1768 if let NodeKind::For {
1769 pattern,
1770 iterable,
1771 body,
1772 } = &mut node.kind
1773 {
1774 let iter_ty = self.infer_node(iterable);
1775 let elem_ty = match &iter_ty {
1777 Type::Generic(g) if g.constructor == "List" && g.args.len() == 1 => {
1778 g.args[0].clone()
1779 }
1780 Type::Generic(g) if g.constructor == "Range" && g.args.len() == 1 => {
1781 g.args[0].clone()
1782 }
1783 _ => self.fresh_var(),
1784 };
1785 self.bind_pattern_type(pattern, &elem_ty);
1786 self.infer_node(body);
1787 }
1788 self.env.pop_scope();
1789 Type::Primitive(PrimitiveType::Void)
1790 }
1791
1792 NodeKind::While { .. } => {
1793 if let NodeKind::While { condition, body } = &mut node.kind {
1794 let bool_ty = Type::Primitive(PrimitiveType::Bool);
1795 self.check_node(condition, &bool_ty);
1796 self.infer_node(body);
1797 }
1798 Type::Primitive(PrimitiveType::Void)
1799 }
1800
1801 NodeKind::Loop { .. } => {
1802 if let NodeKind::Loop { body } = &mut node.kind {
1803 self.infer_node(body);
1804 }
1805 self.fresh_var()
1807 }
1808
1809 NodeKind::Break { .. } => {
1810 if let NodeKind::Break { value: Some(v) } = &mut node.kind {
1811 self.infer_node(v);
1812 }
1813 Type::Primitive(PrimitiveType::Never)
1814 }
1815
1816 NodeKind::Continue => Type::Primitive(PrimitiveType::Never),
1817
1818 NodeKind::Guard { .. } => {
1819 if let NodeKind::Guard {
1820 let_pattern,
1821 condition,
1822 else_block,
1823 } = &mut node.kind
1824 {
1825 if let_pattern.is_some() {
1826 let cond_ty = self.infer_node(condition);
1830 if let Some(pat) = let_pattern {
1831 self.bind_pattern_type(pat, &cond_ty);
1832 }
1833 } else {
1834 let bool_ty = Type::Primitive(PrimitiveType::Bool);
1835 self.check_node(condition, &bool_ty);
1836 }
1837 self.infer_node(else_block);
1838 }
1839 Type::Primitive(PrimitiveType::Void)
1840 }
1841
1842 NodeKind::Compose { .. } => {
1844 if let NodeKind::Compose { left, right } = &mut node.kind {
1845 self.infer_node(left);
1846 self.infer_node(right);
1847 }
1848 self.fresh_var() }
1850
1851 NodeKind::Placeholder => self.fresh_var(),
1853
1854 NodeKind::Unreachable => Type::Primitive(PrimitiveType::Never),
1856
1857 NodeKind::HandlingBlock { .. } => {
1859 if let NodeKind::HandlingBlock { handlers, body } = &mut node.kind {
1860 for hp in handlers.iter_mut() {
1861 self.infer_node(&mut hp.handler);
1862 }
1863 self.infer_node(body)
1864 } else {
1865 unreachable!()
1866 }
1867 }
1868
1869 NodeKind::RecordConstruct { path, .. } => {
1871 let name = path
1872 .segments
1873 .last()
1874 .map(|s| s.name.clone())
1875 .unwrap_or_default();
1876
1877 let generic_params = self.record_generic_params.get(&name).cloned();
1880 let fresh_type_args: Option<Vec<Type>> = generic_params
1881 .as_ref()
1882 .map(|params| params.iter().map(|_| self.fresh_var()).collect());
1883
1884 if let NodeKind::RecordConstruct { fields, spread, .. } = &mut node.kind {
1885 let declared_fields = self.record_field_types.get(&name).cloned();
1887 for f in fields.iter_mut() {
1888 if let Some(v) = &mut f.value {
1889 if let Some(ref decl) = declared_fields {
1890 if let Some((_, expected_ty)) =
1891 decl.iter().find(|(n, _)| n == &f.name.name)
1892 {
1893 let et = if let (Some(ref params), Some(ref args)) =
1896 (&generic_params, &fresh_type_args)
1897 {
1898 substitute_type_params(expected_ty, params, args)
1899 } else {
1900 expected_ty.clone()
1901 };
1902 self.check_node(v, &et);
1903 } else {
1904 self.infer_node(v);
1905 }
1906 } else {
1907 self.infer_node(v);
1908 }
1909 }
1910 }
1911 if let Some(s) = spread {
1912 self.infer_node(s);
1913 }
1914 }
1915
1916 if let Some(type_args) = fresh_type_args {
1918 Type::Generic(GenericType {
1919 constructor: name,
1920 args: type_args,
1921 })
1922 } else {
1923 self.env
1926 .lookup(&name)
1927 .cloned()
1928 .unwrap_or(Type::Named(crate::NamedType { name }))
1929 }
1930 }
1931
1932 NodeKind::Error => Type::Error,
1934
1935 _ => self.fresh_var(),
1937 };
1938
1939 self.record(node, ty)
1940 }
1941
1942 fn check_node(&mut self, node: &mut AIRNode, expected: &Type) {
1945 let span = node.span;
1946 match &node.kind {
1947 NodeKind::ListLiteral { .. } => {
1949 if let Type::Generic(g) = expected {
1950 if g.constructor == "List" && g.args.len() == 1 {
1951 let elem_ty = g.args[0].clone();
1952 if let NodeKind::ListLiteral { elems } = &mut node.kind {
1953 for elem in elems.iter_mut() {
1954 let et = elem_ty.clone();
1955 self.check_node(elem, &et);
1956 }
1957 }
1958 self.record(node, expected.clone());
1959 return;
1960 }
1961 }
1962 let inferred = self.infer_node(node);
1964 self.unify_or_error(&inferred, expected, span, "list literal");
1965 }
1966
1967 NodeKind::Lambda { .. } => {
1969 if let Type::Function(f_expected) = expected {
1970 let param_types = f_expected.params.clone();
1971 let ret_ty = *f_expected.ret.clone();
1972
1973 self.env.push_scope();
1974 if let NodeKind::Lambda { params, body } = &mut node.kind {
1975 for (param, pty) in params.iter_mut().zip(param_types.iter()) {
1976 if let Some(name) = param.kind.param_pat_name() {
1977 self.env.define(name, pty.clone());
1978 }
1979 self.record(param, pty.clone());
1980 }
1981 self.check_node(body, &ret_ty);
1982 }
1983 self.env.pop_scope();
1984 self.record(node, expected.clone());
1985 } else {
1986 let inferred = self.infer_node(node);
1987 self.unify_or_error(&inferred, expected, span, "lambda");
1988 }
1989 }
1990
1991 NodeKind::Match { .. } => {
1993 let scrutinee_ty = if let NodeKind::Match { scrutinee, .. } = &mut node.kind {
1996 self.infer_node(scrutinee)
1997 } else {
1998 unreachable!()
1999 };
2000
2001 if let NodeKind::Match { arms, .. } = &mut node.kind {
2002 for arm in arms.iter_mut() {
2003 self.env.push_scope();
2004 if let NodeKind::MatchArm {
2005 pattern,
2006 guard,
2007 body,
2008 } = &mut arm.kind
2009 {
2010 self.bind_pattern_type(pattern, &scrutinee_ty.clone());
2011 if let Some(g) = guard {
2012 let bt = Type::Primitive(PrimitiveType::Bool);
2013 self.check_node(g, &bt);
2014 }
2015 let et = expected.clone();
2016 self.check_node(body, &et);
2017 }
2018 self.env.pop_scope();
2019 self.record(arm, expected.clone());
2020 }
2021 }
2022 self.record(node, expected.clone());
2023 }
2024
2025 NodeKind::If { .. } => {
2027 if let NodeKind::If {
2028 condition,
2029 then_block,
2030 else_block,
2031 ..
2032 } = &mut node.kind
2033 {
2034 let bt = Type::Primitive(PrimitiveType::Bool);
2035 self.check_node(condition, &bt);
2036 let et = expected.clone();
2037 self.check_node(then_block, &et);
2038 if let Some(eb) = else_block {
2039 let et2 = expected.clone();
2040 self.check_node(eb, &et2);
2041 }
2042 }
2043 self.record(node, expected.clone());
2044 }
2045
2046 NodeKind::Block { .. } => {
2048 if let NodeKind::Block { stmts, tail } = &mut node.kind {
2049 self.env.push_scope();
2050 for stmt in stmts.iter_mut() {
2051 self.infer_node(stmt);
2052 }
2053 if let Some(tail_expr) = tail {
2054 let et = expected.clone();
2055 self.check_node(tail_expr, &et);
2056 } else {
2057 let void_ty = Type::Primitive(PrimitiveType::Void);
2059 self.unify_or_error(&void_ty, expected, node.span, "block");
2060 }
2061 self.env.pop_scope();
2062 }
2063 self.record(node, expected.clone());
2064 }
2065
2066 _ => {
2068 let inferred = self.infer_node(node);
2069 let expected = self.subst.apply(expected);
2070 self.unify_or_error(&inferred, &expected, span, "expression");
2071 }
2072 }
2073 }
2074
2075 fn infer_if(&mut self, node: &mut AIRNode) -> Type {
2078 let span = node.span;
2079 if let NodeKind::If {
2080 condition,
2081 then_block,
2082 else_block,
2083 ..
2084 } = &mut node.kind
2085 {
2086 let bool_ty = Type::Primitive(PrimitiveType::Bool);
2087 self.check_node(condition, &bool_ty);
2088 let then_ty = self.infer_node(then_block);
2089 if let Some(eb) = else_block {
2090 let else_ty = self.infer_node(eb);
2091 let never = Type::Primitive(PrimitiveType::Never);
2092 let (a, b) = if then_ty == never {
2094 (&else_ty, &then_ty)
2095 } else {
2096 (&then_ty, &else_ty)
2097 };
2098 self.unify_or_error(a, b, span, "if-else branches")
2099 } else {
2100 Type::Primitive(PrimitiveType::Void)
2102 }
2103 } else {
2104 unreachable!()
2105 }
2106 }
2107
2108 fn infer_match(&mut self, node: &mut AIRNode) -> Type {
2111 let span = node.span;
2112 let never = Type::Primitive(PrimitiveType::Never);
2113 let scrutinee_ty = if let NodeKind::Match { scrutinee, .. } = &mut node.kind {
2115 self.infer_node(scrutinee)
2116 } else {
2117 unreachable!()
2118 };
2119
2120 let mut arm_types: Vec<Type> = Vec::new();
2122 if let NodeKind::Match { arms, .. } = &mut node.kind {
2123 for arm in arms.iter_mut() {
2124 self.env.push_scope();
2125 let arm_ty = if let NodeKind::MatchArm {
2126 pattern,
2127 guard,
2128 body,
2129 } = &mut arm.kind
2130 {
2131 self.bind_pattern_type(pattern, &scrutinee_ty.clone());
2132 if let Some(g) = guard {
2133 let bt = Type::Primitive(PrimitiveType::Bool);
2134 self.check_node(g, &bt);
2135 }
2136 self.infer_node(body)
2137 } else {
2138 self.fresh_var()
2139 };
2140 self.env.pop_scope();
2141 self.record(arm, arm_ty.clone());
2142 arm_types.push(arm_ty);
2143 }
2144 }
2145
2146 let non_never: Vec<&Type> = arm_types.iter().filter(|t| **t != never).collect();
2148 if non_never.is_empty() {
2149 never
2151 } else {
2152 let result_ty = self.fresh_var();
2153 for t in &non_never {
2154 let rt = result_ty.clone();
2155 self.unify_or_error(t, &rt, span, "match arm");
2156 }
2157 self.subst.apply(&result_ty)
2158 }
2159 }
2160
2161 fn infer_block(&mut self, node: &mut AIRNode) -> Type {
2164 self.env.push_scope();
2165 let ty = if let NodeKind::Block { stmts, tail } = &mut node.kind {
2166 for stmt in stmts.iter_mut() {
2167 self.infer_node(stmt);
2168 }
2169 if let Some(tail_expr) = tail {
2170 self.infer_node(tail_expr)
2171 } else {
2172 Type::Primitive(PrimitiveType::Void)
2173 }
2174 } else {
2175 unreachable!()
2176 };
2177 self.env.pop_scope();
2178 ty
2179 }
2180
2181 fn check_let_binding(&mut self, node: &mut AIRNode) {
2184 let (ty_node, _value_clone) = match &node.kind {
2185 NodeKind::LetBinding { ty, value, .. } => (ty.clone(), *value.clone()),
2186 _ => return,
2187 };
2188
2189 if let Some(ty_ann) = &ty_node {
2190 let expected = self.air_type_node_to_type(ty_ann, &HashMap::new());
2191 if let NodeKind::LetBinding { value, pattern, .. } = &mut node.kind {
2192 self.check_node(value, &expected);
2193 self.bind_pattern_type(pattern, &expected);
2194 }
2195 } else {
2196 let inferred = if let NodeKind::LetBinding { value, .. } = &mut node.kind {
2198 self.infer_node(value)
2199 } else {
2200 unreachable!()
2201 };
2202 let resolved = self.subst.apply(&inferred);
2203 if let NodeKind::LetBinding { pattern, .. } = &mut node.kind {
2204 self.bind_pattern_type(pattern, &resolved);
2205 }
2206 }
2207 }
2208
2209 fn infer_lambda(&mut self, node: &mut AIRNode) -> (Vec<Type>, Type) {
2213 self.env.push_scope();
2214 let (param_tys, body_ty) = if let NodeKind::Lambda { params, body } = &mut node.kind {
2215 let param_tys: Vec<Type> = params
2216 .iter_mut()
2217 .map(|p| {
2218 let ty = self.fresh_var();
2219 if let Some(name) = p.kind.param_pat_name() {
2220 self.env.define(name, ty.clone());
2221 }
2222 ty
2223 })
2224 .collect();
2225 let body_ty = self.infer_node(body);
2226 (param_tys, body_ty)
2227 } else {
2228 unreachable!()
2229 };
2230 self.env.pop_scope();
2231 (param_tys, body_ty)
2232 }
2233
2234 fn check_call(
2239 &mut self,
2240 callee_span: Span,
2241 callee_ty: &Type,
2242 args: &[bock_air::AirArg],
2243 call_span: Span,
2244 ) -> Type {
2245 match callee_ty {
2246 Type::Error => Type::Error,
2247 Type::Function(f) => {
2248 if f.params.len() != args.len() {
2250 self.diags.error(
2251 E_ARITY_MISMATCH,
2252 format!(
2253 "function expects {} argument(s), got {}",
2254 f.params.len(),
2255 args.len()
2256 ),
2257 call_span,
2258 );
2259 return Type::Error;
2260 }
2261 self.subst.apply(&f.ret)
2262 }
2263 _ => {
2264 if let Type::Named(nt) = callee_ty {
2267 if let Some(sig) = self.fn_sigs.get(&nt.name).cloned() {
2268 return self.instantiate_and_check(&nt.name, &sig, args, call_span);
2269 }
2270 }
2271 if matches!(callee_ty, Type::TypeVar(_)) {
2276 return self.fresh_var();
2277 }
2278 self.diags.error(
2279 E_NOT_CALLABLE,
2280 format!("expected a function type, got {callee_ty:?}"),
2281 callee_span,
2282 );
2283 Type::Error
2284 }
2285 }
2286 }
2287
2288 fn instantiate_and_check(
2295 &mut self,
2296 fn_name: &str,
2297 sig: &FnSig,
2298 args: &[bock_air::AirArg],
2299 span: Span,
2300 ) -> Type {
2301 if sig.param_types.len() != args.len() {
2302 self.diags.error(
2303 E_ARITY_MISMATCH,
2304 format!(
2305 "function expects {} argument(s), got {}",
2306 sig.param_types.len(),
2307 args.len()
2308 ),
2309 span,
2310 );
2311 return Type::Error;
2312 }
2313
2314 let fresh_map: HashMap<TypeVarId, Type> = sig
2317 .generic_var_ids
2318 .iter()
2319 .map(|&id| (id, self.fresh_var()))
2320 .collect();
2321
2322 let _param_tys: Vec<Type> = sig
2325 .param_types
2326 .iter()
2327 .map(|t| self.replace_type_vars(t, &fresh_map))
2328 .collect();
2329
2330 self.check_trait_bounds_at_call(fn_name, sig, &fresh_map, span);
2332
2333 self.replace_type_vars(&sig.return_type, &fresh_map)
2335 }
2336
2337 fn resolve_method_return_type(&self, receiver_ty: &Type, method: &str) -> Type {
2345 let receiver_ty = self.subst.apply(receiver_ty);
2346 match &receiver_ty {
2347 Type::Error => Type::Error,
2348 Type::Generic(g) if g.constructor == "List" && g.args.len() == 1 => {
2350 let elem_ty = &g.args[0];
2351 match method {
2352 "len" | "length" | "count" => Type::Primitive(PrimitiveType::Int),
2353 "first" | "last" | "find" | "get" => {
2354 Type::Optional(Box::new(elem_ty.clone()))
2355 }
2356 "index_of" => {
2357 Type::Optional(Box::new(Type::Primitive(PrimitiveType::Int)))
2358 }
2359 "contains" | "is_empty" | "any" | "all" => {
2360 Type::Primitive(PrimitiveType::Bool)
2361 }
2362 "push" | "append" | "pop" | "insert" | "remove" | "concat"
2363 | "reverse" | "sort" | "filter" | "dedup" | "take" | "skip"
2364 | "flat_map" | "slice" | "flatten" => receiver_ty.clone(),
2365 "clear" | "for_each" => Type::Primitive(PrimitiveType::Void),
2366 "join" | "display" => Type::Primitive(PrimitiveType::String),
2367 "enumerate" => Type::Generic(GenericType {
2368 constructor: "List".into(),
2369 args: vec![Type::Tuple(vec![
2370 Type::Primitive(PrimitiveType::Int),
2371 elem_ty.clone(),
2372 ])],
2373 }),
2374 "to_set" => Type::Generic(GenericType {
2375 constructor: "Set".into(),
2376 args: vec![elem_ty.clone()],
2377 }),
2378 _ => self.fresh_var(),
2379 }
2380 }
2381 Type::Generic(g) if g.constructor == "Map" && g.args.len() == 2 => {
2383 let key_ty = &g.args[0];
2384 let val_ty = &g.args[1];
2385 match method {
2386 "len" | "length" | "count" => Type::Primitive(PrimitiveType::Int),
2387 "contains_key" | "is_empty" => Type::Primitive(PrimitiveType::Bool),
2388 "get" => Type::Optional(Box::new(val_ty.clone())),
2389 "set" | "delete" | "merge" | "filter" => receiver_ty.clone(),
2390 "for_each" => Type::Primitive(PrimitiveType::Void),
2391 "keys" => Type::Generic(GenericType {
2392 constructor: "List".into(),
2393 args: vec![key_ty.clone()],
2394 }),
2395 "values" => Type::Generic(GenericType {
2396 constructor: "List".into(),
2397 args: vec![val_ty.clone()],
2398 }),
2399 "entries" | "to_list" => Type::Generic(GenericType {
2400 constructor: "List".into(),
2401 args: vec![Type::Tuple(vec![key_ty.clone(), val_ty.clone()])],
2402 }),
2403 _ => self.fresh_var(),
2404 }
2405 }
2406 Type::Primitive(PrimitiveType::String) => match method {
2408 "len" | "length" | "count" | "byte_len" => {
2409 Type::Primitive(PrimitiveType::Int)
2410 }
2411 "contains" | "starts_with" | "ends_with" | "is_empty"
2412 | "regex_match" => Type::Primitive(PrimitiveType::Bool),
2413 "to_upper" | "to_lower" | "trim" | "trim_start" | "trim_end"
2414 | "reverse" | "slice" | "substring" | "replace" | "to_string"
2415 | "display" | "repeat" | "pad_start" | "pad_end" | "format"
2416 | "regex_replace" | "join" => Type::Primitive(PrimitiveType::String),
2417 "split" | "regex_find" => Type::Generic(GenericType {
2418 constructor: "List".into(),
2419 args: vec![Type::Primitive(PrimitiveType::String)],
2420 }),
2421 "chars" => Type::Generic(GenericType {
2422 constructor: "List".into(),
2423 args: vec![Type::Primitive(PrimitiveType::Char)],
2424 }),
2425 "bytes" => Type::Generic(GenericType {
2426 constructor: "List".into(),
2427 args: vec![Type::Primitive(PrimitiveType::Int)],
2428 }),
2429 "index_of" => Type::Optional(Box::new(Type::Primitive(PrimitiveType::Int))),
2430 "char_at" => {
2431 Type::Optional(Box::new(Type::Primitive(PrimitiveType::Char)))
2432 }
2433 _ => self.fresh_var(),
2434 },
2435 Type::Primitive(PrimitiveType::Int) => match method {
2437 "abs" | "min" | "max" | "clamp" | "shift_left" | "shift_right"
2438 | "compare" | "hash_code" => Type::Primitive(PrimitiveType::Int),
2439 "to_float" => Type::Primitive(PrimitiveType::Float),
2440 "to_string" | "display" => Type::Primitive(PrimitiveType::String),
2441 "equals" => Type::Primitive(PrimitiveType::Bool),
2442 _ => self.fresh_var(),
2443 },
2444 Type::Primitive(PrimitiveType::Float) => match method {
2446 "abs" | "floor" | "ceil" | "round" | "sqrt" | "min" | "max" | "clamp" => {
2447 Type::Primitive(PrimitiveType::Float)
2448 }
2449 "to_int" => Type::Primitive(PrimitiveType::Int),
2450 "to_string" | "display" => Type::Primitive(PrimitiveType::String),
2451 "is_nan" | "is_infinite" | "equals" => Type::Primitive(PrimitiveType::Bool),
2452 "compare" | "hash_code" => Type::Primitive(PrimitiveType::Int),
2453 _ => self.fresh_var(),
2454 },
2455 Type::Primitive(PrimitiveType::Bool) => match method {
2457 "negate" => Type::Primitive(PrimitiveType::Bool),
2458 "to_int" => Type::Primitive(PrimitiveType::Int),
2459 "to_string" | "display" => Type::Primitive(PrimitiveType::String),
2460 "compare" | "hash_code" => Type::Primitive(PrimitiveType::Int),
2461 "equals" => Type::Primitive(PrimitiveType::Bool),
2462 _ => self.fresh_var(),
2463 },
2464 Type::Primitive(PrimitiveType::Char) => match method {
2466 "to_upper" | "to_lower" => Type::Primitive(PrimitiveType::Char),
2467 "is_alpha" | "is_digit" | "is_whitespace" | "equals" => {
2468 Type::Primitive(PrimitiveType::Bool)
2469 }
2470 "to_int" | "compare" | "hash_code" => Type::Primitive(PrimitiveType::Int),
2471 "to_string" | "display" => Type::Primitive(PrimitiveType::String),
2472 _ => self.fresh_var(),
2473 },
2474 Type::Generic(g) if g.constructor == "Set" && g.args.len() == 1 => {
2476 let elem_ty = &g.args[0];
2477 match method {
2478 "len" | "length" | "count" => Type::Primitive(PrimitiveType::Int),
2479 "contains" | "is_empty" | "is_subset" | "is_superset"
2480 | "is_disjoint" => Type::Primitive(PrimitiveType::Bool),
2481 "add" | "remove" | "union" | "intersection" | "difference"
2482 | "symmetric_difference" | "filter" | "map" => receiver_ty.clone(),
2483 "for_each" => Type::Primitive(PrimitiveType::Void),
2484 "to_list" => Type::Generic(GenericType {
2485 constructor: "List".into(),
2486 args: vec![elem_ty.clone()],
2487 }),
2488 _ => self.fresh_var(),
2489 }
2490 }
2491 Type::Optional(inner_ty) => match method {
2493 "is_some" | "is_none" => Type::Primitive(PrimitiveType::Bool),
2494 "unwrap" | "unwrap_or" => *inner_ty.clone(),
2495 _ => self.fresh_var(),
2496 },
2497 Type::Result(ok_ty, _err_ty) => match method {
2499 "is_ok" | "is_err" => Type::Primitive(PrimitiveType::Bool),
2500 "unwrap" | "unwrap_or" => *ok_ty.clone(),
2501 _ => self.fresh_var(),
2502 },
2503 Type::Named(nt) => {
2505 if let Some(methods) = self.method_types.get(&nt.name) {
2506 if let Some(Type::Function(f)) = methods.get(method) {
2507 return self.subst.apply(&f.ret);
2508 }
2509 }
2510 self.fresh_var()
2511 }
2512 Type::Generic(g) => {
2515 if let Some(methods) = self.method_types.get(&g.constructor) {
2516 if let Some(Type::Function(f)) = methods.get(method) {
2517 let ret_ty = self.subst.apply(&f.ret);
2518 if let Some(params) = self.record_generic_params.get(&g.constructor) {
2519 return substitute_type_params(&ret_ty, params, &g.args);
2520 }
2521 return ret_ty;
2522 }
2523 }
2524 self.fresh_var()
2525 }
2526 _ => self.fresh_var(),
2527 }
2528 }
2529
2530 fn resolve_builtin_method_fn_type(&self, receiver_ty: &Type, method: &str) -> Option<Type> {
2540 let receiver_ty = self.subst.apply(receiver_ty);
2541 let mk = |recv: &Type, params: Vec<Type>, ret: Type| -> Option<Type> {
2542 let mut all_params = vec![recv.clone()];
2543 all_params.extend(params);
2544 Some(Type::Function(FnType {
2545 params: all_params,
2546 ret: Box::new(ret),
2547 effects: vec![],
2548 }))
2549 };
2550 match &receiver_ty {
2551 Type::Generic(g) if g.constructor == "List" && g.args.len() == 1 => {
2552 let elem = &g.args[0];
2553 let r = &receiver_ty;
2554 match method {
2555 "len" | "length" | "count" => mk(r, vec![], Type::Primitive(PrimitiveType::Int)),
2556 "is_empty" => mk(r, vec![], Type::Primitive(PrimitiveType::Bool)),
2557 "contains" => mk(r, vec![elem.clone()], Type::Primitive(PrimitiveType::Bool)),
2558 "first" | "last" => mk(r, vec![], Type::Optional(Box::new(elem.clone()))),
2559 "find" => {
2560 let cb = Type::Function(FnType {
2561 params: vec![elem.clone()],
2562 ret: Box::new(Type::Primitive(PrimitiveType::Bool)),
2563 effects: vec![],
2564 });
2565 mk(r, vec![cb], Type::Optional(Box::new(elem.clone())))
2566 }
2567 "get" => mk(
2568 r,
2569 vec![Type::Primitive(PrimitiveType::Int)],
2570 Type::Optional(Box::new(elem.clone())),
2571 ),
2572 "index_of" => mk(
2573 r,
2574 vec![elem.clone()],
2575 Type::Optional(Box::new(Type::Primitive(PrimitiveType::Int))),
2576 ),
2577 "push" | "append" => mk(r, vec![elem.clone()], receiver_ty.clone()),
2578 "pop" => mk(r, vec![], receiver_ty.clone()),
2579 "insert" => mk(
2580 r,
2581 vec![Type::Primitive(PrimitiveType::Int), elem.clone()],
2582 receiver_ty.clone(),
2583 ),
2584 "remove" => mk(
2585 r,
2586 vec![Type::Primitive(PrimitiveType::Int)],
2587 receiver_ty.clone(),
2588 ),
2589 "concat" => mk(r, vec![receiver_ty.clone()], receiver_ty.clone()),
2590 "clear" => mk(r, vec![], Type::Primitive(PrimitiveType::Void)),
2591 "reverse" | "sort" | "dedup" | "flatten" => {
2592 mk(r, vec![], receiver_ty.clone())
2593 }
2594 "take" | "skip" => mk(
2595 r,
2596 vec![Type::Primitive(PrimitiveType::Int)],
2597 receiver_ty.clone(),
2598 ),
2599 "slice" => mk(
2600 r,
2601 vec![
2602 Type::Primitive(PrimitiveType::Int),
2603 Type::Primitive(PrimitiveType::Int),
2604 ],
2605 receiver_ty.clone(),
2606 ),
2607 "filter" => {
2608 let cb = Type::Function(FnType {
2609 params: vec![elem.clone()],
2610 ret: Box::new(Type::Primitive(PrimitiveType::Bool)),
2611 effects: vec![],
2612 });
2613 mk(r, vec![cb], receiver_ty.clone())
2614 }
2615 "map" => {
2616 let u = self.fresh_var();
2617 let cb = Type::Function(FnType {
2618 params: vec![elem.clone()],
2619 ret: Box::new(u.clone()),
2620 effects: vec![],
2621 });
2622 let ret = Type::Generic(GenericType {
2623 constructor: "List".into(),
2624 args: vec![u],
2625 });
2626 mk(r, vec![cb], ret)
2627 }
2628 "flat_map" => {
2629 let u = self.fresh_var();
2630 let inner_list = Type::Generic(GenericType {
2631 constructor: "List".into(),
2632 args: vec![u.clone()],
2633 });
2634 let cb = Type::Function(FnType {
2635 params: vec![elem.clone()],
2636 ret: Box::new(inner_list),
2637 effects: vec![],
2638 });
2639 let ret = Type::Generic(GenericType {
2640 constructor: "List".into(),
2641 args: vec![u],
2642 });
2643 mk(r, vec![cb], ret)
2644 }
2645 "fold" => {
2646 let acc = self.fresh_var();
2647 let cb = Type::Function(FnType {
2648 params: vec![acc.clone(), elem.clone()],
2649 ret: Box::new(acc.clone()),
2650 effects: vec![],
2651 });
2652 mk(r, vec![acc.clone(), cb], acc)
2653 }
2654 "reduce" => {
2655 let cb = Type::Function(FnType {
2656 params: vec![elem.clone(), elem.clone()],
2657 ret: Box::new(elem.clone()),
2658 effects: vec![],
2659 });
2660 mk(r, vec![cb], elem.clone())
2661 }
2662 "for_each" => {
2663 let cb = Type::Function(FnType {
2664 params: vec![elem.clone()],
2665 ret: Box::new(Type::Primitive(PrimitiveType::Void)),
2666 effects: vec![],
2667 });
2668 mk(r, vec![cb], Type::Primitive(PrimitiveType::Void))
2669 }
2670 "any" | "all" => {
2671 let cb = Type::Function(FnType {
2672 params: vec![elem.clone()],
2673 ret: Box::new(Type::Primitive(PrimitiveType::Bool)),
2674 effects: vec![],
2675 });
2676 mk(r, vec![cb], Type::Primitive(PrimitiveType::Bool))
2677 }
2678 "enumerate" => {
2679 let pair = Type::Tuple(vec![
2680 Type::Primitive(PrimitiveType::Int),
2681 elem.clone(),
2682 ]);
2683 mk(r, vec![], Type::Generic(GenericType {
2684 constructor: "List".into(),
2685 args: vec![pair],
2686 }))
2687 }
2688 "zip" => {
2689 let f = self.fresh_var();
2690 let other_list = Type::Generic(GenericType {
2691 constructor: "List".into(),
2692 args: vec![f.clone()],
2693 });
2694 let pair = Type::Tuple(vec![elem.clone(), f]);
2695 mk(r, vec![other_list], Type::Generic(GenericType {
2696 constructor: "List".into(),
2697 args: vec![pair],
2698 }))
2699 }
2700 "join" => mk(
2701 r,
2702 vec![Type::Primitive(PrimitiveType::String)],
2703 Type::Primitive(PrimitiveType::String),
2704 ),
2705 "to_set" => mk(r, vec![], Type::Generic(GenericType {
2706 constructor: "Set".into(),
2707 args: vec![elem.clone()],
2708 })),
2709 _ => None,
2710 }
2711 }
2712 Type::Generic(g) if g.constructor == "Map" && g.args.len() == 2 => {
2713 let key = &g.args[0];
2714 let val = &g.args[1];
2715 let r = &receiver_ty;
2716 match method {
2717 "len" | "length" | "count" => mk(r, vec![], Type::Primitive(PrimitiveType::Int)),
2718 "is_empty" => mk(r, vec![], Type::Primitive(PrimitiveType::Bool)),
2719 "contains_key" => {
2720 mk(r, vec![key.clone()], Type::Primitive(PrimitiveType::Bool))
2721 }
2722 "get" => mk(r, vec![key.clone()], Type::Optional(Box::new(val.clone()))),
2723 "set" => mk(r, vec![key.clone(), val.clone()], receiver_ty.clone()),
2724 "delete" => mk(r, vec![key.clone()], receiver_ty.clone()),
2725 "merge" => mk(r, vec![receiver_ty.clone()], receiver_ty.clone()),
2726 "keys" => mk(
2727 r,
2728 vec![],
2729 Type::Generic(GenericType {
2730 constructor: "List".into(),
2731 args: vec![key.clone()],
2732 }),
2733 ),
2734 "values" => mk(
2735 r,
2736 vec![],
2737 Type::Generic(GenericType {
2738 constructor: "List".into(),
2739 args: vec![val.clone()],
2740 }),
2741 ),
2742 "entries" | "to_list" => mk(
2743 r,
2744 vec![],
2745 Type::Generic(GenericType {
2746 constructor: "List".into(),
2747 args: vec![Type::Tuple(vec![key.clone(), val.clone()])],
2748 }),
2749 ),
2750 "map_values" => {
2751 let u = self.fresh_var();
2752 let cb = Type::Function(FnType {
2753 params: vec![val.clone()],
2754 ret: Box::new(u.clone()),
2755 effects: vec![],
2756 });
2757 mk(r, vec![cb], Type::Generic(GenericType {
2758 constructor: "Map".into(),
2759 args: vec![key.clone(), u],
2760 }))
2761 }
2762 "filter" => {
2763 let cb = Type::Function(FnType {
2764 params: vec![key.clone(), val.clone()],
2765 ret: Box::new(Type::Primitive(PrimitiveType::Bool)),
2766 effects: vec![],
2767 });
2768 mk(r, vec![cb], receiver_ty.clone())
2769 }
2770 "for_each" => {
2771 let cb = Type::Function(FnType {
2772 params: vec![key.clone(), val.clone()],
2773 ret: Box::new(Type::Primitive(PrimitiveType::Void)),
2774 effects: vec![],
2775 });
2776 mk(r, vec![cb], Type::Primitive(PrimitiveType::Void))
2777 }
2778 _ => None,
2779 }
2780 }
2781 Type::Generic(g) if g.constructor == "Set" && g.args.len() == 1 => {
2782 let elem = &g.args[0];
2783 let r = &receiver_ty;
2784 match method {
2785 "len" | "length" | "count" => mk(r, vec![], Type::Primitive(PrimitiveType::Int)),
2786 "is_empty" => mk(r, vec![], Type::Primitive(PrimitiveType::Bool)),
2787 "contains" => mk(r, vec![elem.clone()], Type::Primitive(PrimitiveType::Bool)),
2788 "add" | "remove" => mk(r, vec![elem.clone()], receiver_ty.clone()),
2789 "union" | "intersection" | "difference"
2790 | "symmetric_difference" => {
2791 mk(r, vec![receiver_ty.clone()], receiver_ty.clone())
2792 }
2793 "is_subset" | "is_superset" | "is_disjoint" => {
2794 mk(r, vec![receiver_ty.clone()], Type::Primitive(PrimitiveType::Bool))
2795 }
2796 "filter" => {
2797 let cb = Type::Function(FnType {
2798 params: vec![elem.clone()],
2799 ret: Box::new(Type::Primitive(PrimitiveType::Bool)),
2800 effects: vec![],
2801 });
2802 mk(r, vec![cb], receiver_ty.clone())
2803 }
2804 "map" => {
2805 let cb = Type::Function(FnType {
2806 params: vec![elem.clone()],
2807 ret: Box::new(elem.clone()),
2808 effects: vec![],
2809 });
2810 mk(r, vec![cb], receiver_ty.clone())
2811 }
2812 "for_each" => {
2813 let cb = Type::Function(FnType {
2814 params: vec![elem.clone()],
2815 ret: Box::new(Type::Primitive(PrimitiveType::Void)),
2816 effects: vec![],
2817 });
2818 mk(r, vec![cb], Type::Primitive(PrimitiveType::Void))
2819 }
2820 "to_list" => mk(r, vec![], Type::Generic(GenericType {
2821 constructor: "List".into(),
2822 args: vec![elem.clone()],
2823 })),
2824 _ => None,
2825 }
2826 }
2827 Type::Primitive(PrimitiveType::String) => {
2828 let r = &receiver_ty;
2829 let str_ty = Type::Primitive(PrimitiveType::String);
2830 let int_ty = Type::Primitive(PrimitiveType::Int);
2831 match method {
2832 "len" | "length" | "count" | "byte_len" => {
2833 mk(r, vec![], int_ty)
2834 }
2835 "is_empty" => mk(r, vec![], Type::Primitive(PrimitiveType::Bool)),
2836 "contains" | "starts_with" | "ends_with" => {
2837 mk(r, vec![str_ty.clone()], Type::Primitive(PrimitiveType::Bool))
2838 }
2839 "regex_match" => {
2840 mk(r, vec![str_ty.clone()], Type::Primitive(PrimitiveType::Bool))
2841 }
2842 "to_upper" | "to_lower" | "trim" | "trim_start" | "trim_end"
2843 | "reverse" | "to_string" | "display" => {
2844 mk(r, vec![], str_ty)
2845 }
2846 "repeat" => mk(
2847 r,
2848 vec![Type::Primitive(PrimitiveType::Int)],
2849 str_ty,
2850 ),
2851 "slice" | "substring" => mk(
2852 r,
2853 vec![
2854 Type::Primitive(PrimitiveType::Int),
2855 Type::Primitive(PrimitiveType::Int),
2856 ],
2857 str_ty,
2858 ),
2859 "replace" | "regex_replace" => mk(
2860 r,
2861 vec![str_ty.clone(), str_ty.clone()],
2862 str_ty,
2863 ),
2864 "pad_start" | "pad_end" => mk(
2865 r,
2866 vec![Type::Primitive(PrimitiveType::Int), str_ty.clone()],
2867 str_ty,
2868 ),
2869 "format" => mk(r, vec![], str_ty),
2870 "join" => mk(
2871 r,
2872 vec![Type::Generic(GenericType {
2873 constructor: "List".into(),
2874 args: vec![str_ty.clone()],
2875 })],
2876 str_ty,
2877 ),
2878 "split" => mk(
2879 r,
2880 vec![str_ty],
2881 Type::Generic(GenericType {
2882 constructor: "List".into(),
2883 args: vec![Type::Primitive(PrimitiveType::String)],
2884 }),
2885 ),
2886 "regex_find" => mk(
2887 r,
2888 vec![Type::Primitive(PrimitiveType::String)],
2889 Type::Generic(GenericType {
2890 constructor: "List".into(),
2891 args: vec![Type::Primitive(PrimitiveType::String)],
2892 }),
2893 ),
2894 "chars" => mk(
2895 r,
2896 vec![],
2897 Type::Generic(GenericType {
2898 constructor: "List".into(),
2899 args: vec![Type::Primitive(PrimitiveType::Char)],
2900 }),
2901 ),
2902 "bytes" => mk(
2903 r,
2904 vec![],
2905 Type::Generic(GenericType {
2906 constructor: "List".into(),
2907 args: vec![Type::Primitive(PrimitiveType::Int)],
2908 }),
2909 ),
2910 "index_of" => mk(
2911 r,
2912 vec![Type::Primitive(PrimitiveType::String)],
2913 Type::Optional(Box::new(Type::Primitive(PrimitiveType::Int))),
2914 ),
2915 "char_at" => mk(
2916 r,
2917 vec![Type::Primitive(PrimitiveType::Int)],
2918 Type::Optional(Box::new(Type::Primitive(PrimitiveType::Char))),
2919 ),
2920 _ => None,
2921 }
2922 }
2923 Type::Primitive(PrimitiveType::Int) => {
2924 let r = &receiver_ty;
2925 let int_ty = Type::Primitive(PrimitiveType::Int);
2926 match method {
2927 "abs" => mk(r, vec![], int_ty),
2928 "min" | "max" | "shift_left" | "shift_right" | "compare" => mk(
2929 r,
2930 vec![int_ty.clone()],
2931 int_ty,
2932 ),
2933 "clamp" => mk(
2934 r,
2935 vec![int_ty.clone(), int_ty.clone()],
2936 int_ty,
2937 ),
2938 "equals" => mk(
2939 r,
2940 vec![Type::Primitive(PrimitiveType::Int)],
2941 Type::Primitive(PrimitiveType::Bool),
2942 ),
2943 "hash_code" => mk(r, vec![], Type::Primitive(PrimitiveType::Int)),
2944 "to_float" => mk(r, vec![], Type::Primitive(PrimitiveType::Float)),
2945 "to_string" | "display" => {
2946 mk(r, vec![], Type::Primitive(PrimitiveType::String))
2947 }
2948 _ => None,
2949 }
2950 }
2951 Type::Primitive(PrimitiveType::Float) => {
2952 let r = &receiver_ty;
2953 let float_ty = Type::Primitive(PrimitiveType::Float);
2954 match method {
2955 "abs" | "floor" | "ceil" | "round" | "sqrt" => {
2956 mk(r, vec![], float_ty)
2957 }
2958 "min" | "max" => mk(
2959 r,
2960 vec![float_ty.clone()],
2961 float_ty,
2962 ),
2963 "clamp" => mk(
2964 r,
2965 vec![float_ty.clone(), float_ty.clone()],
2966 float_ty,
2967 ),
2968 "to_int" => mk(r, vec![], Type::Primitive(PrimitiveType::Int)),
2969 "to_string" | "display" => {
2970 mk(r, vec![], Type::Primitive(PrimitiveType::String))
2971 }
2972 "is_nan" | "is_infinite" | "equals" => {
2973 mk(r, vec![], Type::Primitive(PrimitiveType::Bool))
2974 }
2975 "compare" | "hash_code" => {
2976 mk(r, vec![], Type::Primitive(PrimitiveType::Int))
2977 }
2978 _ => None,
2979 }
2980 }
2981 Type::Primitive(PrimitiveType::Bool) => {
2982 let r = &receiver_ty;
2983 match method {
2984 "negate" | "equals" => mk(r, vec![], Type::Primitive(PrimitiveType::Bool)),
2985 "to_int" | "compare" | "hash_code" => {
2986 mk(r, vec![], Type::Primitive(PrimitiveType::Int))
2987 }
2988 "to_string" | "display" => {
2989 mk(r, vec![], Type::Primitive(PrimitiveType::String))
2990 }
2991 _ => None,
2992 }
2993 }
2994 Type::Primitive(PrimitiveType::Char) => {
2995 let r = &receiver_ty;
2996 match method {
2997 "to_upper" | "to_lower" => mk(r, vec![], Type::Primitive(PrimitiveType::Char)),
2998 "is_alpha" | "is_digit" | "is_whitespace" | "equals" => {
2999 mk(r, vec![], Type::Primitive(PrimitiveType::Bool))
3000 }
3001 "to_int" | "compare" | "hash_code" => {
3002 mk(r, vec![], Type::Primitive(PrimitiveType::Int))
3003 }
3004 "to_string" | "display" => {
3005 mk(r, vec![], Type::Primitive(PrimitiveType::String))
3006 }
3007 _ => None,
3008 }
3009 }
3010 Type::Optional(inner_ty) => {
3012 let r = &receiver_ty;
3013 let inner = *inner_ty.clone();
3014 match method {
3015 "is_some" | "is_none" => {
3016 mk(r, vec![], Type::Primitive(PrimitiveType::Bool))
3017 }
3018 "unwrap" => mk(r, vec![], inner),
3019 "unwrap_or" => mk(r, vec![inner.clone()], inner),
3020 "map" => {
3021 let u = self.fresh_var();
3022 let cb = Type::Function(FnType {
3023 params: vec![inner],
3024 ret: Box::new(u.clone()),
3025 effects: vec![],
3026 });
3027 mk(r, vec![cb], Type::Optional(Box::new(u)))
3028 }
3029 "flat_map" => {
3030 let u = self.fresh_var();
3031 let opt_u = Type::Optional(Box::new(u));
3032 let cb = Type::Function(FnType {
3033 params: vec![inner],
3034 ret: Box::new(opt_u.clone()),
3035 effects: vec![],
3036 });
3037 mk(r, vec![cb], opt_u)
3038 }
3039 _ => None,
3040 }
3041 }
3042 Type::Result(ok_ty, err_ty) => {
3044 let r = &receiver_ty;
3045 let ok = *ok_ty.clone();
3046 let err = *err_ty.clone();
3047 match method {
3048 "is_ok" | "is_err" => {
3049 mk(r, vec![], Type::Primitive(PrimitiveType::Bool))
3050 }
3051 "unwrap" => mk(r, vec![], ok),
3052 "unwrap_or" => mk(r, vec![ok.clone()], ok),
3053 "map" => {
3054 let u = self.fresh_var();
3055 let cb = Type::Function(FnType {
3056 params: vec![ok],
3057 ret: Box::new(u.clone()),
3058 effects: vec![],
3059 });
3060 mk(r, vec![cb], Type::Result(Box::new(u), Box::new(err)))
3061 }
3062 "map_err" => {
3063 let e2 = self.fresh_var();
3064 let cb = Type::Function(FnType {
3065 params: vec![err],
3066 ret: Box::new(e2.clone()),
3067 effects: vec![],
3068 });
3069 mk(r, vec![cb], Type::Result(Box::new(ok), Box::new(e2)))
3070 }
3071 _ => None,
3072 }
3073 }
3074 _ => None,
3075 }
3076 }
3077
3078 fn replace_type_vars(&self, ty: &Type, map: &HashMap<TypeVarId, Type>) -> Type {
3084 match ty {
3085 Type::TypeVar(id) => map.get(id).cloned().unwrap_or_else(|| ty.clone()),
3086 Type::Function(f) => Type::Function(FnType {
3087 params: f
3088 .params
3089 .iter()
3090 .map(|t| self.replace_type_vars(t, map))
3091 .collect(),
3092 ret: Box::new(self.replace_type_vars(&f.ret, map)),
3093 effects: f.effects.clone(),
3094 }),
3095 Type::Generic(g) => Type::Generic(GenericType {
3096 constructor: g.constructor.clone(),
3097 args: g
3098 .args
3099 .iter()
3100 .map(|t| self.replace_type_vars(t, map))
3101 .collect(),
3102 }),
3103 Type::Tuple(elems) => Type::Tuple(
3104 elems
3105 .iter()
3106 .map(|t| self.replace_type_vars(t, map))
3107 .collect(),
3108 ),
3109 Type::Optional(inner) => Type::Optional(Box::new(self.replace_type_vars(inner, map))),
3110 Type::Result(ok, err) => Type::Result(
3111 Box::new(self.replace_type_vars(ok, map)),
3112 Box::new(self.replace_type_vars(err, map)),
3113 ),
3114 _ => ty.clone(),
3115 }
3116 }
3117
3118 fn infer_binop(&mut self, op: BinOp, lt: &Type, rt: &Type, span: Span) -> Type {
3121 match op {
3122 BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div | BinOp::Rem | BinOp::Pow => {
3124 self.unify_or_error(lt, rt, span, "arithmetic operands");
3125 self.subst.apply(lt)
3126 }
3127
3128 BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Le | BinOp::Gt | BinOp::Ge => {
3130 self.unify_or_error(lt, rt, span, "comparison operands");
3131 Type::Primitive(PrimitiveType::Bool)
3132 }
3133
3134 BinOp::And | BinOp::Or => {
3136 let bool_ty = Type::Primitive(PrimitiveType::Bool);
3137 self.unify_or_error(lt, &bool_ty, span, "logical operand");
3138 self.unify_or_error(rt, &bool_ty, span, "logical operand");
3139 bool_ty
3140 }
3141
3142 BinOp::BitAnd | BinOp::BitOr | BinOp::BitXor => {
3144 self.unify_or_error(lt, rt, span, "bitwise operands");
3145 self.subst.apply(lt)
3146 }
3147
3148 BinOp::Compose => self.fresh_var(),
3150
3151 BinOp::Is => Type::Primitive(PrimitiveType::Bool),
3153 }
3154 }
3155
3156 fn infer_unop(&mut self, op: UnaryOp, operand_ty: &Type, span: Span) -> Type {
3157 match op {
3158 UnaryOp::Neg => {
3159 self.subst.apply(operand_ty)
3161 }
3162 UnaryOp::Not => {
3163 let bool_ty = Type::Primitive(PrimitiveType::Bool);
3165 self.unify_or_error(operand_ty, &bool_ty, span, "logical not operand");
3166 bool_ty
3167 }
3168 UnaryOp::BitNot => {
3169 self.subst.apply(operand_ty)
3171 }
3172 }
3173 }
3174
3175 fn infer_literal(&self, lit: &Literal) -> Type {
3178 match lit {
3179 Literal::Int(s) => {
3180 let (_, suffix) = bock_ast::strip_type_suffix(s);
3181 match suffix {
3182 Some("i8") => Type::Primitive(PrimitiveType::Int8),
3183 Some("i16") => Type::Primitive(PrimitiveType::Int16),
3184 Some("i32") => Type::Primitive(PrimitiveType::Int32),
3185 Some("i64") => Type::Primitive(PrimitiveType::Int64),
3186 Some("i128") => Type::Primitive(PrimitiveType::Int128),
3187 Some("u8") => Type::Primitive(PrimitiveType::UInt8),
3188 Some("u16") => Type::Primitive(PrimitiveType::UInt16),
3189 Some("u32") => Type::Primitive(PrimitiveType::UInt32),
3190 Some("u64") => Type::Primitive(PrimitiveType::UInt64),
3191 _ => Type::Primitive(PrimitiveType::Int),
3192 }
3193 }
3194 Literal::Float(s) => {
3195 let (_, suffix) = bock_ast::strip_type_suffix(s);
3196 match suffix {
3197 Some("f32") => Type::Primitive(PrimitiveType::Float32),
3198 Some("f64") => Type::Primitive(PrimitiveType::Float64),
3199 _ => Type::Primitive(PrimitiveType::Float),
3200 }
3201 }
3202 Literal::Bool(_) => Type::Primitive(PrimitiveType::Bool),
3203 Literal::Char(_) => Type::Primitive(PrimitiveType::Char),
3204 Literal::String(_) => Type::Primitive(PrimitiveType::String),
3205 Literal::Unit => Type::Primitive(PrimitiveType::Void),
3206 }
3207 }
3208
3209 fn bind_pattern_type(&mut self, pattern: &mut AIRNode, ty: &Type) {
3214 match &pattern.kind {
3215 NodeKind::WildcardPat | NodeKind::RestPat => {
3216 self.record(pattern, ty.clone());
3217 }
3218 NodeKind::BindPat { name, .. } => {
3219 let name = name.name.clone();
3220 self.env.define(name, ty.clone());
3221 self.record(pattern, ty.clone());
3222 }
3223 NodeKind::LiteralPat { lit } => {
3224 let lit_ty = self.infer_literal(lit);
3225 self.unify_or_error(&lit_ty, ty, pattern.span, "literal pattern");
3226 self.record(pattern, lit_ty);
3227 }
3228 NodeKind::TuplePat { .. } => {
3229 if let NodeKind::TuplePat { elems } = &mut pattern.kind {
3230 if let Type::Tuple(elem_tys) = ty {
3231 for (e, et) in elems.iter_mut().zip(elem_tys.iter()) {
3232 let et = et.clone();
3233 self.bind_pattern_type(e, &et);
3234 }
3235 } else {
3236 for e in elems.iter_mut() {
3237 let fv = self.fresh_var();
3238 self.bind_pattern_type(e, &fv);
3239 }
3240 }
3241 }
3242 self.record(pattern, ty.clone());
3243 }
3244 NodeKind::ConstructorPat { .. } => {
3245 let ctor_name = if let NodeKind::ConstructorPat { path, .. } = &pattern.kind {
3247 type_path_to_name(path)
3248 } else {
3249 String::new()
3250 };
3251 let resolved_ty = self.subst.apply(ty);
3252 if let NodeKind::ConstructorPat { fields, .. } = &mut pattern.kind {
3253 match (ctor_name.as_str(), &resolved_ty) {
3254 ("Some", Type::Optional(inner)) if fields.len() == 1 => {
3256 let inner_ty = self.subst.apply(inner);
3257 self.bind_pattern_type(&mut fields[0], &inner_ty);
3258 }
3259 ("Ok", Type::Result(ok, _)) if fields.len() == 1 => {
3261 let ok_ty = self.subst.apply(ok);
3262 self.bind_pattern_type(&mut fields[0], &ok_ty);
3263 }
3264 ("Err", Type::Result(_, err)) if fields.len() == 1 => {
3266 let err_ty = self.subst.apply(err);
3267 self.bind_pattern_type(&mut fields[0], &err_ty);
3268 }
3269 _ => {
3271 for f in fields.iter_mut() {
3272 let fv = self.fresh_var();
3273 self.bind_pattern_type(f, &fv);
3274 }
3275 }
3276 }
3277 }
3278 self.record(pattern, ty.clone());
3279 }
3280 NodeKind::OrPat { .. } => {
3281 if let NodeKind::OrPat { alternatives } = &mut pattern.kind {
3282 for alt in alternatives.iter_mut() {
3283 let t = ty.clone();
3284 self.bind_pattern_type(alt, &t);
3285 }
3286 }
3287 self.record(pattern, ty.clone());
3288 }
3289 NodeKind::ListPat { .. } => {
3290 let elem_ty = match ty {
3291 Type::Generic(g) if g.constructor == "List" && g.args.len() == 1 => {
3292 g.args[0].clone()
3293 }
3294 _ => self.fresh_var(),
3295 };
3296 if let NodeKind::ListPat { elems, rest } = &mut pattern.kind {
3297 for e in elems.iter_mut() {
3298 let et = elem_ty.clone();
3299 self.bind_pattern_type(e, &et);
3300 }
3301 if let Some(r) = rest {
3302 let list_ty = Type::Generic(GenericType {
3303 constructor: "List".into(),
3304 args: vec![elem_ty],
3305 });
3306 self.bind_pattern_type(r, &list_ty);
3307 }
3308 }
3309 self.record(pattern, ty.clone());
3310 }
3311 NodeKind::RecordPat { .. } => {
3312 if let NodeKind::RecordPat { fields, .. } = &mut pattern.kind {
3313 for f in fields.iter_mut() {
3314 let fv = self.fresh_var();
3315 if let Some(sub_pat) = &mut f.pattern {
3316 self.bind_pattern_type(sub_pat, &fv);
3318 } else {
3319 self.env.define(f.name.name.clone(), fv);
3321 }
3322 }
3323 }
3324 self.record(pattern, ty.clone());
3325 }
3326 _ => {
3327 self.record(pattern, ty.clone());
3328 }
3329 }
3330 }
3331
3332 fn air_type_node_to_type(&mut self, node: &AIRNode, gp_map: &HashMap<String, Type>) -> Type {
3337 match &node.kind {
3338 NodeKind::TypeNamed { path, args } => {
3339 let name = type_path_to_name(path);
3340 if let Some(ty) = gp_map.get(&name) {
3342 return ty.clone();
3343 }
3344 if let Some(prim) = name_to_primitive(&name) {
3346 return Type::Primitive(prim);
3347 }
3348 if args.is_empty() {
3350 if let Some(underlying) = self.type_aliases.get(&name) {
3352 return underlying.clone();
3353 }
3354 Type::Named(crate::NamedType { name })
3355 } else {
3356 let converted_args: Vec<Type> = args
3357 .iter()
3358 .map(|a| self.air_type_node_to_type(a, gp_map))
3359 .collect();
3360 match (name.as_str(), converted_args.len()) {
3364 ("Result", 2) => Type::Result(
3365 Box::new(converted_args[0].clone()),
3366 Box::new(converted_args[1].clone()),
3367 ),
3368 ("Optional", 1) => Type::Optional(Box::new(converted_args[0].clone())),
3369 _ => Type::Generic(GenericType {
3370 constructor: name,
3371 args: converted_args,
3372 }),
3373 }
3374 }
3375 }
3376 NodeKind::TypeTuple { elems } => {
3377 let elem_tys: Vec<Type> = elems
3378 .iter()
3379 .map(|e| self.air_type_node_to_type(e, gp_map))
3380 .collect();
3381 Type::Tuple(elem_tys)
3382 }
3383 NodeKind::TypeFunction { params, ret, .. } => {
3384 let param_tys: Vec<Type> = params
3385 .iter()
3386 .map(|p| self.air_type_node_to_type(p, gp_map))
3387 .collect();
3388 let ret_ty = self.air_type_node_to_type(ret, gp_map);
3389 Type::Function(FnType {
3390 params: param_tys,
3391 ret: Box::new(ret_ty),
3392 effects: vec![],
3393 })
3394 }
3395 NodeKind::TypeOptional { inner } => {
3396 Type::Optional(Box::new(self.air_type_node_to_type(inner, gp_map)))
3397 }
3398 NodeKind::TypeSelf => Type::Named(crate::NamedType {
3399 name: "Self".into(),
3400 }),
3401 NodeKind::Param { ty, .. } => {
3402 if let Some(ty_node) = ty {
3403 self.air_type_node_to_type(ty_node, gp_map)
3404 } else {
3405 self.fresh_var()
3406 }
3407 }
3408 _ => self.fresh_var(),
3409 }
3410 }
3411
3412 fn type_expr_to_type(
3417 &self,
3418 ty: &TypeExpr,
3419 gp_map: &HashMap<String, Type>,
3420 ) -> Type {
3421 match ty {
3422 TypeExpr::Named { path, args, .. } => {
3423 let name = type_path_to_name(path);
3424 if let Some(t) = gp_map.get(&name) {
3425 return t.clone();
3426 }
3427 if let Some(prim) = name_to_primitive(&name) {
3428 return Type::Primitive(prim);
3429 }
3430 if args.is_empty() {
3431 if let Some(underlying) = self.type_aliases.get(&name) {
3433 return underlying.clone();
3434 }
3435 Type::Named(crate::NamedType { name })
3436 } else {
3437 let converted_args: Vec<Type> = args
3438 .iter()
3439 .map(|a| self.type_expr_to_type(a, gp_map))
3440 .collect();
3441 match (name.as_str(), converted_args.len()) {
3442 ("Result", 2) => Type::Result(
3443 Box::new(converted_args[0].clone()),
3444 Box::new(converted_args[1].clone()),
3445 ),
3446 ("Optional", 1) => {
3447 Type::Optional(Box::new(converted_args[0].clone()))
3448 }
3449 _ => Type::Generic(GenericType {
3450 constructor: name,
3451 args: converted_args,
3452 }),
3453 }
3454 }
3455 }
3456 TypeExpr::Tuple { elems, .. } => {
3457 Type::Tuple(elems.iter().map(|e| self.type_expr_to_type(e, gp_map)).collect())
3458 }
3459 TypeExpr::Function {
3460 params, ret, ..
3461 } => {
3462 let param_tys: Vec<Type> = params
3463 .iter()
3464 .map(|p| self.type_expr_to_type(p, gp_map))
3465 .collect();
3466 let ret_ty = self.type_expr_to_type(ret, gp_map);
3467 Type::Function(FnType {
3468 params: param_tys,
3469 ret: Box::new(ret_ty),
3470 effects: vec![],
3471 })
3472 }
3473 TypeExpr::Optional { inner, .. } => {
3474 Type::Optional(Box::new(self.type_expr_to_type(inner, gp_map)))
3475 }
3476 TypeExpr::SelfType { .. } => Type::Named(crate::NamedType {
3477 name: "Self".into(),
3478 }),
3479 }
3480 }
3481
3482 pub fn infer_expr(&mut self, expr: &AIRNode) -> Type {
3490 if let Some(ty) = self.types.get(&expr.id) {
3491 return ty.clone();
3492 }
3493 let mut cloned = expr.clone();
3495 self.infer_node(&mut cloned)
3496 }
3497
3498 pub fn check_expr(&mut self, expr: &AIRNode, expected: &Type) {
3503 if let Some(ty) = self.types.get(&expr.id) {
3504 let ty = ty.clone();
3505 self.unify_or_error(&ty, expected, expr.span, "expression");
3506 return;
3507 }
3508 let mut cloned = expr.clone();
3509 self.check_node(&mut cloned, expected);
3510 }
3511}
3512
3513impl Default for TypeChecker {
3514 fn default() -> Self {
3515 Self::new()
3516 }
3517}
3518
3519trait NodeKindExt {
3523 fn param_ty_node(&self) -> &AIRNode;
3525 fn param_pat_name(&self) -> Option<String>;
3527}
3528
3529impl NodeKindExt for NodeKind {
3530 fn param_ty_node(&self) -> &AIRNode {
3531 match self {
3540 NodeKind::Param { ty, pattern, .. } => ty.as_deref().unwrap_or(pattern),
3541 _ => unreachable!("param_ty_node called on non-Param node"),
3543 }
3544 }
3545
3546 fn param_pat_name(&self) -> Option<String> {
3547 match self {
3548 NodeKind::Param { pattern, .. } => match &pattern.kind {
3549 NodeKind::BindPat { name, .. } => Some(name.name.clone()),
3550 NodeKind::WildcardPat => None,
3551 _ => None,
3552 },
3553 _ => None,
3554 }
3555 }
3556}
3557
3558fn collect_type_var_ids_fn(fn_ty: &FnType, out: &mut Vec<TypeVarId>) {
3564 for param in &fn_ty.params {
3565 collect_type_var_ids(param, out);
3566 }
3567 collect_type_var_ids(&fn_ty.ret, out);
3568}
3569
3570fn collect_type_var_ids(ty: &Type, out: &mut Vec<TypeVarId>) {
3572 match ty {
3573 Type::TypeVar(id) => {
3574 if !out.contains(id) {
3575 out.push(*id);
3576 }
3577 }
3578 Type::Function(f) => {
3579 for p in &f.params {
3580 collect_type_var_ids(p, out);
3581 }
3582 collect_type_var_ids(&f.ret, out);
3583 }
3584 Type::Generic(g) => {
3585 for a in &g.args {
3586 collect_type_var_ids(a, out);
3587 }
3588 }
3589 Type::Tuple(elems) => {
3590 for e in elems {
3591 collect_type_var_ids(e, out);
3592 }
3593 }
3594 Type::Optional(inner) => collect_type_var_ids(inner, out),
3595 Type::Result(ok, err) => {
3596 collect_type_var_ids(ok, out);
3597 collect_type_var_ids(err, out);
3598 }
3599 _ => {}
3600 }
3601}
3602
3603fn substitute_type_params(ty: &Type, param_names: &[String], args: &[Type]) -> Type {
3611 match ty {
3612 Type::Named(nt) => {
3613 if let Some(idx) = param_names.iter().position(|n| n == &nt.name) {
3614 if idx < args.len() {
3615 return args[idx].clone();
3616 }
3617 }
3618 ty.clone()
3619 }
3620 Type::Generic(g) => Type::Generic(GenericType {
3621 constructor: g.constructor.clone(),
3622 args: g
3623 .args
3624 .iter()
3625 .map(|a| substitute_type_params(a, param_names, args))
3626 .collect(),
3627 }),
3628 Type::Optional(inner) => {
3629 Type::Optional(Box::new(substitute_type_params(inner, param_names, args)))
3630 }
3631 Type::Result(ok, err) => Type::Result(
3632 Box::new(substitute_type_params(ok, param_names, args)),
3633 Box::new(substitute_type_params(err, param_names, args)),
3634 ),
3635 Type::Tuple(elems) => Type::Tuple(
3636 elems
3637 .iter()
3638 .map(|e| substitute_type_params(e, param_names, args))
3639 .collect(),
3640 ),
3641 Type::Function(f) => Type::Function(FnType {
3642 params: f
3643 .params
3644 .iter()
3645 .map(|p| substitute_type_params(p, param_names, args))
3646 .collect(),
3647 ret: Box::new(substitute_type_params(&f.ret, param_names, args)),
3648 effects: f.effects.clone(),
3649 }),
3650 _ => ty.clone(),
3651 }
3652}
3653
3654fn type_path_to_name(path: &TypePath) -> String {
3658 path.segments
3659 .iter()
3660 .map(|s| s.name.as_str())
3661 .collect::<Vec<_>>()
3662 .join(".")
3663}
3664
3665fn name_to_primitive(name: &str) -> Option<PrimitiveType> {
3667 match name {
3668 "Int" => Some(PrimitiveType::Int),
3669 "Float" => Some(PrimitiveType::Float),
3670 "Bool" => Some(PrimitiveType::Bool),
3671 "String" => Some(PrimitiveType::String),
3672 "Char" => Some(PrimitiveType::Char),
3673 "Void" => Some(PrimitiveType::Void),
3674 "Never" => Some(PrimitiveType::Never),
3675 "Byte" => Some(PrimitiveType::Byte),
3676 "Bytes" => Some(PrimitiveType::Bytes),
3677 "Int8" => Some(PrimitiveType::Int8),
3678 "Int16" => Some(PrimitiveType::Int16),
3679 "Int32" => Some(PrimitiveType::Int32),
3680 "Int64" => Some(PrimitiveType::Int64),
3681 "Int128" => Some(PrimitiveType::Int128),
3682 "UInt8" => Some(PrimitiveType::UInt8),
3683 "UInt16" => Some(PrimitiveType::UInt16),
3684 "UInt32" => Some(PrimitiveType::UInt32),
3685 "UInt64" => Some(PrimitiveType::UInt64),
3686 "Float32" => Some(PrimitiveType::Float32),
3687 "Float64" => Some(PrimitiveType::Float64),
3688 "BigInt" => Some(PrimitiveType::BigInt),
3689 "BigFloat" => Some(PrimitiveType::BigFloat),
3690 "Decimal" => Some(PrimitiveType::Decimal),
3691 _ => None,
3692 }
3693}
3694
3695fn conversion_hint(lhs: &Type, rhs: &Type) -> Option<String> {
3701 let l = as_primitive(lhs)?;
3702 let r = as_primitive(rhs)?;
3703 use PrimitiveType as P;
3704 let is_int = |p: &P| matches!(p, P::Int | P::Int8 | P::Int16 | P::Int32 | P::Int64 | P::Int128 | P::UInt8 | P::UInt16 | P::UInt32 | P::UInt64 | P::BigInt);
3705 let is_float = |p: &P| matches!(p, P::Float | P::Float32 | P::Float64 | P::BigFloat);
3706 if (is_int(&l) && is_float(&r)) || (is_float(&l) && is_int(&r)) {
3707 return Some(
3708 "mixed Int/Float — call `.to_float()` on the Int, or `.to_int()` on the Float (truncates), to make the types match".into(),
3709 );
3710 }
3711 if matches!(l, P::String) || matches!(r, P::String) {
3712 return Some("use `.to_string()` to convert the non-`String` operand".into());
3713 }
3714 None
3715}
3716
3717fn as_primitive(ty: &Type) -> Option<PrimitiveType> {
3719 match ty {
3720 Type::Primitive(p) => Some(p.clone()),
3721 _ => None,
3722 }
3723}
3724
3725#[cfg(test)]
3728mod tests {
3729 use super::*;
3730 use bock_air::{AIRNode, NodeIdGen, NodeKind};
3731 use bock_ast::{BinOp, Ident, Literal, TypePath};
3732 use bock_errors::{FileId, Span};
3733
3734 fn span() -> Span {
3735 Span {
3736 file: FileId(0),
3737 start: 0,
3738 end: 0,
3739 }
3740 }
3741
3742 fn ident(name: &str) -> Ident {
3743 Ident {
3744 name: name.into(),
3745 span: span(),
3746 }
3747 }
3748
3749 fn make_node(gen: &NodeIdGen, kind: NodeKind) -> AIRNode {
3750 AIRNode::new(gen.next(), span(), kind)
3751 }
3752
3753 fn int_lit(gen: &NodeIdGen) -> AIRNode {
3754 make_node(
3755 gen,
3756 NodeKind::Literal {
3757 lit: Literal::Int("42".into()),
3758 },
3759 )
3760 }
3761
3762 fn bool_lit(gen: &NodeIdGen, v: bool) -> AIRNode {
3763 make_node(
3764 gen,
3765 NodeKind::Literal {
3766 lit: Literal::Bool(v),
3767 },
3768 )
3769 }
3770
3771 fn str_lit(gen: &NodeIdGen) -> AIRNode {
3772 make_node(
3773 gen,
3774 NodeKind::Literal {
3775 lit: Literal::String("hello".into()),
3776 },
3777 )
3778 }
3779
3780 fn float_lit(gen: &NodeIdGen) -> AIRNode {
3781 make_node(
3782 gen,
3783 NodeKind::Literal {
3784 lit: Literal::Float("3.14".into()),
3785 },
3786 )
3787 }
3788
3789 fn type_named_node(gen: &NodeIdGen, name: &str) -> AIRNode {
3790 make_node(
3791 gen,
3792 NodeKind::TypeNamed {
3793 path: TypePath {
3794 segments: vec![ident(name)],
3795 span: span(),
3796 },
3797 args: vec![],
3798 },
3799 )
3800 }
3801
3802 #[test]
3805 fn infer_int_literal() {
3806 let gen = NodeIdGen::new();
3807 let mut checker = TypeChecker::new();
3808 let node = int_lit(&gen);
3809 let ty = checker.infer_expr(&node);
3810 assert_eq!(ty, Type::Primitive(PrimitiveType::Int));
3811 }
3812
3813 #[test]
3814 fn infer_float_literal() {
3815 let gen = NodeIdGen::new();
3816 let mut checker = TypeChecker::new();
3817 let node = float_lit(&gen);
3818 let ty = checker.infer_expr(&node);
3819 assert_eq!(ty, Type::Primitive(PrimitiveType::Float));
3820 }
3821
3822 #[test]
3823 fn infer_bool_literal() {
3824 let gen = NodeIdGen::new();
3825 let mut checker = TypeChecker::new();
3826 let node = bool_lit(&gen, true);
3827 let ty = checker.infer_expr(&node);
3828 assert_eq!(ty, Type::Primitive(PrimitiveType::Bool));
3829 }
3830
3831 #[test]
3832 fn infer_string_literal() {
3833 let gen = NodeIdGen::new();
3834 let mut checker = TypeChecker::new();
3835 let node = str_lit(&gen);
3836 let ty = checker.infer_expr(&node);
3837 assert_eq!(ty, Type::Primitive(PrimitiveType::String));
3838 }
3839
3840 #[test]
3843 fn infer_defined_variable() {
3844 let gen = NodeIdGen::new();
3845 let mut checker = TypeChecker::new();
3846 checker.env.define("x", Type::Primitive(PrimitiveType::Int));
3847 let node = make_node(&gen, NodeKind::Identifier { name: ident("x") });
3848 let ty = checker.infer_expr(&node);
3849 assert_eq!(ty, Type::Primitive(PrimitiveType::Int));
3850 }
3851
3852 #[test]
3853 fn infer_undefined_variable_emits_error() {
3854 let gen = NodeIdGen::new();
3855 let mut checker = TypeChecker::new();
3856 let node = make_node(
3857 &gen,
3858 NodeKind::Identifier {
3859 name: ident("unknown"),
3860 },
3861 );
3862 let ty = checker.infer_expr(&node);
3863 assert_eq!(ty, Type::Error);
3864 assert!(checker.diags.has_errors());
3865 }
3866
3867 #[test]
3870 fn infer_int_addition() {
3871 let gen = NodeIdGen::new();
3872 let mut checker = TypeChecker::new();
3873 let left = int_lit(&gen);
3874 let right = int_lit(&gen);
3875 let node = make_node(
3876 &gen,
3877 NodeKind::BinaryOp {
3878 op: BinOp::Add,
3879 left: Box::new(left),
3880 right: Box::new(right),
3881 },
3882 );
3883 let ty = checker.infer_expr(&node);
3884 assert_eq!(ty, Type::Primitive(PrimitiveType::Int));
3885 }
3886
3887 #[test]
3888 fn infer_comparison_returns_bool() {
3889 let gen = NodeIdGen::new();
3890 let mut checker = TypeChecker::new();
3891 let left = int_lit(&gen);
3892 let right = int_lit(&gen);
3893 let node = make_node(
3894 &gen,
3895 NodeKind::BinaryOp {
3896 op: BinOp::Lt,
3897 left: Box::new(left),
3898 right: Box::new(right),
3899 },
3900 );
3901 let ty = checker.infer_expr(&node);
3902 assert_eq!(ty, Type::Primitive(PrimitiveType::Bool));
3903 }
3904
3905 #[test]
3906 fn infer_logical_and_requires_bool() {
3907 let gen = NodeIdGen::new();
3908 let mut checker = TypeChecker::new();
3909 let left = bool_lit(&gen, true);
3910 let right = bool_lit(&gen, false);
3911 let node = make_node(
3912 &gen,
3913 NodeKind::BinaryOp {
3914 op: BinOp::And,
3915 left: Box::new(left),
3916 right: Box::new(right),
3917 },
3918 );
3919 let ty = checker.infer_expr(&node);
3920 assert_eq!(ty, Type::Primitive(PrimitiveType::Bool));
3921 }
3922
3923 #[test]
3924 fn type_mismatch_in_binop_emits_error() {
3925 let gen = NodeIdGen::new();
3926 let mut checker = TypeChecker::new();
3927 let left = int_lit(&gen);
3928 let right = bool_lit(&gen, true);
3929 let node = make_node(
3930 &gen,
3931 NodeKind::BinaryOp {
3932 op: BinOp::Add,
3933 left: Box::new(left),
3934 right: Box::new(right),
3935 },
3936 );
3937 checker.infer_expr(&node);
3938 assert!(checker.diags.has_errors());
3939 }
3940
3941 #[test]
3944 fn infer_neg_int() {
3945 let gen = NodeIdGen::new();
3946 let mut checker = TypeChecker::new();
3947 let operand = int_lit(&gen);
3948 let node = make_node(
3949 &gen,
3950 NodeKind::UnaryOp {
3951 op: UnaryOp::Neg,
3952 operand: Box::new(operand),
3953 },
3954 );
3955 let ty = checker.infer_expr(&node);
3956 assert_eq!(ty, Type::Primitive(PrimitiveType::Int));
3957 }
3958
3959 #[test]
3960 fn infer_not_bool() {
3961 let gen = NodeIdGen::new();
3962 let mut checker = TypeChecker::new();
3963 let operand = bool_lit(&gen, true);
3964 let node = make_node(
3965 &gen,
3966 NodeKind::UnaryOp {
3967 op: UnaryOp::Not,
3968 operand: Box::new(operand),
3969 },
3970 );
3971 let ty = checker.infer_expr(&node);
3972 assert_eq!(ty, Type::Primitive(PrimitiveType::Bool));
3973 }
3974
3975 #[test]
3978 fn check_list_literal_against_list_int() {
3979 let gen = NodeIdGen::new();
3980 let mut checker = TypeChecker::new();
3981 let expected = Type::Generic(GenericType {
3982 constructor: "List".into(),
3983 args: vec![Type::Primitive(PrimitiveType::Int)],
3984 });
3985 let node = make_node(
3986 &gen,
3987 NodeKind::ListLiteral {
3988 elems: vec![int_lit(&gen), int_lit(&gen)],
3989 },
3990 );
3991 checker.check_expr(&node, &expected);
3992 assert!(!checker.diags.has_errors());
3993 }
3994
3995 #[test]
3996 fn list_element_mismatch_emits_error() {
3997 let gen = NodeIdGen::new();
3998 let mut checker = TypeChecker::new();
3999 let expected = Type::Generic(GenericType {
4000 constructor: "List".into(),
4001 args: vec![Type::Primitive(PrimitiveType::Int)],
4002 });
4003 let node = make_node(
4004 &gen,
4005 NodeKind::ListLiteral {
4006 elems: vec![int_lit(&gen), bool_lit(&gen, true)],
4007 },
4008 );
4009 checker.check_expr(&node, &expected);
4010 assert!(checker.diags.has_errors());
4011 }
4012
4013 #[test]
4016 fn infer_list_literal() {
4017 let gen = NodeIdGen::new();
4018 let mut checker = TypeChecker::new();
4019 let node = make_node(
4020 &gen,
4021 NodeKind::ListLiteral {
4022 elems: vec![int_lit(&gen), int_lit(&gen)],
4023 },
4024 );
4025 let ty = checker.infer_expr(&node);
4026 assert!(matches!(&ty, Type::Generic(g) if g.constructor == "List"
4027 && g.args.len() == 1
4028 && g.args[0] == Type::Primitive(PrimitiveType::Int)));
4029 }
4030
4031 #[test]
4034 fn infer_tuple_literal() {
4035 let gen = NodeIdGen::new();
4036 let mut checker = TypeChecker::new();
4037 let node = make_node(
4038 &gen,
4039 NodeKind::TupleLiteral {
4040 elems: vec![int_lit(&gen), bool_lit(&gen, false)],
4041 },
4042 );
4043 let ty = checker.infer_expr(&node);
4044 assert_eq!(
4045 ty,
4046 Type::Tuple(vec![
4047 Type::Primitive(PrimitiveType::Int),
4048 Type::Primitive(PrimitiveType::Bool),
4049 ])
4050 );
4051 }
4052
4053 #[test]
4056 fn infer_block_tail_expression() {
4057 let gen = NodeIdGen::new();
4058 let mut checker = TypeChecker::new();
4059 let tail = int_lit(&gen);
4060 let node = make_node(
4061 &gen,
4062 NodeKind::Block {
4063 stmts: vec![],
4064 tail: Some(Box::new(tail)),
4065 },
4066 );
4067 let ty = checker.infer_expr(&node);
4068 assert_eq!(ty, Type::Primitive(PrimitiveType::Int));
4069 }
4070
4071 #[test]
4072 fn infer_block_no_tail_is_void() {
4073 let gen = NodeIdGen::new();
4074 let mut checker = TypeChecker::new();
4075 let node = make_node(
4076 &gen,
4077 NodeKind::Block {
4078 stmts: vec![],
4079 tail: None,
4080 },
4081 );
4082 let ty = checker.infer_expr(&node);
4083 assert_eq!(ty, Type::Primitive(PrimitiveType::Void));
4084 }
4085
4086 #[test]
4089 fn let_binding_infers_and_binds() {
4090 let gen = NodeIdGen::new();
4091 let mut checker = TypeChecker::new();
4092 let pat = make_node(
4093 &gen,
4094 NodeKind::BindPat {
4095 name: ident("x"),
4096 is_mut: false,
4097 },
4098 );
4099 let val = int_lit(&gen);
4100 let let_node = make_node(
4101 &gen,
4102 NodeKind::LetBinding {
4103 is_mut: false,
4104 pattern: Box::new(pat),
4105 ty: None,
4106 value: Box::new(val),
4107 },
4108 );
4109 let ident_x = make_node(&gen, NodeKind::Identifier { name: ident("x") });
4111 let block = make_node(
4112 &gen,
4113 NodeKind::Block {
4114 stmts: vec![let_node],
4115 tail: Some(Box::new(ident_x)),
4116 },
4117 );
4118 let ty = checker.infer_expr(&block);
4119 assert_eq!(ty, Type::Primitive(PrimitiveType::Int));
4120 assert!(!checker.diags.has_errors());
4121 }
4122
4123 #[test]
4126 fn fresh_var_for_generic_params() {
4127 let mut checker = TypeChecker::new();
4128 let t_var = checker.fresh_var(); let t_id = match &t_var {
4132 Type::TypeVar(id) => *id,
4133 _ => unreachable!(),
4134 };
4135 let sig = FnSig {
4136 generic_params: vec!["T".into()],
4137 generic_var_ids: vec![t_id],
4138 param_types: vec![Type::Generic(GenericType {
4139 constructor: "List".into(),
4140 args: vec![t_var.clone()],
4141 })],
4142 return_type: Type::Optional(Box::new(t_var)),
4143 where_clause: vec![],
4144 };
4145
4146 let gen = NodeIdGen::new();
4147 let arg = make_node(
4148 &gen,
4149 NodeKind::ListLiteral {
4150 elems: vec![int_lit(&gen)],
4151 },
4152 );
4153 let args: Vec<bock_air::AirArg> = vec![bock_air::AirArg {
4154 label: None,
4155 value: arg,
4156 }];
4157
4158 let ret = checker.instantiate_and_check("first", &sig, &args, span());
4159 assert!(!checker.diags.has_errors());
4162 assert!(matches!(ret, Type::Optional(_)));
4163 }
4164
4165 fn register_generic_fn(
4167 checker: &mut TypeChecker,
4168 name: &str,
4169 generic_names: &[&str],
4170 build_sig: impl FnOnce(&[Type]) -> (Vec<Type>, Type),
4171 ) {
4172 let vars: Vec<Type> = generic_names.iter().map(|_| checker.fresh_var()).collect();
4173 let var_ids: Vec<TypeVarId> = vars
4174 .iter()
4175 .map(|t| match t {
4176 Type::TypeVar(id) => *id,
4177 _ => unreachable!(),
4178 })
4179 .collect();
4180 let (param_types, return_type) = build_sig(&vars);
4181 let fn_ty = Type::Function(FnType {
4182 params: param_types.clone(),
4183 ret: Box::new(return_type.clone()),
4184 effects: vec![],
4185 });
4186 checker.env.define(name, fn_ty);
4187 checker.fn_sigs.insert(
4188 name.into(),
4189 FnSig {
4190 generic_params: generic_names.iter().map(|s| (*s).into()).collect(),
4191 generic_var_ids: var_ids,
4192 param_types,
4193 return_type,
4194 where_clause: vec![],
4195 },
4196 );
4197 }
4198
4199 #[test]
4200 fn generic_first_infers_int() {
4201 let gen = NodeIdGen::new();
4203 let mut checker = TypeChecker::new();
4204 register_generic_fn(&mut checker, "first", &["T"], |vars| {
4205 let t = vars[0].clone();
4206 let params = vec![Type::Generic(GenericType {
4207 constructor: "List".into(),
4208 args: vec![t.clone()],
4209 })];
4210 (params, t)
4211 });
4212
4213 let callee = make_node(
4214 &gen,
4215 NodeKind::Identifier {
4216 name: ident("first"),
4217 },
4218 );
4219 let list_arg = make_node(
4220 &gen,
4221 NodeKind::ListLiteral {
4222 elems: vec![int_lit(&gen), int_lit(&gen), int_lit(&gen)],
4223 },
4224 );
4225 let call = make_node(
4226 &gen,
4227 NodeKind::Call {
4228 callee: Box::new(callee),
4229 type_args: vec![],
4230 args: vec![bock_air::AirArg {
4231 label: None,
4232 value: list_arg,
4233 }],
4234 },
4235 );
4236
4237 let ty = checker.infer_expr(&call);
4238 assert_eq!(ty, Type::Primitive(PrimitiveType::Int));
4239 assert!(!checker.diags.has_errors());
4240 }
4241
4242 #[test]
4243 fn generic_identity_infers_string() {
4244 let gen = NodeIdGen::new();
4246 let mut checker = TypeChecker::new();
4247 register_generic_fn(&mut checker, "identity", &["T"], |vars| {
4248 let t = vars[0].clone();
4249 (vec![t.clone()], t)
4250 });
4251
4252 let callee = make_node(
4253 &gen,
4254 NodeKind::Identifier {
4255 name: ident("identity"),
4256 },
4257 );
4258 let call = make_node(
4259 &gen,
4260 NodeKind::Call {
4261 callee: Box::new(callee),
4262 type_args: vec![],
4263 args: vec![bock_air::AirArg {
4264 label: None,
4265 value: str_lit(&gen),
4266 }],
4267 },
4268 );
4269
4270 let ty = checker.infer_expr(&call);
4271 assert_eq!(ty, Type::Primitive(PrimitiveType::String));
4272 assert!(!checker.diags.has_errors());
4273 }
4274
4275 #[test]
4276 fn generic_two_params_swap() {
4277 let gen = NodeIdGen::new();
4279 let mut checker = TypeChecker::new();
4280 register_generic_fn(&mut checker, "swap", &["A", "B"], |vars| {
4281 let a = vars[0].clone();
4282 let b = vars[1].clone();
4283 let params = vec![a.clone(), b.clone()];
4284 let ret = Type::Tuple(vec![b, a]);
4285 (params, ret)
4286 });
4287
4288 let callee = make_node(
4289 &gen,
4290 NodeKind::Identifier {
4291 name: ident("swap"),
4292 },
4293 );
4294 let call = make_node(
4295 &gen,
4296 NodeKind::Call {
4297 callee: Box::new(callee),
4298 type_args: vec![],
4299 args: vec![
4300 bock_air::AirArg {
4301 label: None,
4302 value: int_lit(&gen),
4303 },
4304 bock_air::AirArg {
4305 label: None,
4306 value: str_lit(&gen),
4307 },
4308 ],
4309 },
4310 );
4311
4312 let ty = checker.infer_expr(&call);
4313 assert_eq!(
4314 ty,
4315 Type::Tuple(vec![
4316 Type::Primitive(PrimitiveType::String),
4317 Type::Primitive(PrimitiveType::Int),
4318 ])
4319 );
4320 assert!(!checker.diags.has_errors());
4321 }
4322
4323 #[test]
4324 fn method_call_on_known_type_returns_correct_type() {
4325 let gen = NodeIdGen::new();
4327 let mut checker = TypeChecker::new();
4328 let list = make_node(
4329 &gen,
4330 NodeKind::ListLiteral {
4331 elems: vec![int_lit(&gen), int_lit(&gen), int_lit(&gen)],
4332 },
4333 );
4334 let method_call = make_node(
4335 &gen,
4336 NodeKind::MethodCall {
4337 receiver: Box::new(list),
4338 method: ident("len"),
4339 type_args: vec![],
4340 args: vec![],
4341 },
4342 );
4343 let ty = checker.infer_expr(&method_call);
4344 assert_eq!(ty, Type::Primitive(PrimitiveType::Int));
4345 assert!(!checker.diags.has_errors());
4346 }
4347
4348 #[test]
4349 fn method_call_string_contains_returns_bool() {
4350 let gen = NodeIdGen::new();
4352 let mut checker = TypeChecker::new();
4353 let receiver = str_lit(&gen);
4354 let method_call = make_node(
4355 &gen,
4356 NodeKind::MethodCall {
4357 receiver: Box::new(receiver),
4358 method: ident("contains"),
4359 type_args: vec![],
4360 args: vec![bock_air::AirArg {
4361 label: None,
4362 value: str_lit(&gen),
4363 }],
4364 },
4365 );
4366 let ty = checker.infer_expr(&method_call);
4367 assert_eq!(ty, Type::Primitive(PrimitiveType::Bool));
4368 assert!(!checker.diags.has_errors());
4369 }
4370
4371 #[test]
4374 fn infer_interpolation_is_string() {
4375 let gen = NodeIdGen::new();
4376 let mut checker = TypeChecker::new();
4377 let node = make_node(
4378 &gen,
4379 NodeKind::Interpolation {
4380 parts: vec![
4381 bock_air::AirInterpolationPart::Literal("hello ".into()),
4382 bock_air::AirInterpolationPart::Expr(Box::new(int_lit(&gen))),
4383 ],
4384 },
4385 );
4386 let ty = checker.infer_expr(&node);
4387 assert_eq!(ty, Type::Primitive(PrimitiveType::String));
4388 }
4389
4390 #[test]
4393 fn infer_unreachable_is_never() {
4394 let gen = NodeIdGen::new();
4395 let mut checker = TypeChecker::new();
4396 let node = make_node(&gen, NodeKind::Unreachable);
4397 let ty = checker.infer_expr(&node);
4398 assert_eq!(ty, Type::Primitive(PrimitiveType::Never));
4399 }
4400
4401 #[test]
4404 fn check_module_simple_fn() {
4405 let gen = NodeIdGen::new();
4406 let mut checker = TypeChecker::new();
4407
4408 let x_pat = make_node(
4410 &gen,
4411 NodeKind::BindPat {
4412 name: ident("x"),
4413 is_mut: false,
4414 },
4415 );
4416 let y_pat = make_node(
4417 &gen,
4418 NodeKind::BindPat {
4419 name: ident("y"),
4420 is_mut: false,
4421 },
4422 );
4423
4424 let int_ty = type_named_node(&gen, "Int");
4425
4426 let x_param = make_node(
4427 &gen,
4428 NodeKind::Param {
4429 pattern: Box::new(x_pat),
4430 ty: Some(Box::new(int_ty.clone())),
4431 default: None,
4432 },
4433 );
4434 let y_param = make_node(
4435 &gen,
4436 NodeKind::Param {
4437 pattern: Box::new(y_pat),
4438 ty: Some(Box::new(int_ty.clone())),
4439 default: None,
4440 },
4441 );
4442
4443 let x_ref = make_node(&gen, NodeKind::Identifier { name: ident("x") });
4444 let y_ref = make_node(&gen, NodeKind::Identifier { name: ident("y") });
4445 let add_expr = make_node(
4446 &gen,
4447 NodeKind::BinaryOp {
4448 op: BinOp::Add,
4449 left: Box::new(x_ref),
4450 right: Box::new(y_ref),
4451 },
4452 );
4453
4454 let body = make_node(
4455 &gen,
4456 NodeKind::Block {
4457 stmts: vec![],
4458 tail: Some(Box::new(add_expr)),
4459 },
4460 );
4461
4462 let ret_ty = type_named_node(&gen, "Int");
4463
4464 let fn_node = make_node(
4465 &gen,
4466 NodeKind::FnDecl {
4467 annotations: vec![],
4468 visibility: bock_ast::Visibility::Public,
4469 is_async: false,
4470 name: ident("add"),
4471 generic_params: vec![],
4472 params: vec![x_param, y_param],
4473 return_type: Some(Box::new(ret_ty)),
4474 effect_clause: vec![],
4475 where_clause: vec![],
4476 body: Box::new(body),
4477 },
4478 );
4479
4480 let mut module = make_node(
4481 &gen,
4482 NodeKind::Module {
4483 path: None,
4484 annotations: vec![],
4485 imports: vec![],
4486 items: vec![fn_node],
4487 },
4488 );
4489
4490 checker.check_module(&mut module);
4491 assert!(
4492 !checker.diags.has_errors(),
4493 "errors: {:?}",
4494 checker.diags.iter().collect::<Vec<_>>()
4495 );
4496 }
4497
4498 #[test]
4501 fn check_lambda_from_context() {
4502 let gen = NodeIdGen::new();
4503 let mut checker = TypeChecker::new();
4504
4505 let x_pat = make_node(
4507 &gen,
4508 NodeKind::BindPat {
4509 name: ident("x"),
4510 is_mut: false,
4511 },
4512 );
4513 let x_param = make_node(
4514 &gen,
4515 NodeKind::Param {
4516 pattern: Box::new(x_pat),
4517 ty: None,
4518 default: None,
4519 },
4520 );
4521 let x_ref = make_node(&gen, NodeKind::Identifier { name: ident("x") });
4522 let one = make_node(
4523 &gen,
4524 NodeKind::Literal {
4525 lit: Literal::Int("1".into()),
4526 },
4527 );
4528 let body = make_node(
4529 &gen,
4530 NodeKind::BinaryOp {
4531 op: BinOp::Add,
4532 left: Box::new(x_ref),
4533 right: Box::new(one),
4534 },
4535 );
4536
4537 let lambda = make_node(
4538 &gen,
4539 NodeKind::Lambda {
4540 params: vec![x_param],
4541 body: Box::new(body),
4542 },
4543 );
4544
4545 let expected = Type::Function(FnType {
4546 params: vec![Type::Primitive(PrimitiveType::Int)],
4547 ret: Box::new(Type::Primitive(PrimitiveType::Int)),
4548 effects: vec![],
4549 });
4550
4551 checker.check_expr(&lambda, &expected);
4552 assert!(!checker.diags.has_errors());
4553 }
4554
4555 #[test]
4558 fn error_type_prevents_cascade() {
4559 let gen = NodeIdGen::new();
4560 let mut checker = TypeChecker::new();
4561
4562 let undef = make_node(
4564 &gen,
4565 NodeKind::Identifier {
4566 name: ident("undefined_var"),
4567 },
4568 );
4569 let one = int_lit(&gen);
4570 let add = make_node(
4571 &gen,
4572 NodeKind::BinaryOp {
4573 op: BinOp::Add,
4574 left: Box::new(undef),
4575 right: Box::new(one),
4576 },
4577 );
4578 let ty = checker.infer_expr(&add);
4579 assert_eq!(checker.diags.error_count(), 1);
4581 assert_eq!(ty, Type::Error);
4582 }
4583
4584 #[test]
4587 fn where_clause_unknown_param_emits_error() {
4588 let mut checker = TypeChecker::new();
4589 let clauses = vec![TypeConstraint {
4590 id: 0,
4591 span: span(),
4592 param: ident("X"), bounds: vec![TypePath {
4594 segments: vec![ident("Equatable")],
4595 span: span(),
4596 }],
4597 }];
4598 checker.check_where_clause(&clauses, &HashMap::new(), span());
4599 assert!(checker.diags.has_errors());
4600 }
4601
4602 fn type_named_node_with_args(gen: &NodeIdGen, name: &str, args: Vec<AIRNode>) -> AIRNode {
4605 make_node(
4606 gen,
4607 NodeKind::TypeNamed {
4608 path: TypePath {
4609 segments: vec![ident(name)],
4610 span: span(),
4611 },
4612 args,
4613 },
4614 )
4615 }
4616
4617 #[test]
4618 fn result_annotation_produces_type_result() {
4619 let gen = NodeIdGen::new();
4620 let mut checker = TypeChecker::new();
4621 let int_node = type_named_node(&gen, "Int");
4622 let string_node = type_named_node(&gen, "String");
4623 let result_node = type_named_node_with_args(&gen, "Result", vec![int_node, string_node]);
4624 let ty = checker.air_type_node_to_type(&result_node, &HashMap::new());
4625 assert_eq!(
4626 ty,
4627 Type::Result(
4628 Box::new(Type::Primitive(PrimitiveType::Int)),
4629 Box::new(Type::Primitive(PrimitiveType::String)),
4630 )
4631 );
4632 }
4633
4634 #[test]
4635 fn optional_annotation_produces_type_optional() {
4636 let gen = NodeIdGen::new();
4637 let mut checker = TypeChecker::new();
4638 let int_node = type_named_node(&gen, "Int");
4639 let optional_node = type_named_node_with_args(&gen, "Optional", vec![int_node]);
4640 let ty = checker.air_type_node_to_type(&optional_node, &HashMap::new());
4641 assert_eq!(
4642 ty,
4643 Type::Optional(Box::new(Type::Primitive(PrimitiveType::Int)))
4644 );
4645 }
4646
4647 #[test]
4648 fn result_annotation_unifies_with_ok_construction() {
4649 let annotated = Type::Result(
4652 Box::new(Type::Primitive(PrimitiveType::Int)),
4653 Box::new(Type::Primitive(PrimitiveType::String)),
4654 );
4655 let constructed = Type::Result(
4656 Box::new(Type::Primitive(PrimitiveType::Int)),
4657 Box::new(Type::TypeVar(99)),
4658 );
4659 let mut subst = crate::Substitution::new();
4660 assert!(crate::unify(&annotated, &constructed, &mut subst).is_ok());
4661 assert_eq!(subst.lookup(99), Type::Primitive(PrimitiveType::String));
4662 }
4663
4664 #[test]
4665 fn optional_annotation_unifies_with_some_construction() {
4666 let annotated = Type::Optional(Box::new(Type::Primitive(PrimitiveType::Int)));
4669 let constructed = Type::Optional(Box::new(Type::Primitive(PrimitiveType::Int)));
4670 let mut subst = crate::Substitution::new();
4671 assert!(crate::unify(&annotated, &constructed, &mut subst).is_ok());
4672 }
4673
4674 fn register_generic_fn_with_bounds(
4678 checker: &mut TypeChecker,
4679 name: &str,
4680 generic_names: &[&str],
4681 bounds: Vec<TypeConstraint>,
4682 build_sig: impl FnOnce(&[Type]) -> (Vec<Type>, Type),
4683 ) {
4684 let vars: Vec<Type> = generic_names.iter().map(|_| checker.fresh_var()).collect();
4685 let var_ids: Vec<TypeVarId> = vars
4686 .iter()
4687 .map(|t| match t {
4688 Type::TypeVar(id) => *id,
4689 _ => unreachable!(),
4690 })
4691 .collect();
4692 let (param_types, return_type) = build_sig(&vars);
4693 let fn_ty = Type::Function(FnType {
4694 params: param_types.clone(),
4695 ret: Box::new(return_type.clone()),
4696 effects: vec![],
4697 });
4698 checker.env.define(name, fn_ty);
4699 checker.fn_sigs.insert(
4700 name.into(),
4701 FnSig {
4702 generic_params: generic_names.iter().map(|s| (*s).into()).collect(),
4703 generic_var_ids: var_ids,
4704 param_types,
4705 return_type,
4706 where_clause: bounds,
4707 },
4708 );
4709 }
4710
4711 fn make_constraint(param: &str, bound_names: &[&str]) -> TypeConstraint {
4713 use bock_ast::TypeConstraint;
4714 TypeConstraint {
4715 id: 0,
4716 span: span(),
4717 param: ident(param),
4718 bounds: bound_names
4719 .iter()
4720 .map(|b| TypePath {
4721 segments: vec![ident(b)],
4722 span: span(),
4723 })
4724 .collect(),
4725 }
4726 }
4727
4728 fn make_impl_table(impls: &[(&str, Type)]) -> ImplTable {
4730 let mut table = ImplTable::new();
4731 for (trait_name, ty) in impls {
4732 table.register_trait_impl(*trait_name, ty);
4733 }
4734 table
4735 }
4736
4737 #[test]
4738 fn trait_bound_satisfied_no_error() {
4739 let gen = NodeIdGen::new();
4742 let mut checker = TypeChecker::new();
4743
4744 checker.impl_table = Some(make_impl_table(&[(
4746 "Comparable",
4747 Type::Primitive(PrimitiveType::Int),
4748 )]));
4749
4750 let bounds = vec![make_constraint("T", &["Comparable"])];
4751 register_generic_fn_with_bounds(&mut checker, "sort", &["T"], bounds, |vars| {
4752 let t = vars[0].clone();
4753 let list_t = Type::Generic(GenericType {
4754 constructor: "List".into(),
4755 args: vec![t.clone()],
4756 });
4757 (vec![list_t.clone()], list_t)
4758 });
4759
4760 let callee = make_node(
4761 &gen,
4762 NodeKind::Identifier {
4763 name: ident("sort"),
4764 },
4765 );
4766 let list_arg = make_node(
4767 &gen,
4768 NodeKind::ListLiteral {
4769 elems: vec![int_lit(&gen), int_lit(&gen)],
4770 },
4771 );
4772 let call = make_node(
4773 &gen,
4774 NodeKind::Call {
4775 callee: Box::new(callee),
4776 type_args: vec![],
4777 args: vec![bock_air::AirArg {
4778 label: None,
4779 value: list_arg,
4780 }],
4781 },
4782 );
4783
4784 checker.infer_expr(&call);
4785 assert!(
4786 !checker.diags.has_errors(),
4787 "expected no errors for Int: Comparable"
4788 );
4789 }
4790
4791 #[test]
4792 fn trait_bound_violated_emits_diagnostic() {
4793 let gen = NodeIdGen::new();
4796 let mut checker = TypeChecker::new();
4797
4798 checker.impl_table = Some(make_impl_table(&[(
4800 "Comparable",
4801 Type::Primitive(PrimitiveType::Int),
4802 )]));
4803
4804 let bounds = vec![make_constraint("T", &["Comparable"])];
4805 register_generic_fn_with_bounds(&mut checker, "sort", &["T"], bounds, |vars| {
4806 let t = vars[0].clone();
4807 let list_t = Type::Generic(GenericType {
4808 constructor: "List".into(),
4809 args: vec![t.clone()],
4810 });
4811 (vec![list_t.clone()], list_t)
4812 });
4813
4814 let callee = make_node(
4815 &gen,
4816 NodeKind::Identifier {
4817 name: ident("sort"),
4818 },
4819 );
4820 let list_arg = make_node(
4821 &gen,
4822 NodeKind::ListLiteral {
4823 elems: vec![bool_lit(&gen, true), bool_lit(&gen, false)],
4824 },
4825 );
4826 let call = make_node(
4827 &gen,
4828 NodeKind::Call {
4829 callee: Box::new(callee),
4830 type_args: vec![],
4831 args: vec![bock_air::AirArg {
4832 label: None,
4833 value: list_arg,
4834 }],
4835 },
4836 );
4837
4838 checker.infer_expr(&call);
4839 assert!(
4840 checker.diags.has_errors(),
4841 "expected error: Bool does not implement Comparable"
4842 );
4843 assert_eq!(checker.diags.error_count(), 1);
4844 }
4845
4846 #[test]
4847 fn multiple_trait_bounds_both_satisfied() {
4848 let gen = NodeIdGen::new();
4852 let mut checker = TypeChecker::new();
4853
4854 checker.impl_table = Some(make_impl_table(&[
4855 ("Comparable", Type::Primitive(PrimitiveType::Int)),
4856 ("Displayable", Type::Primitive(PrimitiveType::Int)),
4857 ]));
4858
4859 let bounds = vec![make_constraint("T", &["Comparable", "Displayable"])];
4860 register_generic_fn_with_bounds(&mut checker, "display_sorted", &["T"], bounds, |vars| {
4861 let t = vars[0].clone();
4862 let list_t = Type::Generic(GenericType {
4863 constructor: "List".into(),
4864 args: vec![t],
4865 });
4866 (vec![list_t], Type::Primitive(PrimitiveType::Void))
4867 });
4868
4869 let callee = make_node(
4870 &gen,
4871 NodeKind::Identifier {
4872 name: ident("display_sorted"),
4873 },
4874 );
4875 let list_arg = make_node(
4876 &gen,
4877 NodeKind::ListLiteral {
4878 elems: vec![int_lit(&gen)],
4879 },
4880 );
4881 let call = make_node(
4882 &gen,
4883 NodeKind::Call {
4884 callee: Box::new(callee),
4885 type_args: vec![],
4886 args: vec![bock_air::AirArg {
4887 label: None,
4888 value: list_arg,
4889 }],
4890 },
4891 );
4892
4893 checker.infer_expr(&call);
4894 assert!(
4895 !checker.diags.has_errors(),
4896 "expected no errors: Int satisfies both bounds"
4897 );
4898 }
4899
4900 #[test]
4901 fn multiple_trait_bounds_one_missing() {
4902 let gen = NodeIdGen::new();
4906 let mut checker = TypeChecker::new();
4907
4908 checker.impl_table = Some(make_impl_table(&[(
4910 "Comparable",
4911 Type::Primitive(PrimitiveType::Int),
4912 )]));
4913
4914 let bounds = vec![make_constraint("T", &["Comparable", "Displayable"])];
4915 register_generic_fn_with_bounds(&mut checker, "display_sorted", &["T"], bounds, |vars| {
4916 let t = vars[0].clone();
4917 let list_t = Type::Generic(GenericType {
4918 constructor: "List".into(),
4919 args: vec![t],
4920 });
4921 (vec![list_t], Type::Primitive(PrimitiveType::Void))
4922 });
4923
4924 let callee = make_node(
4925 &gen,
4926 NodeKind::Identifier {
4927 name: ident("display_sorted"),
4928 },
4929 );
4930 let list_arg = make_node(
4931 &gen,
4932 NodeKind::ListLiteral {
4933 elems: vec![int_lit(&gen)],
4934 },
4935 );
4936 let call = make_node(
4937 &gen,
4938 NodeKind::Call {
4939 callee: Box::new(callee),
4940 type_args: vec![],
4941 args: vec![bock_air::AirArg {
4942 label: None,
4943 value: list_arg,
4944 }],
4945 },
4946 );
4947
4948 checker.infer_expr(&call);
4949 assert!(
4950 checker.diags.has_errors(),
4951 "expected error: Int missing Displayable"
4952 );
4953 assert_eq!(checker.diags.error_count(), 1);
4954 }
4955
4956 #[test]
4957 fn no_impl_table_skips_bound_checking() {
4958 let gen = NodeIdGen::new();
4960 let mut checker = TypeChecker::new();
4961 let bounds = vec![make_constraint("T", &["Comparable"])];
4964 register_generic_fn_with_bounds(&mut checker, "sort", &["T"], bounds, |vars| {
4965 let t = vars[0].clone();
4966 (vec![t.clone()], t)
4967 });
4968
4969 let callee = make_node(
4970 &gen,
4971 NodeKind::Identifier {
4972 name: ident("sort"),
4973 },
4974 );
4975 let call = make_node(
4976 &gen,
4977 NodeKind::Call {
4978 callee: Box::new(callee),
4979 type_args: vec![],
4980 args: vec![bock_air::AirArg {
4981 label: None,
4982 value: int_lit(&gen),
4983 }],
4984 },
4985 );
4986
4987 checker.infer_expr(&call);
4988 assert!(!checker.diags.has_errors());
4990 }
4991
4992 #[test]
4995 fn infer_char_literal() {
4996 let gen = NodeIdGen::new();
4997 let mut checker = TypeChecker::new();
4998 let node = make_node(
4999 &gen,
5000 NodeKind::Literal {
5001 lit: Literal::Char("a".into()),
5002 },
5003 );
5004 let ty = checker.infer_expr(&node);
5005 assert_eq!(ty, Type::Primitive(PrimitiveType::Char));
5006 }
5007
5008 #[test]
5011 fn fn_type_carries_effects() {
5012 let gen = NodeIdGen::new();
5013 let mut checker = TypeChecker::new();
5014
5015 let body = make_node(
5017 &gen,
5018 NodeKind::Block {
5019 stmts: vec![],
5020 tail: None,
5021 },
5022 );
5023 let fn_decl = make_node(
5024 &gen,
5025 NodeKind::FnDecl {
5026 annotations: vec![],
5027 visibility: bock_ast::Visibility::Public,
5028 is_async: false,
5029 name: ident("greet"),
5030 generic_params: vec![],
5031 params: vec![],
5032 return_type: None,
5033 effect_clause: vec![
5034 TypePath {
5035 segments: vec![ident("Log")],
5036 span: span(),
5037 },
5038 TypePath {
5039 segments: vec![ident("Clock")],
5040 span: span(),
5041 },
5042 ],
5043 where_clause: vec![],
5044 body: Box::new(body),
5045 },
5046 );
5047
5048 let module = make_node(
5049 &gen,
5050 NodeKind::Module {
5051 path: None,
5052 annotations: vec![],
5053 imports: vec![],
5054 items: vec![fn_decl],
5055 },
5056 );
5057
5058 let mut module = module;
5059 checker.check_module(&mut module);
5060
5061 let fn_ty = checker
5063 .env
5064 .lookup("greet")
5065 .expect("greet should be defined");
5066 match fn_ty {
5067 Type::Function(f) => {
5068 assert_eq!(f.effects.len(), 2);
5069 assert_eq!(f.effects[0].name, "Log");
5070 assert_eq!(f.effects[1].name, "Clock");
5071 }
5072 other => panic!("expected Function type, got {other:?}"),
5073 }
5074 }
5075
5076 #[test]
5079 fn method_call_float_abs_returns_float() {
5080 let gen = NodeIdGen::new();
5081 let mut checker = TypeChecker::new();
5082 let receiver = float_lit(&gen);
5083 let method_call = make_node(
5084 &gen,
5085 NodeKind::MethodCall {
5086 receiver: Box::new(receiver),
5087 method: ident("abs"),
5088 type_args: vec![],
5089 args: vec![],
5090 },
5091 );
5092 let ty = checker.infer_expr(&method_call);
5093 assert_eq!(ty, Type::Primitive(PrimitiveType::Float));
5094 assert!(!checker.diags.has_errors());
5095 }
5096
5097 #[test]
5098 fn method_call_float_to_int_returns_int() {
5099 let gen = NodeIdGen::new();
5100 let mut checker = TypeChecker::new();
5101 let receiver = float_lit(&gen);
5102 let method_call = make_node(
5103 &gen,
5104 NodeKind::MethodCall {
5105 receiver: Box::new(receiver),
5106 method: ident("to_int"),
5107 type_args: vec![],
5108 args: vec![],
5109 },
5110 );
5111 let ty = checker.infer_expr(&method_call);
5112 assert_eq!(ty, Type::Primitive(PrimitiveType::Int));
5113 }
5114
5115 #[test]
5116 fn method_call_bool_negate_returns_bool() {
5117 let gen = NodeIdGen::new();
5118 let mut checker = TypeChecker::new();
5119 let receiver = bool_lit(&gen, true);
5120 let method_call = make_node(
5121 &gen,
5122 NodeKind::MethodCall {
5123 receiver: Box::new(receiver),
5124 method: ident("negate"),
5125 type_args: vec![],
5126 args: vec![],
5127 },
5128 );
5129 let ty = checker.infer_expr(&method_call);
5130 assert_eq!(ty, Type::Primitive(PrimitiveType::Bool));
5131 }
5132
5133 #[test]
5134 fn method_call_char_is_alpha_returns_bool() {
5135 let gen = NodeIdGen::new();
5136 let mut checker = TypeChecker::new();
5137 let receiver = make_node(
5138 &gen,
5139 NodeKind::Literal {
5140 lit: Literal::Char("a".into()),
5141 },
5142 );
5143 let method_call = make_node(
5144 &gen,
5145 NodeKind::MethodCall {
5146 receiver: Box::new(receiver),
5147 method: ident("is_alpha"),
5148 type_args: vec![],
5149 args: vec![],
5150 },
5151 );
5152 let ty = checker.infer_expr(&method_call);
5153 assert_eq!(ty, Type::Primitive(PrimitiveType::Bool));
5154 }
5155
5156 #[test]
5157 fn method_call_char_to_upper_returns_char() {
5158 let gen = NodeIdGen::new();
5159 let mut checker = TypeChecker::new();
5160 let receiver = make_node(
5161 &gen,
5162 NodeKind::Literal {
5163 lit: Literal::Char("a".into()),
5164 },
5165 );
5166 let method_call = make_node(
5167 &gen,
5168 NodeKind::MethodCall {
5169 receiver: Box::new(receiver),
5170 method: ident("to_upper"),
5171 type_args: vec![],
5172 args: vec![],
5173 },
5174 );
5175 let ty = checker.infer_expr(&method_call);
5176 assert_eq!(ty, Type::Primitive(PrimitiveType::Char));
5177 }
5178
5179 #[test]
5180 fn method_call_unknown_method_returns_fresh_var() {
5181 let gen = NodeIdGen::new();
5182 let mut checker = TypeChecker::new();
5183 let receiver = int_lit(&gen);
5184 let method_call = make_node(
5185 &gen,
5186 NodeKind::MethodCall {
5187 receiver: Box::new(receiver),
5188 method: ident("nonexistent"),
5189 type_args: vec![],
5190 args: vec![],
5191 },
5192 );
5193 let ty = checker.infer_expr(&method_call);
5194 assert!(matches!(ty, Type::TypeVar(_)));
5196 }
5197}