use crate::types::*;
use indexmap::IndexMap;
#[derive(Default)]
pub struct Unifier {
next_var: TyVarId,
subst: IndexMap<TyVarId, Ty>,
eff_subst: IndexMap<u32, EffectSet>,
next_eff_var: u32,
}
impl Unifier {
pub fn new() -> Self { Self::default() }
pub fn fresh(&mut self) -> Ty {
let v = self.next_var;
self.next_var += 1;
Ty::Var(v)
}
pub fn fresh_id(&mut self) -> TyVarId {
let v = self.next_var;
self.next_var += 1;
v
}
pub fn fresh_eff_id(&mut self) -> u32 {
let v = self.next_eff_var;
self.next_eff_var += 1;
v
}
pub fn resolve(&self, t: &Ty) -> Ty {
match t {
Ty::Var(v) => match self.subst.get(v) {
Some(t2) => self.resolve(t2),
None => Ty::Var(*v),
},
Ty::Prim(_) | Ty::Unit | Ty::Never => t.clone(),
Ty::List(inner) => Ty::List(Box::new(self.resolve(inner))),
Ty::Tuple(items) => Ty::Tuple(items.iter().map(|t| self.resolve(t)).collect()),
Ty::Record(fs) => {
let mut out = IndexMap::new();
for (k, v) in fs { out.insert(k.clone(), self.resolve(v)); }
Ty::Record(out)
}
Ty::Con(n, args) => Ty::Con(n.clone(), args.iter().map(|t| self.resolve(t)).collect()),
Ty::Function { params, effects, ret } => Ty::Function {
params: params.iter().map(|t| self.resolve(t)).collect(),
effects: self.resolve_effects(effects),
ret: Box::new(self.resolve(ret)),
},
}
}
pub fn resolve_effects(&self, eff: &EffectSet) -> EffectSet {
let mut out = EffectSet { concrete: eff.concrete.clone(), var: None };
let mut cur_var = eff.var;
while let Some(v) = cur_var {
match self.eff_subst.get(&v) {
Some(bound) => {
out.concrete.extend(bound.concrete.iter().cloned());
cur_var = bound.var;
}
None => { out.var = Some(v); break; }
}
}
out
}
pub fn unify_effects(&mut self, a: &EffectSet, b: &EffectSet) -> Result<(), UnifyError> {
let a = self.resolve_effects(a);
let b = self.resolve_effects(b);
match (a.var, b.var) {
(None, None) => {
if a.concrete == b.concrete { Ok(()) }
else { Err(UnifyError::EffectMismatch { a, b }) }
}
(Some(va), Some(vb)) if va == vb => {
if a.concrete == b.concrete { Ok(()) }
else { Err(UnifyError::EffectMismatch { a, b }) }
}
(Some(va), _) => {
if !a.concrete.is_subset(&b.concrete) {
if b.var.is_none() {
return Err(UnifyError::EffectMismatch { a, b });
}
}
let extra: std::collections::BTreeSet<String> =
b.concrete.difference(&a.concrete).cloned().collect();
let bound = EffectSet { concrete: extra, var: b.var };
self.eff_subst.insert(va, bound);
Ok(())
}
(None, Some(vb)) => {
if !b.concrete.is_subset(&a.concrete) {
return Err(UnifyError::EffectMismatch { a, b });
}
let extra: std::collections::BTreeSet<String> =
a.concrete.difference(&b.concrete).cloned().collect();
let bound = EffectSet { concrete: extra, var: None };
self.eff_subst.insert(vb, bound);
Ok(())
}
}
}
pub fn unify(&mut self, a: &Ty, b: &Ty) -> Result<(), UnifyError> {
let a = self.resolve(a);
let b = self.resolve(b);
match (&a, &b) {
(Ty::Var(v), other) | (other, Ty::Var(v)) => {
if let Ty::Var(w) = other {
if v == w { return Ok(()); }
}
if occurs(*v, other, self) {
return Err(UnifyError::Infinite { var: *v, ty: other.clone() });
}
self.subst.insert(*v, other.clone());
Ok(())
}
(Ty::Prim(p1), Ty::Prim(p2)) if p1 == p2 => Ok(()),
(Ty::Unit, Ty::Unit) | (Ty::Never, Ty::Never) => Ok(()),
(Ty::Never, _) | (_, Ty::Never) => Ok(()),
(Ty::List(t1), Ty::List(t2)) => self.unify(t1, t2),
(Ty::Tuple(xs), Ty::Tuple(ys)) if xs.len() == ys.len() => {
for (x, y) in xs.iter().zip(ys.iter()) { self.unify(x, y)?; }
Ok(())
}
(Ty::Record(a), Ty::Record(b)) => {
if a.len() != b.len() {
return Err(UnifyError::Mismatch { a: Ty::Record(a.clone()), b: Ty::Record(b.clone()) });
}
for (k, va) in a {
match b.get(k) {
Some(vb) => self.unify(va, vb)?,
None => return Err(UnifyError::Mismatch {
a: Ty::Record(a.clone()), b: Ty::Record(b.clone())
}),
}
}
Ok(())
}
(Ty::Con(n1, a1), Ty::Con(n2, a2)) if n1 == n2 && a1.len() == a2.len() => {
for (x, y) in a1.iter().zip(a2.iter()) { self.unify(x, y)?; }
Ok(())
}
(Ty::Function { params: p1, effects: e1, ret: r1 },
Ty::Function { params: p2, effects: e2, ret: r2 })
if p1.len() == p2.len() =>
{
for (x, y) in p1.iter().zip(p2.iter()) { self.unify(x, y)?; }
self.unify_effects(e1, e2)?;
self.unify(r1, r2)
}
_ => Err(UnifyError::Mismatch { a, b }),
}
}
}
fn occurs(v: TyVarId, t: &Ty, u: &Unifier) -> bool {
let t = u.resolve(t);
match t {
Ty::Var(w) => v == w,
Ty::Prim(_) | Ty::Unit | Ty::Never => false,
Ty::List(inner) => occurs(v, &inner, u),
Ty::Tuple(items) => items.iter().any(|t| occurs(v, t, u)),
Ty::Record(fs) => fs.values().any(|t| occurs(v, t, u)),
Ty::Con(_, args) => args.iter().any(|t| occurs(v, t, u)),
Ty::Function { params, ret, .. } => {
params.iter().any(|t| occurs(v, t, u)) || occurs(v, &ret, u)
}
}
}
#[derive(Debug, Clone)]
pub enum UnifyError {
Mismatch { a: Ty, b: Ty },
Infinite { var: TyVarId, ty: Ty },
EffectMismatch { a: EffectSet, b: EffectSet },
}