use mrs_core::Substitution;
use mrs_core::term::Term;
use crate::{UnifyError, UnifyResult};
pub fn match_term(pattern: &Term, target: &Term) -> UnifyResult {
let mut subst = Substitution::new();
match_rec(pattern, target, &mut subst)?;
Ok(subst)
}
fn match_rec(pattern: &Term, target: &Term, subst: &mut Substitution) -> Result<(), UnifyError> {
match pattern {
Term::Var(v) => {
if let Some(bound) = subst.lookup(*v) {
if bound == target {
Ok(())
} else {
Err(UnifyError::SymbolClash {
left: format!("X{} (bound)", v),
right: format!("{:?}", target),
})
}
} else {
subst.bind(*v, target.clone());
Ok(())
}
}
Term::App(f1, args1) => {
match target {
Term::App(f2, args2) => {
if f1 != f2 {
return Err(UnifyError::SymbolClash {
left: format!("{:?}", f1),
right: format!("{:?}", f2),
});
}
if args1.len() != args2.len() {
return Err(UnifyError::ArityMismatch {
expected: args1.len(),
found: args2.len(),
});
}
for (a1, a2) in args1.iter().zip(args2.iter()) {
match_rec(a1, a2, subst)?;
}
Ok(())
}
Term::Var(_) => Err(UnifyError::SymbolClash {
left: format!("{:?}", f1),
right: "variable".to_string(),
}),
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use mrs_core::SymbolTable;
#[test]
fn match_var_to_constant() {
let mut syms = SymbolTable::new();
let a = syms.intern("a");
let sub = match_term(&Term::var(0), &Term::constant(a)).unwrap();
assert_eq!(sub.apply_term(&Term::var(0)), Term::constant(a));
}
#[test]
fn match_function_pattern() {
let mut syms = SymbolTable::new();
let f = syms.intern("f");
let a = syms.intern("a");
let b = syms.intern("b");
let pattern = Term::app(f, vec![Term::var(0), Term::var(1)]);
let target = Term::app(f, vec![Term::constant(a), Term::constant(b)]);
let sub = match_term(&pattern, &target).unwrap();
assert_eq!(sub.apply_term(&Term::var(0)), Term::constant(a));
assert_eq!(sub.apply_term(&Term::var(1)), Term::constant(b));
}
#[test]
fn match_repeated_var() {
let mut syms = SymbolTable::new();
let f = syms.intern("f");
let a = syms.intern("a");
let pattern = Term::app(f, vec![Term::var(0), Term::var(0)]);
let target = Term::app(f, vec![Term::constant(a), Term::constant(a)]);
let sub = match_term(&pattern, &target).unwrap();
assert_eq!(sub.apply_term(&Term::var(0)), Term::constant(a));
}
#[test]
fn match_repeated_var_conflict() {
let mut syms = SymbolTable::new();
let f = syms.intern("f");
let a = syms.intern("a");
let b = syms.intern("b");
let pattern = Term::app(f, vec![Term::var(0), Term::var(0)]);
let target = Term::app(f, vec![Term::constant(a), Term::constant(b)]);
assert!(match_term(&pattern, &target).is_err());
}
#[test]
fn match_no_target_var_binding() {
let mut syms = SymbolTable::new();
let f = syms.intern("f");
let a = syms.intern("a");
let pattern = Term::app(f, vec![Term::constant(a)]);
let target = Term::app(f, vec![Term::var(0)]);
assert!(match_term(&pattern, &target).is_err());
}
#[test]
fn match_var_to_compound() {
let mut syms = SymbolTable::new();
let _f = syms.intern("f");
let g = syms.intern("g");
let a = syms.intern("a");
let pattern = Term::var(0);
let target = Term::app(g, vec![Term::constant(a)]);
let sub = match_term(&pattern, &target).unwrap();
assert_eq!(
sub.apply_term(&Term::var(0)),
Term::app(g, vec![Term::constant(a)])
);
}
}