1use std::collections::HashMap;
12
13pub use bock_air::stubs::EffectRef;
14
15pub mod checker;
16pub use checker::{TypeChecker, TypeEnv};
17
18pub mod traits;
19pub use traits::{
20 check_supertrait_obligations, resolve_impl, resolve_method, ImplId, ImplTable, ResolvedMethod,
21 TraitRef,
22};
23
24pub mod ownership;
25pub use ownership::{analyze_ownership, AIRModule, OwnershipInfo, OwnershipState};
26
27pub mod effects;
28pub use effects::{infer_effects, track_effects, Strictness};
29
30pub mod capabilities;
31pub use capabilities::{compute_capabilities, verify_capabilities, CapabilitySet};
32
33pub mod exports;
34pub use exports::{collect_exports, type_to_type_ref};
35
36pub mod seed_imports;
37pub use seed_imports::seed_imports;
38
39pub mod vocab;
40
41#[derive(Debug, Clone, PartialEq, Eq, Hash)]
45pub enum PrimitiveType {
46 Int,
48 Float,
49 Int8,
51 Int16,
52 Int32,
53 Int64,
54 Int128,
55 UInt8,
56 UInt16,
57 UInt32,
58 UInt64,
59 Float32,
61 Float64,
62 BigInt,
64 BigFloat,
65 Decimal,
66 Bool,
68 Char,
69 String,
70 Byte,
71 Bytes,
72 Void,
74 Never,
75}
76
77#[derive(Debug, Clone, PartialEq)]
86pub struct NamedType {
87 pub name: String,
89}
90
91#[derive(Debug, Clone, PartialEq)]
97pub struct GenericType {
98 pub constructor: String,
100 pub args: Vec<Type>,
102}
103
104#[derive(Debug, Clone, PartialEq)]
108pub struct FnType {
109 pub params: Vec<Type>,
111 pub ret: Box<Type>,
113 pub effects: Vec<EffectRef>,
115}
116
117#[derive(Debug, Clone, PartialEq)]
124pub struct Predicate {
125 pub source: String,
127}
128
129#[derive(Debug, Clone, PartialEq, Default)]
136pub struct StructuralConstraints {
137 pub fields: Vec<(String, Type)>,
139}
140
141pub type TypeVarId = u32;
145
146#[derive(Debug, Clone, PartialEq)]
155pub enum Type {
156 Primitive(PrimitiveType),
158 Named(NamedType),
160 Generic(GenericType),
162 Tuple(Vec<Type>),
164 Function(FnType),
166 Optional(Box<Type>),
168 Result(Box<Type>, Box<Type>),
170 TypeVar(TypeVarId),
172 Refined(Box<Type>, Predicate),
174 Flexible(StructuralConstraints),
176 Error,
178}
179
180#[derive(Debug, Clone, PartialEq)]
184pub enum TypeError {
185 Mismatch {
187 left: Type,
189 right: Type,
191 },
192 OccursCheck {
195 var: TypeVarId,
197 ty: Type,
199 },
200 TupleArity { expected: usize, found: usize },
202 FnArity { expected: usize, found: usize },
204}
205
206impl std::fmt::Display for TypeError {
207 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208 match self {
209 TypeError::Mismatch { left, right } => {
210 write!(f, "type mismatch: {left:?} vs {right:?}")
211 }
212 TypeError::OccursCheck { var, ty } => {
213 write!(f, "occurs check failed: ?{var} in {ty:?}")
214 }
215 TypeError::TupleArity { expected, found } => {
216 write!(
217 f,
218 "tuple arity mismatch: expected {expected}, found {found}"
219 )
220 }
221 TypeError::FnArity { expected, found } => {
222 write!(
223 f,
224 "function arity mismatch: expected {expected}, found {found}"
225 )
226 }
227 }
228 }
229}
230
231impl std::error::Error for TypeError {}
232
233#[derive(Debug, Clone, Default)]
240pub struct Substitution {
241 map: HashMap<TypeVarId, Type>,
242}
243
244impl Substitution {
245 #[must_use]
247 pub fn new() -> Self {
248 Self::default()
249 }
250
251 #[must_use]
255 pub fn lookup(&self, mut id: TypeVarId) -> Type {
256 loop {
257 match self.map.get(&id) {
258 None => return Type::TypeVar(id),
259 Some(Type::TypeVar(next)) => {
260 id = *next;
261 }
262 Some(ty) => return ty.clone(),
263 }
264 }
265 }
266
267 pub fn bind(&mut self, id: TypeVarId, ty: Type) {
272 debug_assert!(
273 !self.map.contains_key(&id),
274 "TypeVar ?{id} is already bound"
275 );
276 self.map.insert(id, ty);
277 }
278
279 #[must_use]
282 pub fn apply(&self, ty: &Type) -> Type {
283 match ty {
284 Type::TypeVar(id) => {
285 let resolved = self.lookup(*id);
286 if resolved == *ty {
287 resolved
288 } else {
289 self.apply(&resolved)
290 }
291 }
292 Type::Primitive(_) | Type::Error => ty.clone(),
293 Type::Named(_) => ty.clone(),
294 Type::Generic(g) => Type::Generic(GenericType {
295 constructor: g.constructor.clone(),
296 args: g.args.iter().map(|a| self.apply(a)).collect(),
297 }),
298 Type::Tuple(elems) => Type::Tuple(elems.iter().map(|e| self.apply(e)).collect()),
299 Type::Function(f) => Type::Function(FnType {
300 params: f.params.iter().map(|p| self.apply(p)).collect(),
301 ret: Box::new(self.apply(&f.ret)),
302 effects: f.effects.clone(),
303 }),
304 Type::Optional(inner) => Type::Optional(Box::new(self.apply(inner))),
305 Type::Result(ok, err) => {
306 Type::Result(Box::new(self.apply(ok)), Box::new(self.apply(err)))
307 }
308 Type::Refined(base, pred) => Type::Refined(Box::new(self.apply(base)), pred.clone()),
309 Type::Flexible(constraints) => Type::Flexible(StructuralConstraints {
310 fields: constraints
311 .fields
312 .iter()
313 .map(|(name, ty)| (name.clone(), self.apply(ty)))
314 .collect(),
315 }),
316 }
317 }
318
319 #[must_use]
321 pub fn is_unbound(&self, id: TypeVarId) -> bool {
322 matches!(self.lookup(id), Type::TypeVar(_))
323 }
324}
325
326fn occurs(id: TypeVarId, ty: &Type, subst: &Substitution) -> bool {
331 match ty {
332 Type::TypeVar(other) => {
333 let resolved = subst.lookup(*other);
334 match resolved {
335 Type::TypeVar(rid) => rid == id,
336 _ => occurs(id, &resolved, subst),
337 }
338 }
339 Type::Primitive(_) | Type::Named(_) | Type::Error => false,
340 Type::Generic(g) => g.args.iter().any(|a| occurs(id, a, subst)),
341 Type::Tuple(elems) => elems.iter().any(|e| occurs(id, e, subst)),
342 Type::Function(f) => {
343 f.params.iter().any(|p| occurs(id, p, subst)) || occurs(id, &f.ret, subst)
344 }
345 Type::Optional(inner) => occurs(id, inner, subst),
346 Type::Result(ok, err) => occurs(id, ok, subst) || occurs(id, err, subst),
347 Type::Refined(base, _) => occurs(id, base, subst),
348 Type::Flexible(c) => c.fields.iter().any(|(_, t)| occurs(id, t, subst)),
349 }
350}
351
352pub fn unify(a: &Type, b: &Type, subst: &mut Substitution) -> Result<(), TypeError> {
366 let a = subst.apply(a);
368 let b = subst.apply(b);
369
370 match (&a, &b) {
371 (Type::Error, _) | (_, Type::Error) => Ok(()),
373
374 (Type::Primitive(PrimitiveType::Never), _)
376 | (_, Type::Primitive(PrimitiveType::Never)) => Ok(()),
377
378 _ if a == b => Ok(()),
380
381 (Type::TypeVar(id), other) | (other, Type::TypeVar(id)) => {
383 let id = *id;
384 if occurs(id, other, subst) {
385 return Err(TypeError::OccursCheck {
386 var: id,
387 ty: other.clone(),
388 });
389 }
390 subst.bind(id, other.clone());
391 Ok(())
392 }
393
394 (Type::Optional(a_inner), Type::Optional(b_inner)) => unify(a_inner, b_inner, subst),
398
399 (Type::Result(a_ok, a_err), Type::Result(b_ok, b_err)) => {
400 unify(a_ok, b_ok, subst)?;
401 unify(a_err, b_err, subst)
402 }
403
404 (Type::Tuple(a_elems), Type::Tuple(b_elems)) => {
405 if a_elems.len() != b_elems.len() {
406 return Err(TypeError::TupleArity {
407 expected: a_elems.len(),
408 found: b_elems.len(),
409 });
410 }
411 for (ae, be) in a_elems.iter().zip(b_elems.iter()) {
412 unify(ae, be, subst)?;
413 }
414 Ok(())
415 }
416
417 (Type::Function(fa), Type::Function(fb)) => {
418 if fa.params.len() != fb.params.len() {
419 return Err(TypeError::FnArity {
420 expected: fa.params.len(),
421 found: fb.params.len(),
422 });
423 }
424 for (ap, bp) in fa.params.iter().zip(fb.params.iter()) {
425 unify(ap, bp, subst)?;
426 }
427 unify(&fa.ret, &fb.ret, subst)
428 }
429
430 (Type::Generic(ga), Type::Generic(gb)) => {
431 if ga.constructor != gb.constructor {
432 return Err(TypeError::Mismatch {
433 left: a.clone(),
434 right: b.clone(),
435 });
436 }
437 if ga.args.len() != gb.args.len() {
438 return Err(TypeError::Mismatch {
439 left: a.clone(),
440 right: b.clone(),
441 });
442 }
443 for (aa, ba) in ga.args.iter().zip(gb.args.iter()) {
444 unify(aa, ba, subst)?;
445 }
446 Ok(())
447 }
448
449 (Type::Refined(base_a, _), Type::Refined(base_b, _)) => unify(base_a, base_b, subst),
451
452 (Type::Named(nt), Type::Generic(g)) | (Type::Generic(g), Type::Named(nt))
457 if nt.name == g.constructor =>
458 {
459 Ok(())
460 }
461
462 _ => Err(TypeError::Mismatch {
464 left: a.clone(),
465 right: b.clone(),
466 }),
467 }
468}
469
470#[must_use]
477pub fn types_equal(a: &Type, b: &Type, subst: &Substitution) -> bool {
478 let mut scratch = subst.clone();
479 unify(a, b, &mut scratch).is_ok()
480}
481
482#[cfg(test)]
485mod tests {
486 use super::*;
487
488 fn int() -> Type {
489 Type::Primitive(PrimitiveType::Int)
490 }
491
492 fn bool_ty() -> Type {
493 Type::Primitive(PrimitiveType::Bool)
494 }
495
496 fn string_ty() -> Type {
497 Type::Primitive(PrimitiveType::String)
498 }
499
500 fn var(id: TypeVarId) -> Type {
501 Type::TypeVar(id)
502 }
503
504 #[test]
507 fn subst_lookup_unbound() {
508 let s = Substitution::new();
509 assert_eq!(s.lookup(0), var(0));
510 }
511
512 #[test]
513 fn subst_bind_and_lookup() {
514 let mut s = Substitution::new();
515 s.bind(0, int());
516 assert_eq!(s.lookup(0), int());
517 }
518
519 #[test]
520 fn subst_chain_lookup() {
521 let mut s = Substitution::new();
522 s.bind(0, var(1));
523 s.bind(1, int());
524 assert_eq!(s.lookup(0), int());
525 }
526
527 #[test]
528 fn subst_apply_nested() {
529 let mut s = Substitution::new();
530 s.bind(0, int());
531 let ty = Type::Optional(Box::new(var(0)));
532 assert_eq!(s.apply(&ty), Type::Optional(Box::new(int())));
533 }
534
535 #[test]
536 fn subst_apply_tuple() {
537 let mut s = Substitution::new();
538 s.bind(0, int());
539 s.bind(1, bool_ty());
540 let ty = Type::Tuple(vec![var(0), var(1)]);
541 assert_eq!(s.apply(&ty), Type::Tuple(vec![int(), bool_ty()]));
542 }
543
544 #[test]
545 fn subst_apply_function() {
546 let mut s = Substitution::new();
547 s.bind(0, int());
548 s.bind(1, bool_ty());
549 let ty = Type::Function(FnType {
550 params: vec![var(0)],
551 ret: Box::new(var(1)),
552 effects: vec![],
553 });
554 let result = s.apply(&ty);
555 assert_eq!(
556 result,
557 Type::Function(FnType {
558 params: vec![int()],
559 ret: Box::new(bool_ty()),
560 effects: vec![],
561 })
562 );
563 }
564
565 #[test]
568 fn unify_same_primitive() {
569 let mut s = Substitution::new();
570 assert!(unify(&int(), &int(), &mut s).is_ok());
571 }
572
573 #[test]
574 fn unify_different_primitives_fails() {
575 let mut s = Substitution::new();
576 assert!(matches!(
577 unify(&int(), &bool_ty(), &mut s),
578 Err(TypeError::Mismatch { .. })
579 ));
580 }
581
582 #[test]
583 fn unify_error_with_anything() {
584 let mut s = Substitution::new();
585 assert!(unify(&Type::Error, &int(), &mut s).is_ok());
586 assert!(unify(&bool_ty(), &Type::Error, &mut s).is_ok());
587 assert!(unify(&Type::Error, &Type::Error, &mut s).is_ok());
588 assert!(unify(&Type::Error, &var(0), &mut s).is_ok());
589 }
590
591 #[test]
592 fn unify_never_with_anything() {
593 let mut s = Substitution::new();
594 let never = Type::Primitive(PrimitiveType::Never);
595 assert!(unify(&never, &int(), &mut s).is_ok());
596 assert!(unify(&bool_ty(), &never, &mut s).is_ok());
597 assert!(unify(&never, &never, &mut s).is_ok());
598 assert!(unify(&never, &var(10), &mut s).is_ok());
599 }
600
601 #[test]
604 fn unify_var_with_concrete() {
605 let mut s = Substitution::new();
606 assert!(unify(&var(0), &int(), &mut s).is_ok());
607 assert_eq!(s.lookup(0), int());
608 }
609
610 #[test]
611 fn unify_concrete_with_var() {
612 let mut s = Substitution::new();
613 assert!(unify(&int(), &var(0), &mut s).is_ok());
614 assert_eq!(s.lookup(0), int());
615 }
616
617 #[test]
618 fn unify_var_with_var() {
619 let mut s = Substitution::new();
620 assert!(unify(&var(0), &var(1), &mut s).is_ok());
621 s.bind(1, int());
626 assert_eq!(s.lookup(0), int());
627 }
628
629 #[test]
632 fn occurs_check_prevents_infinite_type() {
633 let mut s = Substitution::new();
634 let ty = Type::Optional(Box::new(var(0)));
636 assert!(matches!(
637 unify(&var(0), &ty, &mut s),
638 Err(TypeError::OccursCheck { var: 0, .. })
639 ));
640 }
641
642 #[test]
643 fn occurs_check_list_generic() {
644 let mut s = Substitution::new();
645 let list_t = Type::Generic(GenericType {
646 constructor: "List".into(),
647 args: vec![var(0)],
648 });
649 assert!(matches!(
650 unify(&var(0), &list_t, &mut s),
651 Err(TypeError::OccursCheck { var: 0, .. })
652 ));
653 }
654
655 #[test]
658 fn unify_optional() {
659 let mut s = Substitution::new();
660 assert!(unify(
661 &Type::Optional(Box::new(var(0))),
662 &Type::Optional(Box::new(int())),
663 &mut s
664 )
665 .is_ok());
666 assert_eq!(s.lookup(0), int());
667 }
668
669 #[test]
670 fn unify_result() {
671 let mut s = Substitution::new();
672 let a = Type::Result(Box::new(var(0)), Box::new(var(1)));
673 let b = Type::Result(Box::new(int()), Box::new(string_ty()));
674 assert!(unify(&a, &b, &mut s).is_ok());
675 assert_eq!(s.lookup(0), int());
676 assert_eq!(s.lookup(1), string_ty());
677 }
678
679 #[test]
680 fn unify_tuple_element_wise() {
681 let mut s = Substitution::new();
682 let a = Type::Tuple(vec![var(0), var(1)]);
683 let b = Type::Tuple(vec![int(), bool_ty()]);
684 assert!(unify(&a, &b, &mut s).is_ok());
685 assert_eq!(s.lookup(0), int());
686 assert_eq!(s.lookup(1), bool_ty());
687 }
688
689 #[test]
690 fn unify_tuple_arity_mismatch() {
691 let mut s = Substitution::new();
692 let a = Type::Tuple(vec![int(), bool_ty()]);
693 let b = Type::Tuple(vec![int()]);
694 assert!(matches!(
695 unify(&a, &b, &mut s),
696 Err(TypeError::TupleArity {
697 expected: 2,
698 found: 1
699 })
700 ));
701 }
702
703 #[test]
704 fn unify_function_types() {
705 let mut s = Substitution::new();
706 let a = Type::Function(FnType {
707 params: vec![var(0)],
708 ret: Box::new(var(1)),
709 effects: vec![],
710 });
711 let b = Type::Function(FnType {
712 params: vec![int()],
713 ret: Box::new(bool_ty()),
714 effects: vec![],
715 });
716 assert!(unify(&a, &b, &mut s).is_ok());
717 assert_eq!(s.lookup(0), int());
718 assert_eq!(s.lookup(1), bool_ty());
719 }
720
721 #[test]
722 fn unify_function_arity_mismatch() {
723 let mut s = Substitution::new();
724 let a = Type::Function(FnType {
725 params: vec![int(), bool_ty()],
726 ret: Box::new(int()),
727 effects: vec![],
728 });
729 let b = Type::Function(FnType {
730 params: vec![int()],
731 ret: Box::new(int()),
732 effects: vec![],
733 });
734 assert!(matches!(
735 unify(&a, &b, &mut s),
736 Err(TypeError::FnArity {
737 expected: 2,
738 found: 1
739 })
740 ));
741 }
742
743 #[test]
744 fn unify_generic_same_constructor() {
745 let mut s = Substitution::new();
746 let a = Type::Generic(GenericType {
747 constructor: "List".into(),
748 args: vec![var(0)],
749 });
750 let b = Type::Generic(GenericType {
751 constructor: "List".into(),
752 args: vec![int()],
753 });
754 assert!(unify(&a, &b, &mut s).is_ok());
755 assert_eq!(s.lookup(0), int());
756 }
757
758 #[test]
759 fn unify_generic_different_constructor_fails() {
760 let mut s = Substitution::new();
761 let a = Type::Generic(GenericType {
762 constructor: "List".into(),
763 args: vec![int()],
764 });
765 let b = Type::Generic(GenericType {
766 constructor: "Set".into(),
767 args: vec![int()],
768 });
769 assert!(matches!(
770 unify(&a, &b, &mut s),
771 Err(TypeError::Mismatch { .. })
772 ));
773 }
774
775 #[test]
778 fn unify_refined_base_types() {
779 let mut s = Substitution::new();
780 let a = Type::Refined(
781 Box::new(var(0)),
782 Predicate {
783 source: "self > 0".into(),
784 },
785 );
786 let b = Type::Refined(
787 Box::new(int()),
788 Predicate {
789 source: "self >= 0".into(),
790 },
791 );
792 assert!(unify(&a, &b, &mut s).is_ok());
793 assert_eq!(s.lookup(0), int());
794 }
795
796 #[test]
799 fn types_equal_same() {
800 let s = Substitution::new();
801 assert!(types_equal(&int(), &int(), &s));
802 }
803
804 #[test]
805 fn types_equal_different() {
806 let s = Substitution::new();
807 assert!(!types_equal(&int(), &bool_ty(), &s));
808 }
809
810 #[test]
811 fn types_equal_via_subst() {
812 let mut s = Substitution::new();
813 s.bind(0, int());
814 assert!(types_equal(&var(0), &int(), &s));
815 }
816
817 #[test]
820 fn all_primitive_variants() {
821 let prims = [
822 PrimitiveType::Int,
823 PrimitiveType::Float,
824 PrimitiveType::Int8,
825 PrimitiveType::Int16,
826 PrimitiveType::Int32,
827 PrimitiveType::Int64,
828 PrimitiveType::Int128,
829 PrimitiveType::UInt8,
830 PrimitiveType::UInt16,
831 PrimitiveType::UInt32,
832 PrimitiveType::UInt64,
833 PrimitiveType::Float32,
834 PrimitiveType::Float64,
835 PrimitiveType::BigInt,
836 PrimitiveType::BigFloat,
837 PrimitiveType::Decimal,
838 PrimitiveType::Bool,
839 PrimitiveType::Char,
840 PrimitiveType::String,
841 PrimitiveType::Byte,
842 PrimitiveType::Bytes,
843 PrimitiveType::Void,
844 PrimitiveType::Never,
845 ];
846 for p in &prims {
847 let ty = Type::Primitive(p.clone());
848 assert!(matches!(ty, Type::Primitive(_)));
849 }
850 }
851}