use crate::utils::metadata::Span;
use std::collections::HashSet;
use super::*;
#[derive(PartialEq, Eq, Debug, Clone, Copy)]
pub(crate) enum Relation {
Subtype,
Identical,
Supertype,
}
#[derive(PartialEq, Eq, Debug)]
pub(crate) enum Error {
TypeMismatch {
left: (TypeNodeId, Span),
right: (TypeNodeId, Span),
},
LengthMismatch {
left: (Vec<TypeNodeId>, Span),
right: (Vec<TypeNodeId>, Span),
},
CircularType {
left: Span,
right: Span,
},
ImcompatibleRecords {
left: (Vec<(Symbol, TypeNodeId)>, Span),
right: (Vec<(Symbol, TypeNodeId)>, Span),
},
}
fn is_dummy_span(span: &Span) -> bool {
span.start == 0 && span.end == 0
}
fn span_from_type_tree(t: TypeNodeId) -> Option<Span> {
let mut stack = vec![t];
let mut visited = HashSet::new();
while let Some(current) = stack.pop() {
if !visited.insert(current) {
continue;
}
let span = current.to_span();
if !is_dummy_span(&span) {
return Some(span);
}
match current.to_type() {
Type::Array(elem) | Type::Ref(elem) | Type::Code(elem) | Type::Boxed(elem) => {
stack.push(elem);
}
Type::Tuple(items) | Type::Union(items) => {
stack.extend(items.iter().copied());
}
Type::Record(fields) => {
stack.extend(fields.iter().map(|RecordTypeField { ty, .. }| *ty));
}
Type::Function { arg, ret } => {
stack.push(arg);
stack.push(ret);
}
Type::Intermediate(cell) => {
if let Some(parent) = cell.read().ok().and_then(|tv| tv.parent) {
stack.push(parent);
}
}
Type::UserSum { variants, .. } => {
stack.extend(variants.iter().filter_map(|(_, payload)| *payload));
}
Type::Primitive(_)
| Type::TypeScheme(_)
| Type::TypeAlias(_)
| Type::Any
| Type::Failure
| Type::Unknown => {}
}
}
None
}
fn best_span(primary: TypeNodeId, secondary: TypeNodeId) -> Span {
span_from_type_tree(primary)
.or_else(|| span_from_type_tree(secondary))
.unwrap_or_else(|| primary.to_span())
}
fn occur_check(id1: IntermediateId, t2: TypeNodeId) -> bool {
let cls = |t2dash: TypeNodeId| -> bool { occur_check(id1, t2dash) };
let vec_cls = |t: &[_]| -> bool { t.iter().any(|a| cls(*a)) };
match &t2.to_type() {
Type::Intermediate(cell) => cell
.read()
.map(|tv2| match tv2.parent {
Some(tid2) => id1 == tv2.var || occur_check(id1, tid2),
None => id1 == tv2.var,
})
.unwrap_or(true),
Type::Array(a) => cls(*a),
Type::Tuple(t) => vec_cls(t),
Type::Function { arg, ret } => cls(*arg) && cls(*ret),
Type::Record(s) => vec_cls(
s.iter()
.map(|RecordTypeField { ty, .. }| *ty)
.collect::<Vec<_>>()
.as_slice(),
),
Type::Union(types) => vec_cls(types),
Type::Boxed(b) => cls(*b),
_ => false,
}
}
fn unify_vec(a1: &[TypeNodeId], a2: &[TypeNodeId]) -> Result<Relation, Vec<Error>> {
assert_eq!(a1.len(), a2.len());
let (res, errs): (Vec<_>, Vec<_>) = a1
.iter()
.zip(a2)
.map(|(a1, a2)| unify_types(*a1, *a2))
.partition_result();
let errs: Vec<_> = errs.into_iter().flatten().collect();
let res_relation = if res.iter().all(|r| *r != Relation::Subtype) {
Relation::Supertype
} else if res.iter().all(|r| *r != Relation::Supertype) {
Relation::Subtype
} else {
return Err(errs);
};
Ok(res_relation)
}
fn unify_types_args(t1: TypeNodeId, t2: TypeNodeId) -> Result<Relation, Vec<Error>> {
log::trace!("unify_args {} and {}", t1.to_type(), t2.to_type());
let loc1 = best_span(t1, t2);
let loc2 = best_span(t2, t1);
let t1r = t1.get_root();
let t2r = t2.get_root();
let res = match &(t1r.to_type(), t2r.to_type()) {
(Type::Record(_), Type::Record(_)) | (Type::Tuple(_), Type::Tuple(_)) => {
unify_types(t1, t2)?
}
(Type::Record(v), _t) if v.len() == 1 => unify_types_args(v.first().unwrap().ty, t2)?,
(_t, Type::Record(v)) if v.len() == 1 && !v.first().unwrap().has_default => {
unify_types_args(t1, v.first().unwrap().ty)?
}
(_t, Type::Tuple(v)) if v.len() == 1 => unify_types_args(t1, *v.first().unwrap())?,
(Type::Tuple(v), _t) if v.len() == 1 => unify_types_args(*v.first().unwrap(), t2)?,
(Type::Intermediate(i1), Type::Intermediate(i2)) => {
let (tv1_eq, var1, level1, parent1) = {
let guard = i1.read().unwrap();
(guard.clone(), guard.var, guard.level, guard.parent)
};
let (tv2_eq, var2, level2, parent2) = {
let guard = i2.read().unwrap();
(guard.clone(), guard.var, guard.level, guard.parent)
};
if tv1_eq == tv2_eq {
return Ok(Relation::Identical);
}
if occur_check(var1, t2) {
return Err(vec![Error::CircularType {
left: loc1,
right: loc2,
}]);
}
if level2 > level1 {
i2.write().unwrap().level = level1;
}
match (parent1, parent2) {
(None, None) => {
if var1 > var2 {
i2.write().unwrap().parent = Some(t1r);
} else {
i1.write().unwrap().parent = Some(t2r);
};
}
(_, Some(p2)) => {
i1.write().unwrap().parent = Some(p2);
}
(Some(p1), _) => {
i2.write().unwrap().parent = Some(p1);
}
};
Relation::Identical
}
(Type::Intermediate(i1), _) => {
let var1 = i1.read().unwrap().var;
if occur_check(var1, t2r) {
return Err(vec![Error::CircularType {
left: loc1,
right: loc2,
}]);
}
let mut tv1 = i1.write().unwrap();
tv1.parent = Some(t2r);
tv1.bound.upper = t2r;
drop(tv1);
Relation::Identical
}
(_, Type::Intermediate(i2)) => {
let var2 = i2.read().unwrap().var;
if occur_check(var2, t1r) {
return Err(vec![Error::CircularType {
left: loc1,
right: loc2,
}]);
}
let mut tv2 = i2.write().unwrap();
tv2.parent = Some(t1r);
tv2.bound.upper = t1r;
drop(tv2); Relation::Identical
}
(Type::Record(kvs), Type::Tuple(_)) => {
let recordvec = kvs
.iter()
.map(|RecordTypeField { ty, .. }| *ty)
.collect::<Vec<_>>();
let loc_record = t1.to_loc();
let new_tup = Type::Tuple(recordvec).into_id_with_location(loc_record);
unify_types_args(new_tup, t2)?
}
(Type::Tuple(_), Type::Record(_)) => unify_types_args(t2, t1)?,
(Type::Union(union_types), _t) => {
for union_member in union_types {
if unify_types_args(*union_member, t2r).is_ok() {
return Ok(Relation::Identical);
}
}
return Err(vec![Error::TypeMismatch {
left: (t1, loc1.clone()),
right: (t2, loc2.clone()),
}]);
}
(_, _) => unify_types(t1, t2)?,
};
Ok(res)
}
pub(crate) fn unify_types(t1: TypeNodeId, t2: TypeNodeId) -> Result<Relation, Vec<Error>> {
let loc1 = best_span(t1, t2);
let loc2 = best_span(t2, t1);
let t1r = t1.get_root();
let t2r = t2.get_root();
let res = match &(t1r.to_type(), t2r.to_type()) {
(Type::Intermediate(i1), Type::Intermediate(i2)) => {
let (tv1_eq, var1, level1, parent1) = {
let guard = i1.read().unwrap();
(guard.clone(), guard.var, guard.level, guard.parent)
};
let (tv2_eq, var2, level2, parent2) = {
let guard = i2.read().unwrap();
(guard.clone(), guard.var, guard.level, guard.parent)
};
if tv1_eq == tv2_eq {
return Ok(Relation::Identical);
}
if occur_check(var1, t2) {
return Err(vec![Error::CircularType {
left: loc1,
right: loc2,
}]);
}
if level1 < level2 {
i1.write().unwrap().level = level2;
}
match (parent1, parent2) {
(None, None) => {
if var1 > var2 {
i2.write().unwrap().parent = Some(t1r);
} else {
i1.write().unwrap().parent = Some(t2r);
};
}
(_, Some(p2)) => {
i1.write().unwrap().parent = Some(p2);
}
(Some(p1), _) => {
i2.write().unwrap().parent = Some(p1);
}
};
Relation::Identical
}
(Type::Intermediate(i1), _) => {
let var1 = i1.read().unwrap().var;
if occur_check(var1, t2r) {
return Err(vec![Error::CircularType {
left: loc1,
right: loc2,
}]);
}
let mut tv1 = i1.write().unwrap();
tv1.parent = Some(t2r);
tv1.bound.lower = t2r;
drop(tv1);
Relation::Identical
}
(_, Type::Intermediate(i2)) => {
let var2 = i2.read().unwrap().var;
if occur_check(var2, t1r) {
return Err(vec![Error::CircularType {
left: loc1,
right: loc2,
}]);
}
let mut tv2 = i2.write().unwrap();
tv2.parent = Some(t1r);
tv2.bound.lower = t1r;
drop(tv2); Relation::Identical
}
(Type::Array(a1), Type::Array(a2)) => {
let res = unify_types(*a1, *a2)?;
match res {
Relation::Identical => Relation::Identical,
_ => {
return Err(vec![Error::TypeMismatch {
left: (*a1, loc1.clone()),
right: (*a2, loc2.clone()),
}]);
}
}
}
(Type::Ref(x1), Type::Ref(x2)) => unify_types(*x1, *x2)?,
(Type::Tuple(a1), Type::Tuple(a2)) => {
use std::cmp::Ordering;
match a1.len().cmp(&a2.len()) {
Ordering::Equal => {
let _ = unify_vec(a1, a2)?;
Relation::Identical
}
_ => {
return Err(vec![Error::LengthMismatch {
left: (a1.to_vec(), loc1),
right: (a2.to_vec(), loc2),
}]);
}
}
}
(Type::Record(a1), Type::Record(a2)) => {
let keys_a = a1.iter().sorted_by(move |a, b| {
let keya = a.key;
let keyb = b.key;
keya.as_str().cmp(keyb.as_str())
});
let keys_b = a2.iter().sorted_by(move |a, b| {
let keya = a.key;
let keyb = b.key;
keya.as_str().cmp(keyb.as_str())
});
let allkeys = keys_a
.clone()
.chain(keys_b.clone())
.unique_by(|RecordTypeField { key, .. }| key.as_str());
let sparse_fields1 = allkeys.clone().map(|parent| {
a1.iter()
.find(|RecordTypeField { key, .. }| parent.key == *key)
.or(parent.has_default.then_some(parent))
});
let sparse_fields2 = allkeys.map(|parent| {
a2.iter()
.find(|RecordTypeField { key, .. }| parent.key == *key)
.or(parent.has_default.then_some(parent))
});
#[derive(PartialEq, Eq, Debug)]
enum SearchRes {
Both,
A,
B,
}
let searchresults = sparse_fields1.zip(sparse_fields2).map(|pair| match pair {
(Some(s1), Some(s2)) => unify_types(s1.ty, s2.ty).map(|_| SearchRes::Both),
(Some(_), None) => Ok(SearchRes::A),
(None, Some(_)) => Ok(SearchRes::B),
(None, None) => unreachable!(),
});
log::trace!(
"unify_records {} and {}: {:?}",
t1,
t2,
searchresults.clone().collect_vec()
);
let all_both = searchresults
.clone()
.all(|r| r.is_ok_and(|r| r == SearchRes::Both));
let collected_errs = searchresults
.clone()
.filter_map(|r| r.err())
.flatten()
.collect::<Vec<_>>();
let mut all_errs = vec![];
let contains_err = !collected_errs.is_empty();
if contains_err {
all_errs = collected_errs;
}
let contains_a = searchresults
.clone()
.any(|r| r.is_ok_and(|r| r == SearchRes::A));
let contains_b = searchresults
.clone()
.any(|r| r.is_ok_and(|r| r == SearchRes::B));
if all_both {
Relation::Identical
} else if !contains_err && contains_a && !contains_b {
Relation::Supertype
} else if !contains_err && contains_b && !contains_a {
Relation::Subtype
} else if contains_b && contains_a {
let keys_a = a1
.iter()
.map(|RecordTypeField { key, ty, .. }| (*key, *ty))
.collect::<Vec<_>>();
let keys_b = a2
.iter()
.map(|RecordTypeField { key, ty, .. }| (*key, *ty))
.collect::<Vec<_>>();
all_errs.push(Error::ImcompatibleRecords {
left: (keys_a, loc1.clone()),
right: (keys_b, loc2.clone()),
});
return Err(all_errs);
} else {
return Err(all_errs);
}
}
(
Type::Function {
arg: arg1,
ret: ret1,
},
Type::Function {
arg: arg2,
ret: ret2,
},
) => {
let arg_res = unify_types_args(*arg1, *arg2);
let ret_res = unify_types(*ret1, *ret2);
match (arg_res, ret_res) {
(Ok(Relation::Subtype), Ok(_)) | (Ok(_), Ok(Relation::Supertype)) => {
return Err(vec![Error::TypeMismatch {
left: (t1, loc1.clone()),
right: (t2, loc2.clone()),
}]);
}
(Ok(Relation::Identical), Ok(Relation::Identical)) => Relation::Identical,
(Ok(_), Err(errs)) | (Err(errs), Ok(_)) => {
return Err(errs);
}
(Err(mut e1), Err(mut e2)) => {
e1.append(&mut e2);
return Err(e1);
}
_ => Relation::Subtype,
}
}
(Type::Primitive(p1), Type::Primitive(p2)) if p1 == p2 => Relation::Identical,
(Type::TypeScheme(s1), Type::TypeScheme(s2)) if s1 == s2 => Relation::Identical,
(Type::TypeScheme(_), _) | (_, Type::TypeScheme(_)) => {
return Err(vec![Error::TypeMismatch {
left: (t1, loc1.clone()),
right: (t2, loc2.clone()),
}]);
}
(Type::Primitive(PType::Unit), Type::Tuple(v))
| (Type::Tuple(v), Type::Primitive(PType::Unit))
if v.is_empty() =>
{
Relation::Identical
}
(_t, Type::Tuple(v)) if v.len() == 1 => unify_types(t1, *v.first().unwrap())?,
(Type::Tuple(v), _t) if v.len() == 1 => unify_types(*v.first().unwrap(), t2)?,
(Type::Primitive(PType::Unit), Type::Record(v))
| (Type::Record(v), Type::Primitive(PType::Unit))
if v.is_empty() =>
{
Relation::Identical
}
(Type::Failure, _t) | (_t, Type::Any) => Relation::Identical,
(Type::Any, _t) | (_t, Type::Failure) => Relation::Identical,
(Type::Code(p1), Type::Code(p2)) => unify_types(*p1, *p2)?,
(Type::Union(union_types1), Type::Union(union_types2)) => {
if union_types1.len() != union_types2.len() {
return Err(vec![Error::TypeMismatch {
left: (t1, loc1.clone()),
right: (t2, loc2.clone()),
}]);
}
let all_match = union_types1.iter().all(|m1| {
union_types2
.iter()
.any(|m2| unify_types(*m1, *m2).is_ok_and(|r| r == Relation::Identical))
});
if all_match {
Relation::Identical
} else {
return Err(vec![Error::TypeMismatch {
left: (t1, loc1.clone()),
right: (t2, loc2.clone()),
}]);
}
}
(_t, Type::Union(union_types)) => {
for union_member in union_types {
if unify_types(t1r, *union_member).is_ok() {
return Ok(Relation::Subtype);
}
}
return Err(vec![Error::TypeMismatch {
left: (t1, loc1.clone()),
right: (t2, loc2.clone()),
}]);
}
(Type::Union(union_types), _t) => {
let all_match = union_types
.iter()
.all(|union_member| unify_types(*union_member, t2r).is_ok());
if all_match {
return Ok(Relation::Supertype);
}
return Err(vec![Error::TypeMismatch {
left: (t1, loc1.clone()),
right: (t2, loc2.clone()),
}]);
}
(Type::UserSum { name: n1, .. }, Type::UserSum { name: n2, .. }) => {
if n1 == n2 {
Relation::Identical
} else {
return Err(vec![Error::TypeMismatch {
left: (t1, loc1.clone()),
right: (t2, loc2.clone()),
}]);
}
}
(Type::Boxed(b1), Type::Boxed(b2)) => unify_types(*b1, *b2)?,
(Type::Boxed(inner), _t2) => {
let inner_res = unify_types(*inner, t2r);
match inner_res {
Ok(_) => Relation::Identical,
Err(_) => {
return Err(vec![Error::TypeMismatch {
left: (t1, loc1.clone()),
right: (t2, loc2.clone()),
}]);
}
}
}
(_t1, Type::Boxed(inner)) => {
let inner_res = unify_types(t1r, *inner);
match inner_res {
Ok(_) => Relation::Identical,
Err(_) => {
return Err(vec![Error::TypeMismatch {
left: (t1, loc1.clone()),
right: (t2, loc2.clone()),
}]);
}
}
}
(_p1, _p2) => {
return Err(vec![Error::TypeMismatch {
left: (t1, loc1),
right: (t2, loc2),
}]);
}
};
log::trace!("unified {} and {}:{:?}", t1.to_type(), t2.to_type(), res);
Ok(res)
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use super::{Error, unify_types};
use crate::{
types::{PType, Type},
utils::metadata::Location,
};
#[test]
fn length_mismatch_keeps_both_spans() {
let left_span = 10..20;
let right_span = 30..40;
let left_elem = Type::Primitive(PType::Numeric)
.into_id_with_location(Location::new(left_span.clone(), PathBuf::from("left.mmm")));
let right_elem = Type::Primitive(PType::Numeric).into_id_with_location(Location::new(
right_span.clone(),
PathBuf::from("right.mmm"),
));
let left = Type::Tuple(vec![left_elem])
.into_id_with_location(Location::new(left_span.clone(), PathBuf::from("left.mmm")));
let right = Type::Tuple(vec![right_elem, right_elem]).into_id_with_location(Location::new(
right_span.clone(),
PathBuf::from("right.mmm"),
));
let err = unify_types(left, right).expect_err("expected tuple length mismatch");
let Some(Error::LengthMismatch {
left: (_, lspan),
right: (_, rspan),
}) = err.first()
else {
panic!("unexpected error variant");
};
assert_eq!(lspan, &left_span);
assert_eq!(rspan, &right_span);
}
#[test]
fn length_mismatch_avoids_dummy_span_when_one_side_has_location() {
let right_span = 30..40;
let left_elem = Type::Primitive(PType::Numeric).into_id();
let right_elem = Type::Primitive(PType::Numeric).into_id_with_location(Location::new(
right_span.clone(),
PathBuf::from("right.mmm"),
));
let left = Type::Tuple(vec![left_elem]).into_id();
let right = Type::Tuple(vec![right_elem, right_elem]).into_id_with_location(Location::new(
right_span.clone(),
PathBuf::from("right.mmm"),
));
let err = unify_types(left, right).expect_err("expected tuple length mismatch");
let Some(Error::LengthMismatch {
left: (_, lspan),
right: (_, rspan),
}) = err.first()
else {
panic!("unexpected error variant");
};
assert_eq!(lspan, &right_span);
assert_eq!(rspan, &right_span);
}
}