use crate::error::IrError;
use crate::term::Term;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
pub struct Substitution {
bindings: HashMap<String, Term>,
}
impl Substitution {
pub fn empty() -> Self {
Substitution {
bindings: HashMap::new(),
}
}
pub fn singleton(var: String, term: Term) -> Self {
let mut bindings = HashMap::new();
bindings.insert(var, term);
Substitution { bindings }
}
pub fn from_map(bindings: HashMap<String, Term>) -> Self {
Substitution { bindings }
}
pub fn is_empty(&self) -> bool {
self.bindings.is_empty()
}
pub fn len(&self) -> usize {
self.bindings.len()
}
pub fn get(&self, var: &str) -> Option<&Term> {
self.bindings.get(var)
}
pub fn bind(&mut self, var: String, term: Term) {
self.bindings.insert(var, term);
}
pub fn apply(&self, term: &Term) -> Term {
match term {
Term::Var(name) => {
self.bindings
.get(name)
.cloned()
.unwrap_or_else(|| term.clone())
}
Term::Const(_) => term.clone(),
Term::Typed {
value,
type_annotation,
} => Term::Typed {
value: Box::new(self.apply(value)),
type_annotation: type_annotation.clone(),
},
}
}
pub fn compose(&self, other: &Substitution) -> Substitution {
let mut result = HashMap::new();
for (var, term) in &other.bindings {
result.insert(var.clone(), self.apply(term));
}
for (var, term) in &self.bindings {
if !result.contains_key(var) {
result.insert(var.clone(), term.clone());
}
}
Substitution::from_map(result)
}
pub fn domain(&self) -> Vec<String> {
self.bindings.keys().cloned().collect()
}
pub fn range(&self) -> Vec<Term> {
self.bindings.values().cloned().collect()
}
pub fn extend(&mut self, var: String, term: Term) -> Result<(), IrError> {
if let Some(existing) = self.bindings.get(&var) {
if existing != &term {
return Err(IrError::UnificationFailure {
type1: format!("{:?}", existing),
type2: format!("{:?}", term),
});
}
}
self.bindings.insert(var, term);
Ok(())
}
pub fn try_extend(&mut self, other: &Substitution) -> Result<(), IrError> {
for (var, term) in &other.bindings {
if let Some(existing) = self.bindings.get(var) {
if existing != term {
return Err(IrError::UnificationFailure {
type1: format!("{:?}", existing),
type2: format!("{:?}", term),
});
}
} else {
self.bindings.insert(var.clone(), term.clone());
}
}
Ok(())
}
}
fn occurs_check(var: &str, term: &Term) -> bool {
match term {
Term::Var(name) => name == var,
Term::Const(_) => false,
Term::Typed { value, .. } => occurs_check(var, value),
}
}
pub fn unify_terms(term1: &Term, term2: &Term) -> Result<Substitution, IrError> {
unify_impl(term1, term2, &mut Substitution::empty())
}
fn unify_impl(
term1: &Term,
term2: &Term,
subst: &mut Substitution,
) -> Result<Substitution, IrError> {
let t1 = subst.apply(term1);
let t2 = subst.apply(term2);
match (&t1, &t2) {
(Term::Var(n1), Term::Var(n2)) if n1 == n2 => Ok(subst.clone()),
(Term::Var(name), _) => {
if occurs_check(name, &t2) {
return Err(IrError::UnificationFailure {
type1: format!("{:?}", t1),
type2: format!("{:?}", t2),
});
}
subst.bind(name.clone(), t2.clone());
Ok(subst.clone())
}
(_, Term::Var(name)) => {
if occurs_check(name, &t1) {
return Err(IrError::UnificationFailure {
type1: format!("{:?}", t1),
type2: format!("{:?}", t2),
});
}
subst.bind(name.clone(), t1.clone());
Ok(subst.clone())
}
(Term::Const(v1), Term::Const(v2)) => {
if v1 == v2 {
Ok(subst.clone())
} else {
Err(IrError::UnificationFailure {
type1: format!("{:?}", t1),
type2: format!("{:?}", t2),
})
}
}
(
Term::Typed {
value: inner1,
type_annotation: ty1,
},
Term::Typed {
value: inner2,
type_annotation: ty2,
},
) => {
if ty1 != ty2 {
return Err(IrError::UnificationFailure {
type1: format!("{:?}", t1),
type2: format!("{:?}", t2),
});
}
unify_impl(inner1, inner2, subst)
}
(Term::Typed { value, .. }, other) | (other, Term::Typed { value, .. }) => {
unify_impl(value, other, subst)
}
}
}
pub fn unify_term_list(pairs: &[(Term, Term)]) -> Result<Substitution, IrError> {
let mut subst = Substitution::empty();
for (t1, t2) in pairs {
subst = unify_impl(t1, t2, &mut subst)?;
}
Ok(subst)
}
pub fn are_unifiable(term1: &Term, term2: &Term) -> bool {
unify_terms(term1, term2).is_ok()
}
pub fn rename_vars(term: &Term, suffix: &str) -> Term {
match term {
Term::Var(name) => Term::Var(format!("{}_{}", name, suffix)),
Term::Const(_) => term.clone(),
Term::Typed {
value,
type_annotation,
} => Term::Typed {
value: Box::new(rename_vars(value, suffix)),
type_annotation: type_annotation.clone(),
},
}
}
pub fn anti_unify_terms(term1: &Term, term2: &Term) -> (Term, Substitution, Substitution) {
let mut var_counter = 0;
let mut subst1 = Substitution::empty();
let mut subst2 = Substitution::empty();
let gen = anti_unify_impl(term1, term2, &mut var_counter, &mut subst1, &mut subst2);
(gen, subst1, subst2)
}
fn anti_unify_impl(
term1: &Term,
term2: &Term,
var_counter: &mut usize,
subst1: &mut Substitution,
subst2: &mut Substitution,
) -> Term {
match (term1, term2) {
(Term::Const(c1), Term::Const(c2)) if c1 == c2 => term1.clone(),
(Term::Var(v1), Term::Var(v2)) if v1 == v2 => term1.clone(),
(
Term::Typed {
value: inner1,
type_annotation: ty1,
},
Term::Typed {
value: inner2,
type_annotation: ty2,
},
) if ty1 == ty2 => {
let inner_gen = anti_unify_impl(inner1, inner2, var_counter, subst1, subst2);
Term::Typed {
value: Box::new(inner_gen),
type_annotation: ty1.clone(),
}
}
_ => {
*var_counter += 1;
let fresh_var = Term::Var(format!("_G{}", var_counter));
subst1.bind(format!("_G{}", var_counter), term1.clone());
subst2.bind(format!("_G{}", var_counter), term2.clone());
fresh_var
}
}
}
pub fn lgg_terms(terms: &[Term]) -> (Term, Vec<Substitution>) {
if terms.is_empty() {
return (Term::Var("_Empty".to_string()), vec![]);
}
if terms.len() == 1 {
return (terms[0].clone(), vec![Substitution::empty()]);
}
let (mut gen, subst1, subst2) = anti_unify_terms(&terms[0], &terms[1]);
let mut substs = vec![subst1, subst2];
for term in &terms[2..] {
let (new_gen, gen_subst, term_subst) = anti_unify_terms(&gen, term);
gen = new_gen;
for s in &mut substs {
*s = gen_subst.compose(s);
}
substs.push(term_subst);
}
(gen, substs)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_substitution() {
let subst = Substitution::empty();
assert!(subst.is_empty());
assert_eq!(subst.len(), 0);
let term = Term::var("x");
assert_eq!(subst.apply(&term), term);
}
#[test]
fn test_singleton_substitution() {
let subst = Substitution::singleton("x".to_string(), Term::constant("a"));
assert_eq!(subst.len(), 1);
let x = Term::var("x");
let a = Term::constant("a");
assert_eq!(subst.apply(&x), a);
}
#[test]
fn test_substitution_application() {
let mut subst = Substitution::empty();
subst.bind("x".to_string(), Term::constant("a"));
subst.bind("y".to_string(), Term::constant("b"));
let x = Term::var("x");
let y = Term::var("y");
let z = Term::var("z");
assert_eq!(subst.apply(&x), Term::constant("a"));
assert_eq!(subst.apply(&y), Term::constant("b"));
assert_eq!(subst.apply(&z), z); }
#[test]
fn test_unify_var_constant() {
let x = Term::var("x");
let a = Term::constant("a");
let mgu = unify_terms(&x, &a).expect("unwrap");
assert_eq!(mgu.apply(&x), a);
}
#[test]
fn test_unify_same_variable() {
let x = Term::var("x");
let mgu = unify_terms(&x, &x).expect("unwrap");
assert!(mgu.is_empty());
}
#[test]
fn test_unify_different_constants() {
let a = Term::constant("a");
let b = Term::constant("b");
let result = unify_terms(&a, &b);
assert!(result.is_err());
}
#[test]
fn test_unify_same_constant() {
let a = Term::constant("a");
let mgu = unify_terms(&a, &a).expect("unwrap");
assert!(mgu.is_empty());
}
#[test]
fn test_occur_check() {
let x = Term::var("x");
assert!(occurs_check("x", &x));
assert!(!occurs_check("y", &x));
let a = Term::constant("a");
assert!(!occurs_check("x", &a));
}
#[test]
fn test_substitution_composition() {
let sigma = Substitution::singleton("x".to_string(), Term::constant("a"));
let theta = Substitution::singleton("y".to_string(), Term::var("x"));
let composed = sigma.compose(&theta);
assert_eq!(composed.len(), 2);
assert_eq!(composed.apply(&Term::var("x")), Term::constant("a"));
assert_eq!(composed.apply(&Term::var("y")), Term::constant("a"));
}
#[test]
fn test_unify_term_list() {
let pairs = vec![
(Term::var("x"), Term::constant("a")),
(Term::var("y"), Term::constant("b")),
(Term::var("z"), Term::var("x")),
];
let mgu = unify_term_list(&pairs).expect("unwrap");
assert_eq!(mgu.len(), 3);
assert_eq!(mgu.apply(&Term::var("x")), Term::constant("a"));
assert_eq!(mgu.apply(&Term::var("y")), Term::constant("b"));
assert_eq!(mgu.apply(&Term::var("z")), Term::constant("a"));
}
#[test]
fn test_are_unifiable() {
let x = Term::var("x");
let a = Term::constant("a");
let b = Term::constant("b");
assert!(are_unifiable(&x, &a));
assert!(are_unifiable(&a, &a));
assert!(!are_unifiable(&a, &b));
}
#[test]
fn test_rename_vars() {
let x = Term::var("x");
let renamed = rename_vars(&x, "1");
assert_eq!(renamed, Term::var("x_1"));
let a = Term::constant("a");
let renamed_const = rename_vars(&a, "1");
assert_eq!(renamed_const, a); }
#[test]
fn test_extend_substitution() {
let mut subst = Substitution::empty();
assert!(subst.extend("x".to_string(), Term::constant("a")).is_ok());
assert!(subst.extend("y".to_string(), Term::constant("b")).is_ok());
assert!(subst.extend("x".to_string(), Term::constant("a")).is_ok());
assert!(subst.extend("x".to_string(), Term::constant("b")).is_err());
}
#[test]
fn test_typed_term_unification() {
use crate::term::TypeAnnotation;
let x = Term::Typed {
value: Box::new(Term::var("x")),
type_annotation: TypeAnnotation::new("Int"),
};
let a = Term::Typed {
value: Box::new(Term::constant("5")),
type_annotation: TypeAnnotation::new("Int"),
};
let mgu = unify_terms(&x, &a).expect("unwrap");
assert_eq!(mgu.len(), 1);
}
#[test]
fn test_anti_unify_same_constant() {
let a1 = Term::constant("a");
let a2 = Term::constant("a");
let (gen, subst1, subst2) = anti_unify_terms(&a1, &a2);
assert_eq!(gen, a1);
assert!(subst1.is_empty());
assert!(subst2.is_empty());
}
#[test]
fn test_anti_unify_different_constants() {
let a = Term::constant("a");
let b = Term::constant("b");
let (gen, subst1, subst2) = anti_unify_terms(&a, &b);
match gen {
Term::Var(name) => assert!(name.starts_with("_G")),
_ => panic!("Expected fresh variable"),
}
assert_eq!(subst1.len(), 1);
assert_eq!(subst2.len(), 1);
}
#[test]
fn test_anti_unify_variable_constant() {
let x = Term::var("x");
let a = Term::constant("a");
let (gen, _subst1, _subst2) = anti_unify_terms(&x, &a);
if let Term::Var(name) = gen {
assert!(name == "x" || name.starts_with("_G"));
}
}
#[test]
fn test_anti_unify_same_variable() {
let x1 = Term::var("x");
let x2 = Term::var("x");
let (gen, subst1, subst2) = anti_unify_terms(&x1, &x2);
assert_eq!(gen, x1);
assert!(subst1.is_empty());
assert!(subst2.is_empty());
}
#[test]
fn test_anti_unify_typed_terms() {
use crate::term::TypeAnnotation;
let t1 = Term::Typed {
value: Box::new(Term::constant("5")),
type_annotation: TypeAnnotation::new("Int"),
};
let t2 = Term::Typed {
value: Box::new(Term::constant("10")),
type_annotation: TypeAnnotation::new("Int"),
};
let (gen, _subst1, _subst2) = anti_unify_terms(&t1, &t2);
match gen {
Term::Typed {
value,
type_annotation,
} => {
assert_eq!(type_annotation.type_name, "Int");
match *value {
Term::Var(name) => assert!(name.starts_with("_G")),
_ => panic!("Expected fresh variable inside typed term"),
}
}
_ => panic!("Expected typed term"),
}
}
#[test]
fn test_lgg_single_term() {
let terms = vec![Term::constant("a")];
let (gen, substs) = lgg_terms(&terms);
assert_eq!(gen, Term::constant("a"));
assert_eq!(substs.len(), 1);
assert!(substs[0].is_empty());
}
#[test]
fn test_lgg_two_same_terms() {
let terms = vec![Term::constant("a"), Term::constant("a")];
let (gen, substs) = lgg_terms(&terms);
assert_eq!(gen, Term::constant("a"));
assert_eq!(substs.len(), 2);
}
#[test]
fn test_lgg_two_different_terms() {
let terms = vec![Term::constant("a"), Term::constant("b")];
let (gen, substs) = lgg_terms(&terms);
match gen {
Term::Var(name) => assert!(name.starts_with("_G")),
_ => panic!("Expected fresh variable"),
}
assert_eq!(substs.len(), 2);
}
#[test]
fn test_lgg_three_terms() {
let terms = vec![
Term::constant("a"),
Term::constant("b"),
Term::constant("c"),
];
let (gen, substs) = lgg_terms(&terms);
match gen {
Term::Var(name) => assert!(name.starts_with("_G")),
_ => panic!("Expected fresh variable"),
}
assert_eq!(substs.len(), 3);
}
#[test]
fn test_lgg_empty() {
let terms: Vec<Term> = vec![];
let (gen, substs) = lgg_terms(&terms);
match gen {
Term::Var(name) => assert_eq!(name, "_Empty"),
_ => panic!("Expected _Empty variable"),
}
assert_eq!(substs.len(), 0);
}
#[test]
fn test_anti_unify_preserves_structure() {
let a1 = Term::constant("a");
let a2 = Term::constant("a");
let (gen, _, _) = anti_unify_terms(&a1, &a2);
assert_eq!(gen, Term::constant("a"));
}
}