use mrs_core::Substitution;
use mrs_core::term::{Term, VarId};
use crate::{UnifyError, UnifyResult};
pub fn unify(s: &Term, t: &Term) -> UnifyResult {
let mut subst = Substitution::new();
unify_rec(s, t, &mut subst)?;
Ok(subst)
}
fn unify_rec(s: &Term, t: &Term, subst: &mut Substitution) -> Result<(), UnifyError> {
let s = subst.apply_term(s);
let t = subst.apply_term(t);
match (&s, &t) {
_ if s == t => Ok(()),
(Term::Var(v), _) => bind_var(*v, &t, subst),
(_, Term::Var(v)) => bind_var(*v, &s, subst),
(Term::App(f1, args1), Term::App(f2, args2)) => {
if f1 != f2 {
return Err(UnifyError::SymbolClash {
left: format!("{:?}", f1),
right: format!("{:?}", f2),
});
}
if args1.len() != args2.len() {
return Err(UnifyError::ArityMismatch {
expected: args1.len(),
found: args2.len(),
});
}
for (a1, a2) in args1.iter().zip(args2.iter()) {
unify_rec(a1, a2, subst)?;
}
Ok(())
}
}
}
fn bind_var(var: VarId, term: &Term, subst: &mut Substitution) -> Result<(), UnifyError> {
if let Term::Var(v) = term
&& *v == var
{
return Ok(());
}
if term.contains_var(var) {
return Err(UnifyError::OccursCheck { var });
}
subst.bind(var, term.clone());
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use mrs_core::{SymbolId, SymbolTable};
fn syms() -> (SymbolTable, SymbolId, SymbolId, SymbolId, SymbolId) {
let mut st = SymbolTable::new();
let f = st.intern("f");
let g = st.intern("g");
let a = st.intern("a");
let b = st.intern("b");
(st, f, g, a, b)
}
#[test]
fn unify_identical_vars() {
let mgu = unify(&Term::var(0), &Term::var(0)).unwrap();
assert!(mgu.is_empty());
}
#[test]
fn unify_var_with_constant() {
let (_, _, _, a, _) = syms();
let mgu = unify(&Term::var(0), &Term::constant(a)).unwrap();
assert_eq!(mgu.apply_term(&Term::var(0)), Term::constant(a));
}
#[test]
fn unify_two_vars() {
let mgu = unify(&Term::var(0), &Term::var(1)).unwrap();
let result = mgu.apply_term(&Term::var(0));
assert_eq!(result, mgu.apply_term(&Term::var(1)));
}
#[test]
fn unify_function_terms() {
let (_, f, _, a, b) = syms();
let t1 = Term::app(f, vec![Term::var(0), Term::constant(a)]);
let t2 = Term::app(f, vec![Term::constant(b), Term::var(1)]);
let mgu = unify(&t1, &t2).unwrap();
assert_eq!(mgu.apply_term(&Term::var(0)), Term::constant(b));
assert_eq!(mgu.apply_term(&Term::var(1)), Term::constant(a));
}
#[test]
fn unify_nested() {
let (_, f, g, a, _) = syms();
let t1 = Term::app(f, vec![Term::var(0), Term::app(g, vec![Term::var(0)])]);
let t2 = Term::app(
f,
vec![Term::constant(a), Term::app(g, vec![Term::constant(a)])],
);
let mgu = unify(&t1, &t2).unwrap();
assert_eq!(mgu.apply_term(&Term::var(0)), Term::constant(a));
}
#[test]
fn unify_transitive() {
let (_, f, _, a, _) = syms();
let t1 = Term::app(f, vec![Term::var(0), Term::var(1)]);
let t2 = Term::app(f, vec![Term::var(1), Term::constant(a)]);
let mgu = unify(&t1, &t2).unwrap();
assert_eq!(mgu.apply_term(&Term::var(0)), Term::constant(a));
assert_eq!(mgu.apply_term(&Term::var(1)), Term::constant(a));
}
#[test]
fn unify_symbol_clash() {
let (_, f, g, a, _) = syms();
let t1 = Term::app(f, vec![Term::constant(a)]);
let t2 = Term::app(g, vec![Term::constant(a)]);
assert!(matches!(
unify(&t1, &t2),
Err(UnifyError::SymbolClash { .. })
));
}
#[test]
fn unify_occurs_check() {
let (_, f, _, _, _) = syms();
let t1 = Term::var(0);
let t2 = Term::app(f, vec![Term::var(0)]);
assert!(matches!(
unify(&t1, &t2),
Err(UnifyError::OccursCheck { var: 0 })
));
}
#[test]
fn unify_different_constants() {
let (_, _, _, a, b) = syms();
assert!(unify(&Term::constant(a), &Term::constant(b)).is_err());
}
#[test]
fn unify_idempotent() {
let (_, f, _, a, _) = syms();
let t1 = Term::app(f, vec![Term::var(0), Term::var(1)]);
let t2 = Term::app(f, vec![Term::constant(a), Term::var(0)]);
let mgu = unify(&t1, &t2).unwrap();
let applied1 = mgu.apply_term(&t1);
let applied2 = mgu.apply_term(&applied1);
assert_eq!(applied1, applied2);
}
}