use super::types::{SynthField, SynthRecord, SynthType, SynthUnion};
pub fn unify(t1: SynthType, t2: SynthType) -> SynthType {
if t1 == t2 {
return t1;
}
match (t1, t2) {
(SynthType::Any, t) | (t, SynthType::Any) => t,
(SynthType::Never, t) | (t, SynthType::Never) => t,
(SynthType::Hole(_), t) | (t, SynthType::Hole(_)) => t,
(SynthType::Union(u1), SynthType::Union(u2)) => {
let all_variants = u1.variants.into_iter().chain(u2.variants);
SynthUnion::from_variants(all_variants)
}
(SynthType::Union(u), t) | (t, SynthType::Union(u)) => {
let all_variants = u.variants.into_iter().chain(std::iter::once(t));
SynthUnion::from_variants(all_variants)
}
(SynthType::Array(a), SynthType::Array(b)) => SynthType::Array(Box::new(unify(*a, *b))),
(SynthType::Tuple(a), SynthType::Tuple(b)) if a.len() == b.len() => {
let unified: Vec<_> = a.into_iter().zip(b).map(|(x, y)| unify(x, y)).collect();
SynthType::Tuple(unified)
}
(SynthType::Record(r1), SynthType::Record(r2)) => unify_records(r1, r2),
(SynthType::Text(l1), SynthType::Text(l2)) => {
match (l1, l2) {
(None, l) | (l, None) => SynthType::Text(l),
(Some(a), Some(b)) if a == b => SynthType::Text(Some(a)),
(Some(a), Some(b)) => {
SynthUnion::from_variants([SynthType::Text(Some(a)), SynthType::Text(Some(b))])
}
}
}
(t1, t2) => SynthUnion::from_variants([t1, t2]),
}
}
fn unify_records(r1: SynthRecord, r2: SynthRecord) -> SynthType {
let keys1: std::collections::HashSet<_> = r1.fields.keys().collect();
let keys2: std::collections::HashSet<_> = r2.fields.keys().collect();
if keys1 == keys2 {
let mut fields = std::collections::HashMap::new();
for (name, f1) in r1.fields {
let f2 = r2.fields.get(&name).unwrap();
fields.insert(
name,
SynthField {
ty: unify(f1.ty, f2.ty.clone()),
optional: f1.optional || f2.optional,
},
);
}
SynthType::Record(SynthRecord { fields })
} else {
SynthUnion::from_variants([SynthType::Record(r1), SynthType::Record(r2)])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_unify_same() {
assert_eq!(
unify(SynthType::Integer, SynthType::Integer),
SynthType::Integer
);
assert_eq!(
unify(
SynthType::Text(Some("rust".into())),
SynthType::Text(Some("rust".into()))
),
SynthType::Text(Some("rust".into()))
);
}
#[test]
fn test_unify_any() {
assert_eq!(
unify(SynthType::Any, SynthType::Integer),
SynthType::Integer
);
assert_eq!(
unify(SynthType::Integer, SynthType::Any),
SynthType::Integer
);
}
#[test]
fn test_unify_never() {
assert_eq!(
unify(SynthType::Never, SynthType::Integer),
SynthType::Integer
);
assert_eq!(
unify(SynthType::Integer, SynthType::Never),
SynthType::Integer
);
}
#[test]
fn test_unify_hole() {
assert_eq!(
unify(SynthType::Hole(None), SynthType::Integer),
SynthType::Integer
);
assert_eq!(
unify(SynthType::Integer, SynthType::Hole(None)),
SynthType::Integer
);
}
#[test]
fn test_unify_different_primitives() {
assert_eq!(
unify(SynthType::Integer, SynthType::Boolean),
SynthType::Union(SynthUnion {
variants: vec![SynthType::Integer, SynthType::Boolean]
})
);
}
#[test]
fn test_unify_arrays() {
let arr1 = SynthType::Array(Box::new(SynthType::Integer));
let arr2 = SynthType::Array(Box::new(SynthType::Boolean));
assert_eq!(
unify(arr1, arr2),
SynthType::Array(Box::new(SynthType::Union(SynthUnion {
variants: vec![SynthType::Integer, SynthType::Boolean]
})))
);
}
#[test]
fn test_unify_tuples_same_length() {
let t1 = SynthType::Tuple(vec![SynthType::Integer, SynthType::Boolean]);
let t2 = SynthType::Tuple(vec![SynthType::Integer, SynthType::Integer]);
assert_eq!(
unify(t1, t2),
SynthType::Tuple(vec![
SynthType::Integer,
SynthType::Union(SynthUnion {
variants: vec![SynthType::Boolean, SynthType::Integer]
})
])
);
}
#[test]
fn test_unify_tuples_different_length() {
let t1 = SynthType::Tuple(vec![SynthType::Integer]);
let t2 = SynthType::Tuple(vec![SynthType::Integer, SynthType::Boolean]);
assert_eq!(
unify(t1.clone(), t2.clone()),
SynthType::Union(SynthUnion {
variants: vec![t1, t2]
})
);
}
#[test]
fn test_unify_records_same_shape() {
let r1 = SynthRecord::new([
("a".into(), SynthField::required(SynthType::Integer)),
("b".into(), SynthField::required(SynthType::Boolean)),
]);
let r2 = SynthRecord::new([
("a".into(), SynthField::required(SynthType::Text(None))),
("b".into(), SynthField::required(SynthType::Boolean)),
]);
let expected = SynthType::Record(SynthRecord::new([
(
"a".into(),
SynthField::required(SynthType::Union(SynthUnion {
variants: vec![SynthType::Integer, SynthType::Text(None)],
})),
),
("b".into(), SynthField::required(SynthType::Boolean)),
]));
assert_eq!(
unify(SynthType::Record(r1), SynthType::Record(r2)),
expected
);
}
#[test]
fn test_unify_records_different_shape() {
let r1 = SynthRecord::new([("a".into(), SynthField::required(SynthType::Integer))]);
let r2 = SynthRecord::new([
("a".into(), SynthField::required(SynthType::Text(None))),
("b".into(), SynthField::required(SynthType::Boolean)),
]);
let expected = SynthType::Union(SynthUnion {
variants: vec![
SynthType::Record(SynthRecord::new([(
"a".into(),
SynthField::required(SynthType::Integer),
)])),
SynthType::Record(SynthRecord::new([
("a".into(), SynthField::required(SynthType::Text(None))),
("b".into(), SynthField::required(SynthType::Boolean)),
])),
],
});
assert_eq!(
unify(SynthType::Record(r1), SynthType::Record(r2)),
expected
);
}
#[test]
fn test_unify_text_languages() {
assert_eq!(
unify(SynthType::Text(None), SynthType::Text(Some("rust".into()))),
SynthType::Text(Some("rust".into()))
);
assert_eq!(
unify(
SynthType::Text(Some("rust".into())),
SynthType::Text(Some("rust".into()))
),
SynthType::Text(Some("rust".into()))
);
assert_eq!(
unify(
SynthType::Text(Some("rust".into())),
SynthType::Text(Some("python".into())),
),
SynthType::Union(SynthUnion {
variants: vec![
SynthType::Text(Some("rust".into())),
SynthType::Text(Some("python".into()))
]
})
);
}
#[test]
fn test_unify_unions() {
let u1 = SynthUnion::from_variants([SynthType::Integer, SynthType::Boolean]);
let u2 = SynthUnion::from_variants([SynthType::Text(None), SynthType::Float]);
assert_eq!(
unify(u1, u2),
SynthType::Union(SynthUnion {
variants: vec![
SynthType::Integer,
SynthType::Boolean,
SynthType::Text(None),
SynthType::Float
]
})
);
}
#[test]
fn test_unify_union_with_member() {
let union = SynthUnion::from_variants([SynthType::Integer, SynthType::Boolean]);
assert_eq!(
unify(union, SynthType::Text(None)),
SynthType::Union(SynthUnion {
variants: vec![
SynthType::Integer,
SynthType::Boolean,
SynthType::Text(None)
]
})
);
}
}