use std::collections::HashSet;
use mrs_core::Substitution;
use mrs_core::SymbolId;
use mrs_core::term::{Term, VarId};
use crate::{UnifyError, UnifyResult};
pub fn unify(s: &Term, t: &Term) -> UnifyResult {
let mut subst = Substitution::new();
unify_rec(s, t, &mut subst)?;
Ok(subst)
}
fn unify_rec(s: &Term, t: &Term, subst: &mut Substitution) -> Result<(), UnifyError> {
let s = subst.apply_term(s);
let t = subst.apply_term(t);
match (&s, &t) {
_ if s == t => Ok(()),
(Term::Var(v), _) => bind_var(*v, &t, subst),
(_, Term::Var(v)) => bind_var(*v, &s, subst),
(Term::App(f1, args1), Term::App(f2, args2)) => {
if f1 != f2 {
return Err(UnifyError::SymbolClash {
left: format!("{:?}", f1),
right: format!("{:?}", f2),
});
}
if args1.len() != args2.len() {
return Err(UnifyError::ArityMismatch {
expected: args1.len(),
found: args2.len(),
});
}
for (a1, a2) in args1.iter().zip(args2.iter()) {
unify_rec(a1, a2, subst)?;
}
Ok(())
}
}
}
pub fn unify_comm(s: &Term, t: &Term, comm: &HashSet<SymbolId>) -> UnifyResult {
if comm.is_empty() {
return unify(s, t);
}
let mut subst = Substitution::new();
unify_comm_rec(s, t, &mut subst, comm)?;
Ok(subst)
}
fn unify_comm_rec(
s: &Term,
t: &Term,
subst: &mut Substitution,
comm: &HashSet<SymbolId>,
) -> Result<(), UnifyError> {
let s = subst.apply_term(s);
let t = subst.apply_term(t);
match (&s, &t) {
_ if s == t => Ok(()),
(Term::Var(v), _) => bind_var(*v, &t, subst),
(_, Term::Var(v)) => bind_var(*v, &s, subst),
(Term::App(f1, args1), Term::App(f2, args2)) => {
if f1 != f2 {
return Err(UnifyError::SymbolClash {
left: format!("{:?}", f1),
right: format!("{:?}", f2),
});
}
if args1.len() != args2.len() {
return Err(UnifyError::ArityMismatch {
expected: args1.len(),
found: args2.len(),
});
}
let saved = subst.clone();
let normal_ok: Result<(), UnifyError> = (|| {
for (a1, a2) in args1.iter().zip(args2.iter()) {
unify_comm_rec(a1, a2, subst, comm)?;
}
Ok(())
})();
if normal_ok.is_ok() {
return Ok(());
}
if comm.contains(f1) && args1.len() == 2 {
let mut subst_swap = saved.clone();
let swap_ok: Result<(), UnifyError> = (|| {
unify_comm_rec(&args1[0], &args2[1], &mut subst_swap, comm)?;
unify_comm_rec(&args1[1], &args2[0], &mut subst_swap, comm)?;
Ok(())
})();
if swap_ok.is_ok() {
*subst = subst_swap;
return Ok(());
}
}
*subst = saved;
normal_ok
}
}
}
fn bind_var(var: VarId, term: &Term, subst: &mut Substitution) -> Result<(), UnifyError> {
if let Term::Var(v) = term
&& *v == var
{
return Ok(());
}
let term = subst.apply_term(term);
if let Term::Var(v) = &term
&& *v == var
{
return Ok(());
}
if term.contains_var(var) {
return Err(UnifyError::OccursCheck { var });
}
subst.bind(var, term);
Ok(())
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use super::*;
use mrs_core::{SymbolId, SymbolTable};
fn syms() -> (SymbolTable, SymbolId, SymbolId, SymbolId, SymbolId) {
let mut st = SymbolTable::new();
let f = st.intern("f");
let g = st.intern("g");
let a = st.intern("a");
let b = st.intern("b");
(st, f, g, a, b)
}
#[test]
fn unify_identical_vars() {
let mgu = unify(&Term::var(0), &Term::var(0)).unwrap();
assert!(mgu.is_empty());
}
#[test]
fn unify_var_with_constant() {
let (_, _, _, a, _) = syms();
let mgu = unify(&Term::var(0), &Term::constant(a)).unwrap();
assert_eq!(mgu.apply_term(&Term::var(0)), Term::constant(a));
}
#[test]
fn unify_two_vars() {
let mgu = unify(&Term::var(0), &Term::var(1)).unwrap();
let result = mgu.apply_term(&Term::var(0));
assert_eq!(result, mgu.apply_term(&Term::var(1)));
}
#[test]
fn unify_function_terms() {
let (_, f, _, a, b) = syms();
let t1 = Term::app(f, vec![Term::var(0), Term::constant(a)]);
let t2 = Term::app(f, vec![Term::constant(b), Term::var(1)]);
let mgu = unify(&t1, &t2).unwrap();
assert_eq!(mgu.apply_term(&Term::var(0)), Term::constant(b));
assert_eq!(mgu.apply_term(&Term::var(1)), Term::constant(a));
}
#[test]
fn unify_nested() {
let (_, f, g, a, _) = syms();
let t1 = Term::app(f, vec![Term::var(0), Term::app(g, vec![Term::var(0)])]);
let t2 = Term::app(
f,
vec![Term::constant(a), Term::app(g, vec![Term::constant(a)])],
);
let mgu = unify(&t1, &t2).unwrap();
assert_eq!(mgu.apply_term(&Term::var(0)), Term::constant(a));
}
#[test]
fn unify_transitive() {
let (_, f, _, a, _) = syms();
let t1 = Term::app(f, vec![Term::var(0), Term::var(1)]);
let t2 = Term::app(f, vec![Term::var(1), Term::constant(a)]);
let mgu = unify(&t1, &t2).unwrap();
assert_eq!(mgu.apply_term(&Term::var(0)), Term::constant(a));
assert_eq!(mgu.apply_term(&Term::var(1)), Term::constant(a));
}
#[test]
fn unify_symbol_clash() {
let (_, f, g, a, _) = syms();
let t1 = Term::app(f, vec![Term::constant(a)]);
let t2 = Term::app(g, vec![Term::constant(a)]);
assert!(matches!(
unify(&t1, &t2),
Err(UnifyError::SymbolClash { .. })
));
}
#[test]
fn unify_occurs_check() {
let (_, f, _, _, _) = syms();
let t1 = Term::var(0);
let t2 = Term::app(f, vec![Term::var(0)]);
assert!(matches!(
unify(&t1, &t2),
Err(UnifyError::OccursCheck { var: 0 })
));
}
#[test]
fn unify_different_constants() {
let (_, _, _, a, b) = syms();
assert!(unify(&Term::constant(a), &Term::constant(b)).is_err());
}
#[test]
fn unify_idempotent() {
let (_, f, _, a, _) = syms();
let t1 = Term::app(f, vec![Term::var(0), Term::var(1)]);
let t2 = Term::app(f, vec![Term::constant(a), Term::var(0)]);
let mgu = unify(&t1, &t2).unwrap();
let applied1 = mgu.apply_term(&t1);
let applied2 = mgu.apply_term(&applied1);
assert_eq!(applied1, applied2);
}
#[test]
fn unify_comm_normal_order_preferred() {
let (_, f, _, a, b) = syms();
let mut comm = HashSet::new();
comm.insert(f);
let t1 = Term::app(f, vec![Term::constant(a), Term::constant(b)]);
let t2 = Term::app(f, vec![Term::constant(a), Term::constant(b)]);
let mgu = unify_comm(&t1, &t2, &comm).unwrap();
assert!(mgu.is_empty());
}
#[test]
fn unify_comm_swapped_constants() {
let (_, f, _, a, b) = syms();
let mut comm = HashSet::new();
comm.insert(f);
let t1 = Term::app(f, vec![Term::constant(a), Term::constant(b)]);
let t2 = Term::app(f, vec![Term::constant(b), Term::constant(a)]);
assert!(unify_comm(&t1, &t2, &comm).is_ok());
}
#[test]
fn unify_comm_swapped_with_vars() {
let (_, f, _, a, _) = syms();
let mut comm = HashSet::new();
comm.insert(f);
let t1 = Term::app(f, vec![Term::var(0), Term::constant(a)]);
let t2 = Term::app(f, vec![Term::constant(a), Term::var(0)]);
let mgu = unify_comm(&t1, &t2, &comm).unwrap();
assert_eq!(mgu.apply_term(&Term::var(0)), Term::constant(a));
}
#[test]
fn unify_comm_fails_when_both_orderings_clash() {
let (_, f, _, a, b) = syms();
let comm = HashSet::new(); let t1 = Term::app(f, vec![Term::constant(a), Term::constant(b)]);
let t2 = Term::app(f, vec![Term::constant(b), Term::constant(a)]);
assert!(unify_comm(&t1, &t2, &comm).is_err());
}
#[test]
fn unify_comm_non_binary_not_swapped() {
let mut st = SymbolTable::new();
let g = st.intern("g");
let a = st.intern("a");
let b = st.intern("b");
let c = st.intern("c");
let mut comm = HashSet::new();
comm.insert(g); let t1 = Term::app(
g,
vec![Term::constant(a), Term::constant(b), Term::constant(c)],
);
let t2 = Term::app(
g,
vec![Term::constant(c), Term::constant(b), Term::constant(a)],
);
assert!(unify_comm(&t1, &t2, &comm).is_err());
}
}