use std::collections::HashMap;
pub use bock_air::stubs::EffectRef;
pub mod checker;
pub use checker::{TypeChecker, TypeEnv};
pub mod traits;
pub use traits::{
check_supertrait_obligations, resolve_impl, resolve_method, ImplId, ImplTable, ResolvedMethod,
TraitRef,
};
pub mod ownership;
pub use ownership::{analyze_ownership, AIRModule, OwnershipInfo, OwnershipState};
pub mod effects;
pub use effects::{infer_effects, track_effects, Strictness};
pub mod capabilities;
pub use capabilities::{compute_capabilities, verify_capabilities, CapabilitySet};
pub mod exports;
pub use exports::{collect_exports, type_to_type_ref};
pub mod seed_imports;
pub use seed_imports::seed_imports;
pub mod vocab;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum PrimitiveType {
Int,
Float,
Int8,
Int16,
Int32,
Int64,
Int128,
UInt8,
UInt16,
UInt32,
UInt64,
Float32,
Float64,
BigInt,
BigFloat,
Decimal,
Bool,
Char,
String,
Byte,
Bytes,
Void,
Never,
}
#[derive(Debug, Clone, PartialEq)]
pub struct NamedType {
pub name: String,
}
#[derive(Debug, Clone, PartialEq)]
pub struct GenericType {
pub constructor: String,
pub args: Vec<Type>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct FnType {
pub params: Vec<Type>,
pub ret: Box<Type>,
pub effects: Vec<EffectRef>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Predicate {
pub source: String,
}
#[derive(Debug, Clone, PartialEq, Default)]
pub struct StructuralConstraints {
pub fields: Vec<(String, Type)>,
}
pub type TypeVarId = u32;
#[derive(Debug, Clone, PartialEq)]
pub enum Type {
Primitive(PrimitiveType),
Named(NamedType),
Generic(GenericType),
Tuple(Vec<Type>),
Function(FnType),
Optional(Box<Type>),
Result(Box<Type>, Box<Type>),
TypeVar(TypeVarId),
Refined(Box<Type>, Predicate),
Flexible(StructuralConstraints),
Error,
}
#[derive(Debug, Clone, PartialEq)]
pub enum TypeError {
Mismatch {
left: Type,
right: Type,
},
OccursCheck {
var: TypeVarId,
ty: Type,
},
TupleArity { expected: usize, found: usize },
FnArity { expected: usize, found: usize },
}
impl std::fmt::Display for TypeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TypeError::Mismatch { left, right } => {
write!(f, "type mismatch: {left:?} vs {right:?}")
}
TypeError::OccursCheck { var, ty } => {
write!(f, "occurs check failed: ?{var} in {ty:?}")
}
TypeError::TupleArity { expected, found } => {
write!(
f,
"tuple arity mismatch: expected {expected}, found {found}"
)
}
TypeError::FnArity { expected, found } => {
write!(
f,
"function arity mismatch: expected {expected}, found {found}"
)
}
}
}
}
impl std::error::Error for TypeError {}
#[derive(Debug, Clone, Default)]
pub struct Substitution {
map: HashMap<TypeVarId, Type>,
}
impl Substitution {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn lookup(&self, mut id: TypeVarId) -> Type {
loop {
match self.map.get(&id) {
None => return Type::TypeVar(id),
Some(Type::TypeVar(next)) => {
id = *next;
}
Some(ty) => return ty.clone(),
}
}
}
pub fn bind(&mut self, id: TypeVarId, ty: Type) {
debug_assert!(
!self.map.contains_key(&id),
"TypeVar ?{id} is already bound"
);
self.map.insert(id, ty);
}
#[must_use]
pub fn apply(&self, ty: &Type) -> Type {
match ty {
Type::TypeVar(id) => {
let resolved = self.lookup(*id);
if resolved == *ty {
resolved
} else {
self.apply(&resolved)
}
}
Type::Primitive(_) | Type::Error => ty.clone(),
Type::Named(_) => ty.clone(),
Type::Generic(g) => Type::Generic(GenericType {
constructor: g.constructor.clone(),
args: g.args.iter().map(|a| self.apply(a)).collect(),
}),
Type::Tuple(elems) => Type::Tuple(elems.iter().map(|e| self.apply(e)).collect()),
Type::Function(f) => Type::Function(FnType {
params: f.params.iter().map(|p| self.apply(p)).collect(),
ret: Box::new(self.apply(&f.ret)),
effects: f.effects.clone(),
}),
Type::Optional(inner) => Type::Optional(Box::new(self.apply(inner))),
Type::Result(ok, err) => {
Type::Result(Box::new(self.apply(ok)), Box::new(self.apply(err)))
}
Type::Refined(base, pred) => Type::Refined(Box::new(self.apply(base)), pred.clone()),
Type::Flexible(constraints) => Type::Flexible(StructuralConstraints {
fields: constraints
.fields
.iter()
.map(|(name, ty)| (name.clone(), self.apply(ty)))
.collect(),
}),
}
}
#[must_use]
pub fn is_unbound(&self, id: TypeVarId) -> bool {
matches!(self.lookup(id), Type::TypeVar(_))
}
}
fn occurs(id: TypeVarId, ty: &Type, subst: &Substitution) -> bool {
match ty {
Type::TypeVar(other) => {
let resolved = subst.lookup(*other);
match resolved {
Type::TypeVar(rid) => rid == id,
_ => occurs(id, &resolved, subst),
}
}
Type::Primitive(_) | Type::Named(_) | Type::Error => false,
Type::Generic(g) => g.args.iter().any(|a| occurs(id, a, subst)),
Type::Tuple(elems) => elems.iter().any(|e| occurs(id, e, subst)),
Type::Function(f) => {
f.params.iter().any(|p| occurs(id, p, subst)) || occurs(id, &f.ret, subst)
}
Type::Optional(inner) => occurs(id, inner, subst),
Type::Result(ok, err) => occurs(id, ok, subst) || occurs(id, err, subst),
Type::Refined(base, _) => occurs(id, base, subst),
Type::Flexible(c) => c.fields.iter().any(|(_, t)| occurs(id, t, subst)),
}
}
pub fn unify(a: &Type, b: &Type, subst: &mut Substitution) -> Result<(), TypeError> {
let a = subst.apply(a);
let b = subst.apply(b);
match (&a, &b) {
(Type::Error, _) | (_, Type::Error) => Ok(()),
(Type::Primitive(PrimitiveType::Never), _)
| (_, Type::Primitive(PrimitiveType::Never)) => Ok(()),
_ if a == b => Ok(()),
(Type::TypeVar(id), other) | (other, Type::TypeVar(id)) => {
let id = *id;
if occurs(id, other, subst) {
return Err(TypeError::OccursCheck {
var: id,
ty: other.clone(),
});
}
subst.bind(id, other.clone());
Ok(())
}
(Type::Optional(a_inner), Type::Optional(b_inner)) => unify(a_inner, b_inner, subst),
(Type::Result(a_ok, a_err), Type::Result(b_ok, b_err)) => {
unify(a_ok, b_ok, subst)?;
unify(a_err, b_err, subst)
}
(Type::Tuple(a_elems), Type::Tuple(b_elems)) => {
if a_elems.len() != b_elems.len() {
return Err(TypeError::TupleArity {
expected: a_elems.len(),
found: b_elems.len(),
});
}
for (ae, be) in a_elems.iter().zip(b_elems.iter()) {
unify(ae, be, subst)?;
}
Ok(())
}
(Type::Function(fa), Type::Function(fb)) => {
if fa.params.len() != fb.params.len() {
return Err(TypeError::FnArity {
expected: fa.params.len(),
found: fb.params.len(),
});
}
for (ap, bp) in fa.params.iter().zip(fb.params.iter()) {
unify(ap, bp, subst)?;
}
unify(&fa.ret, &fb.ret, subst)
}
(Type::Generic(ga), Type::Generic(gb)) => {
if ga.constructor != gb.constructor {
return Err(TypeError::Mismatch {
left: a.clone(),
right: b.clone(),
});
}
if ga.args.len() != gb.args.len() {
return Err(TypeError::Mismatch {
left: a.clone(),
right: b.clone(),
});
}
for (aa, ba) in ga.args.iter().zip(gb.args.iter()) {
unify(aa, ba, subst)?;
}
Ok(())
}
(Type::Refined(base_a, _), Type::Refined(base_b, _)) => unify(base_a, base_b, subst),
(Type::Named(nt), Type::Generic(g)) | (Type::Generic(g), Type::Named(nt))
if nt.name == g.constructor =>
{
Ok(())
}
_ => Err(TypeError::Mismatch {
left: a.clone(),
right: b.clone(),
}),
}
}
#[must_use]
pub fn types_equal(a: &Type, b: &Type, subst: &Substitution) -> bool {
let mut scratch = subst.clone();
unify(a, b, &mut scratch).is_ok()
}
#[cfg(test)]
mod tests {
use super::*;
fn int() -> Type {
Type::Primitive(PrimitiveType::Int)
}
fn bool_ty() -> Type {
Type::Primitive(PrimitiveType::Bool)
}
fn string_ty() -> Type {
Type::Primitive(PrimitiveType::String)
}
fn var(id: TypeVarId) -> Type {
Type::TypeVar(id)
}
#[test]
fn subst_lookup_unbound() {
let s = Substitution::new();
assert_eq!(s.lookup(0), var(0));
}
#[test]
fn subst_bind_and_lookup() {
let mut s = Substitution::new();
s.bind(0, int());
assert_eq!(s.lookup(0), int());
}
#[test]
fn subst_chain_lookup() {
let mut s = Substitution::new();
s.bind(0, var(1));
s.bind(1, int());
assert_eq!(s.lookup(0), int());
}
#[test]
fn subst_apply_nested() {
let mut s = Substitution::new();
s.bind(0, int());
let ty = Type::Optional(Box::new(var(0)));
assert_eq!(s.apply(&ty), Type::Optional(Box::new(int())));
}
#[test]
fn subst_apply_tuple() {
let mut s = Substitution::new();
s.bind(0, int());
s.bind(1, bool_ty());
let ty = Type::Tuple(vec![var(0), var(1)]);
assert_eq!(s.apply(&ty), Type::Tuple(vec![int(), bool_ty()]));
}
#[test]
fn subst_apply_function() {
let mut s = Substitution::new();
s.bind(0, int());
s.bind(1, bool_ty());
let ty = Type::Function(FnType {
params: vec![var(0)],
ret: Box::new(var(1)),
effects: vec![],
});
let result = s.apply(&ty);
assert_eq!(
result,
Type::Function(FnType {
params: vec![int()],
ret: Box::new(bool_ty()),
effects: vec![],
})
);
}
#[test]
fn unify_same_primitive() {
let mut s = Substitution::new();
assert!(unify(&int(), &int(), &mut s).is_ok());
}
#[test]
fn unify_different_primitives_fails() {
let mut s = Substitution::new();
assert!(matches!(
unify(&int(), &bool_ty(), &mut s),
Err(TypeError::Mismatch { .. })
));
}
#[test]
fn unify_error_with_anything() {
let mut s = Substitution::new();
assert!(unify(&Type::Error, &int(), &mut s).is_ok());
assert!(unify(&bool_ty(), &Type::Error, &mut s).is_ok());
assert!(unify(&Type::Error, &Type::Error, &mut s).is_ok());
assert!(unify(&Type::Error, &var(0), &mut s).is_ok());
}
#[test]
fn unify_never_with_anything() {
let mut s = Substitution::new();
let never = Type::Primitive(PrimitiveType::Never);
assert!(unify(&never, &int(), &mut s).is_ok());
assert!(unify(&bool_ty(), &never, &mut s).is_ok());
assert!(unify(&never, &never, &mut s).is_ok());
assert!(unify(&never, &var(10), &mut s).is_ok());
}
#[test]
fn unify_var_with_concrete() {
let mut s = Substitution::new();
assert!(unify(&var(0), &int(), &mut s).is_ok());
assert_eq!(s.lookup(0), int());
}
#[test]
fn unify_concrete_with_var() {
let mut s = Substitution::new();
assert!(unify(&int(), &var(0), &mut s).is_ok());
assert_eq!(s.lookup(0), int());
}
#[test]
fn unify_var_with_var() {
let mut s = Substitution::new();
assert!(unify(&var(0), &var(1), &mut s).is_ok());
s.bind(1, int());
assert_eq!(s.lookup(0), int());
}
#[test]
fn occurs_check_prevents_infinite_type() {
let mut s = Substitution::new();
let ty = Type::Optional(Box::new(var(0)));
assert!(matches!(
unify(&var(0), &ty, &mut s),
Err(TypeError::OccursCheck { var: 0, .. })
));
}
#[test]
fn occurs_check_list_generic() {
let mut s = Substitution::new();
let list_t = Type::Generic(GenericType {
constructor: "List".into(),
args: vec![var(0)],
});
assert!(matches!(
unify(&var(0), &list_t, &mut s),
Err(TypeError::OccursCheck { var: 0, .. })
));
}
#[test]
fn unify_optional() {
let mut s = Substitution::new();
assert!(unify(
&Type::Optional(Box::new(var(0))),
&Type::Optional(Box::new(int())),
&mut s
)
.is_ok());
assert_eq!(s.lookup(0), int());
}
#[test]
fn unify_result() {
let mut s = Substitution::new();
let a = Type::Result(Box::new(var(0)), Box::new(var(1)));
let b = Type::Result(Box::new(int()), Box::new(string_ty()));
assert!(unify(&a, &b, &mut s).is_ok());
assert_eq!(s.lookup(0), int());
assert_eq!(s.lookup(1), string_ty());
}
#[test]
fn unify_tuple_element_wise() {
let mut s = Substitution::new();
let a = Type::Tuple(vec![var(0), var(1)]);
let b = Type::Tuple(vec![int(), bool_ty()]);
assert!(unify(&a, &b, &mut s).is_ok());
assert_eq!(s.lookup(0), int());
assert_eq!(s.lookup(1), bool_ty());
}
#[test]
fn unify_tuple_arity_mismatch() {
let mut s = Substitution::new();
let a = Type::Tuple(vec![int(), bool_ty()]);
let b = Type::Tuple(vec![int()]);
assert!(matches!(
unify(&a, &b, &mut s),
Err(TypeError::TupleArity {
expected: 2,
found: 1
})
));
}
#[test]
fn unify_function_types() {
let mut s = Substitution::new();
let a = Type::Function(FnType {
params: vec![var(0)],
ret: Box::new(var(1)),
effects: vec![],
});
let b = Type::Function(FnType {
params: vec![int()],
ret: Box::new(bool_ty()),
effects: vec![],
});
assert!(unify(&a, &b, &mut s).is_ok());
assert_eq!(s.lookup(0), int());
assert_eq!(s.lookup(1), bool_ty());
}
#[test]
fn unify_function_arity_mismatch() {
let mut s = Substitution::new();
let a = Type::Function(FnType {
params: vec![int(), bool_ty()],
ret: Box::new(int()),
effects: vec![],
});
let b = Type::Function(FnType {
params: vec![int()],
ret: Box::new(int()),
effects: vec![],
});
assert!(matches!(
unify(&a, &b, &mut s),
Err(TypeError::FnArity {
expected: 2,
found: 1
})
));
}
#[test]
fn unify_generic_same_constructor() {
let mut s = Substitution::new();
let a = Type::Generic(GenericType {
constructor: "List".into(),
args: vec![var(0)],
});
let b = Type::Generic(GenericType {
constructor: "List".into(),
args: vec![int()],
});
assert!(unify(&a, &b, &mut s).is_ok());
assert_eq!(s.lookup(0), int());
}
#[test]
fn unify_generic_different_constructor_fails() {
let mut s = Substitution::new();
let a = Type::Generic(GenericType {
constructor: "List".into(),
args: vec![int()],
});
let b = Type::Generic(GenericType {
constructor: "Set".into(),
args: vec![int()],
});
assert!(matches!(
unify(&a, &b, &mut s),
Err(TypeError::Mismatch { .. })
));
}
#[test]
fn unify_refined_base_types() {
let mut s = Substitution::new();
let a = Type::Refined(
Box::new(var(0)),
Predicate {
source: "self > 0".into(),
},
);
let b = Type::Refined(
Box::new(int()),
Predicate {
source: "self >= 0".into(),
},
);
assert!(unify(&a, &b, &mut s).is_ok());
assert_eq!(s.lookup(0), int());
}
#[test]
fn types_equal_same() {
let s = Substitution::new();
assert!(types_equal(&int(), &int(), &s));
}
#[test]
fn types_equal_different() {
let s = Substitution::new();
assert!(!types_equal(&int(), &bool_ty(), &s));
}
#[test]
fn types_equal_via_subst() {
let mut s = Substitution::new();
s.bind(0, int());
assert!(types_equal(&var(0), &int(), &s));
}
#[test]
fn all_primitive_variants() {
let prims = [
PrimitiveType::Int,
PrimitiveType::Float,
PrimitiveType::Int8,
PrimitiveType::Int16,
PrimitiveType::Int32,
PrimitiveType::Int64,
PrimitiveType::Int128,
PrimitiveType::UInt8,
PrimitiveType::UInt16,
PrimitiveType::UInt32,
PrimitiveType::UInt64,
PrimitiveType::Float32,
PrimitiveType::Float64,
PrimitiveType::BigInt,
PrimitiveType::BigFloat,
PrimitiveType::Decimal,
PrimitiveType::Bool,
PrimitiveType::Char,
PrimitiveType::String,
PrimitiveType::Byte,
PrimitiveType::Bytes,
PrimitiveType::Void,
PrimitiveType::Never,
];
for p in &prims {
let ty = Type::Primitive(p.clone());
assert!(matches!(ty, Type::Primitive(_)));
}
}
}