use crate::arena::Arena;
use crate::context::Context;
use crate::environment::Environment;
use crate::term::{MetaVarId, TermId, TermKind};
use std::collections::HashMap;
use std::collections::VecDeque;
#[derive(Debug, Clone)]
pub struct Substitution {
assignments: HashMap<MetaVarId, TermId>,
}
impl Substitution {
pub fn new() -> Self {
Self {
assignments: HashMap::new(),
}
}
pub fn assign(&mut self, mvar: MetaVarId, term: TermId) {
self.assignments.insert(mvar, term);
}
pub fn lookup(&self, mvar: MetaVarId) -> Option<TermId> {
self.assignments.get(&mvar).copied()
}
pub fn is_assigned(&self, mvar: MetaVarId) -> bool {
self.assignments.contains_key(&mvar)
}
pub fn assignments(&self) -> &HashMap<MetaVarId, TermId> {
&self.assignments
}
}
impl Default for Substitution {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Constraint {
Unify(TermId, TermId),
IsSort(TermId),
HasType(MetaVarId, TermId),
}
pub struct Unifier {
subst: Substitution,
constraints: VecDeque<Constraint>,
mvar_types: HashMap<MetaVarId, TermId>,
}
impl Unifier {
pub fn new() -> Self {
Self {
subst: Substitution::new(),
constraints: VecDeque::new(),
mvar_types: HashMap::new(),
}
}
pub fn add_constraint(&mut self, constraint: Constraint) {
self.constraints.push_back(constraint);
}
pub fn unify(&mut self, t1: TermId, t2: TermId) {
self.add_constraint(Constraint::Unify(t1, t2));
}
pub fn declare_mvar(&mut self, mvar: MetaVarId, ty: TermId) {
self.mvar_types.insert(mvar, ty);
}
pub fn solve(
&mut self,
arena: &mut Arena,
env: &Environment,
ctx: &Context,
) -> crate::Result<()> {
while let Some(constraint) = self.constraints.pop_front() {
match constraint {
Constraint::Unify(t1, t2) => {
self.solve_unify(arena, env, ctx, t1, t2)?;
}
Constraint::IsSort(term) => {
let term = self.apply_subst(arena, term)?;
if let Some(TermKind::Sort(_)) = arena.kind(term) {
} else if let Some(TermKind::MVar(_mvar)) = arena.kind(term) {
self.add_constraint(Constraint::IsSort(term));
} else {
return Err(crate::Error::UnificationError(
"Expected sort".to_string(),
));
}
}
Constraint::HasType(mvar, ty) => {
self.mvar_types.insert(mvar, ty);
}
}
}
Ok(())
}
fn solve_unify(
&mut self,
arena: &mut Arena,
_env: &Environment,
_ctx: &Context,
t1: TermId,
t2: TermId,
) -> crate::Result<()> {
if t1 == t2 {
return Ok(());
}
let t1 = self.apply_subst(arena, t1)?;
let t2 = self.apply_subst(arena, t2)?;
if t1 == t2 {
return Ok(());
}
let kind1 = arena.kind(t1).ok_or_else(|| {
crate::Error::Internal(format!("Invalid term ID: {:?}", t1))
})?.clone();
let kind2 = arena.kind(t2).ok_or_else(|| {
crate::Error::Internal(format!("Invalid term ID: {:?}", t2))
})?.clone();
match (kind1, kind2) {
(TermKind::MVar(m), _) => {
if !self.subst.is_assigned(m) {
if self.occurs_check(m, t2, arena)? {
return Err(crate::Error::UnificationError(
"Occurs check failed".to_string(),
));
}
self.subst.assign(m, t2);
Ok(())
} else {
let assigned = self.subst.lookup(m).unwrap();
self.solve_unify(arena, _env, _ctx, assigned, t2)
}
}
(_, TermKind::MVar(m)) => {
if !self.subst.is_assigned(m) {
if self.occurs_check(m, t1, arena)? {
return Err(crate::Error::UnificationError(
"Occurs check failed".to_string(),
));
}
self.subst.assign(m, t1);
Ok(())
} else {
let assigned = self.subst.lookup(m).unwrap();
self.solve_unify(arena, _env, _ctx, t1, assigned)
}
}
(TermKind::App(f1, a1), TermKind::App(f2, a2)) => {
self.solve_unify(arena, _env, _ctx, f1, f2)?;
self.solve_unify(arena, _env, _ctx, a1, a2)?;
Ok(())
}
(TermKind::Lam(b1, body1), TermKind::Lam(b2, body2)) => {
self.solve_unify(arena, _env, _ctx, b1.ty, b2.ty)?;
self.solve_unify(arena, _env, _ctx, body1, body2)?;
Ok(())
}
(TermKind::Pi(b1, body1), TermKind::Pi(b2, body2)) => {
self.solve_unify(arena, _env, _ctx, b1.ty, b2.ty)?;
self.solve_unify(arena, _env, _ctx, body1, body2)?;
Ok(())
}
(TermKind::Sort(l1), TermKind::Sort(l2)) if l1 == l2 => Ok(()),
(TermKind::Var(i1), TermKind::Var(i2)) if i1 == i2 => Ok(()),
(TermKind::Const(n1, lvls1), TermKind::Const(n2, lvls2))
if n1 == n2 && lvls1 == lvls2 =>
{
Ok(())
}
_ => Err(crate::Error::UnificationError(format!(
"Cannot unify {:?} with {:?}",
t1, t2
))),
}
}
fn occurs_check(
&self,
mvar: MetaVarId,
term: TermId,
arena: &Arena,
) -> crate::Result<bool> {
let kind = arena.kind(term).ok_or_else(|| {
crate::Error::Internal(format!("Invalid term ID: {:?}", term))
})?;
match kind {
TermKind::MVar(m) if *m == mvar => Ok(true),
TermKind::MVar(m) => {
if let Some(assigned) = self.subst.lookup(*m) {
self.occurs_check(mvar, assigned, arena)
} else {
Ok(false)
}
}
TermKind::App(f, a) => {
let in_func = self.occurs_check(mvar, *f, arena)?;
let in_arg = self.occurs_check(mvar, *a, arena)?;
Ok(in_func || in_arg)
}
TermKind::Lam(b, body) | TermKind::Pi(b, body) => {
let in_ty = self.occurs_check(mvar, b.ty, arena)?;
let in_body = self.occurs_check(mvar, *body, arena)?;
Ok(in_ty || in_body)
}
TermKind::Let(b, val, body) => {
let in_ty = self.occurs_check(mvar, b.ty, arena)?;
let in_val = self.occurs_check(mvar, *val, arena)?;
let in_body = self.occurs_check(mvar, *body, arena)?;
Ok(in_ty || in_val || in_body)
}
TermKind::Sort(_) | TermKind::Const(_, _) | TermKind::Var(_) | TermKind::Lit(_) => {
Ok(false)
}
}
}
fn apply_subst(&self, arena: &Arena, term: TermId) -> crate::Result<TermId> {
let kind = arena.kind(term).ok_or_else(|| {
crate::Error::Internal(format!("Invalid term ID: {:?}", term))
})?;
match kind {
TermKind::MVar(m) => {
if let Some(assigned) = self.subst.lookup(*m) {
self.apply_subst(arena, assigned)
} else {
Ok(term)
}
}
_ => Ok(term),
}
}
pub fn substitution(&self) -> &Substitution {
&self.subst
}
pub fn is_solved(&self) -> bool {
self.constraints.is_empty()
}
pub fn num_constraints(&self) -> usize {
self.constraints.len()
}
}
impl Default for Unifier {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_unification() {
let mut arena = Arena::new();
let env = Environment::new();
let ctx = Context::new();
let mut unifier = Unifier::new();
let var0 = arena.mk_var(0);
let var1 = arena.mk_var(1);
let mvar0 = arena.mk_mvar(MetaVarId::new(0));
unifier.unify(mvar0, var0);
unifier.solve(&mut arena, &env, &ctx).unwrap();
assert!(unifier.is_solved());
assert!(unifier.substitution().is_assigned(MetaVarId::new(0)));
}
#[test]
fn test_occurs_check() {
let mut arena = Arena::new();
let env = Environment::new();
let ctx = Context::new();
let mut unifier = Unifier::new();
let mvar0_id = MetaVarId::new(0);
let mvar0 = arena.mk_mvar(mvar0_id);
let x = arena.mk_var(0);
let app = arena.mk_app(mvar0, x);
unifier.unify(mvar0, app);
let result = unifier.solve(&mut arena, &env, &ctx);
assert!(result.is_err());
}
#[test]
fn test_structural_unification() {
let mut arena = Arena::new();
let env = Environment::new();
let ctx = Context::new();
let mut unifier = Unifier::new();
let mvar0 = arena.mk_mvar(MetaVarId::new(0));
let x = arena.mk_var(0);
let y = arena.mk_var(1);
let app1 = arena.mk_app(mvar0, x);
let app2 = arena.mk_app(y, x);
unifier.unify(app1, app2);
unifier.solve(&mut arena, &env, &ctx).unwrap();
let assignment = unifier.substitution().lookup(MetaVarId::new(0)).unwrap();
assert_eq!(assignment, y);
}
}