use super::types::Type;
#[derive(Clone, Debug, Default)]
pub struct UnionFind {
inner: Vec<Type>,
}
impl UnionFind {
pub fn find(&mut self, index: usize) -> Type {
match &self.inner[index] {
Type::Var(i)
| Type::IntVar(i, _)
| Type::FloatVar(i)
| Type::RecordVar(i, _)
if *i != index =>
{
let new_t = self.find(*i);
self.inner[index] = new_t.clone();
new_t
}
t => t.clone(),
}
}
pub fn fresh(&mut self, f: impl FnOnce(usize) -> Type) -> Type {
let n = self.inner.len();
let t = f(n);
self.inner.push(t.clone());
t
}
pub fn set(&mut self, index: usize, t: Type) {
self.inner[index] = t;
}
pub fn find_ref(&self, index: usize) -> &Type {
match &self.inner[index] {
Type::Var(i)
| Type::IntVar(i, _)
| Type::FloatVar(i)
| Type::RecordVar(i, _)
if *i != index =>
{
self.find_ref(*i)
}
t => t,
}
}
}