prune-lang 0.2.2

Prune is a constraint logic programming language with branching heuristic.
Documentation
use super::term::*;
use itertools::Itertools;
use std::collections::{HashMap, HashSet};
use std::fmt;
use std::hash::Hash;

#[derive(Clone, Debug)]
pub enum UnifyError<V, L, C> {
    UnifyFailed(Term<V, L, C>, Term<V, L, C>),
    OccurCheckFailed(V, Term<V, L, C>),
    UnifyVecDiffLen(Vec<Term<V, L, C>>, Vec<Term<V, L, C>>),
}

use crate::cli::diagnostic::Diagnostic;
impl<V: fmt::Display, L: fmt::Display, C: fmt::Display> From<UnifyError<V, L, OptCons<C>>>
    for Diagnostic
{
    fn from(val: UnifyError<V, L, OptCons<C>>) -> Self {
        match val {
            UnifyError::UnifyFailed(lhs, rhs) => {
                Diagnostic::error(format!("Can not unify types: {lhs} and {rhs}!"))
            }
            UnifyError::OccurCheckFailed(x, typ) => {
                Diagnostic::error(format!("Occur check failed at variable: {x} in {typ}!"))
            }
            UnifyError::UnifyVecDiffLen(vec1, vec2) => {
                let vec1 = vec1.iter().format(", ");
                let vec2 = vec2.iter().format(", ");
                Diagnostic::error(format!(
                    "Unify vectors of different length: [{vec1}] and [{vec2}]!"
                ))
            }
        }
    }
}

#[derive(Debug)]
pub struct Unifier<V, L, C> {
    map: HashMap<V, Term<V, L, C>>,
    freshs: HashSet<V>,
}

impl<V: Eq + Hash + Clone, L, C> Default for Unifier<V, L, C> {
    fn default() -> Self {
        Self::new()
    }
}

impl<V: Eq + Hash + Clone, L, C> Unifier<V, L, C> {
    pub fn new() -> Unifier<V, L, C> {
        Unifier {
            map: HashMap::new(),
            freshs: HashSet::new(),
        }
    }

    pub fn is_empty(&self) -> bool {
        self.map.is_empty() && self.freshs.is_empty()
    }

    pub fn reset(&mut self) {
        self.map.clear();
    }
}

impl<V: Eq + Hash + Clone, L: PartialEq + Clone, C: Eq + Clone> Unifier<V, L, C> {
    pub fn deref<'a>(&'a self, term: &'a Term<V, L, C>) -> &'a Term<V, L, C> {
        let mut term = term;
        loop {
            if let Term::Var(var) = term {
                if let Some(term2) = self.map.get(var) {
                    term = term2;
                } else {
                    return term;
                }
            } else {
                return term;
            }
        }
    }

    pub fn subst_opt(&self, term: &Term<V, L, C>) -> Option<Term<V, L, C>> {
        let mut flag = false;
        let res = self.subst_opt_help(term, &mut flag);
        if flag { Some(res) } else { None }
    }

    fn subst_opt_help(&self, term: &Term<V, L, C>, flag: &mut bool) -> Term<V, L, C> {
        match term {
            Term::Var(var) => {
                if let Some(term) = self.map.get(var) {
                    *flag = true;
                    self.subst_opt_help(term, flag)
                } else {
                    Term::Var(var.clone())
                }
            }
            Term::Lit(lit) => Term::Lit(lit.clone()),
            Term::Cons(cons, flds) => {
                let flds = flds
                    .iter()
                    .map(|fld| self.subst_opt_help(fld, flag))
                    .collect();
                Term::Cons(cons.clone(), flds)
            }
        }
    }

    pub fn subst(&self, term: &Term<V, L, C>) -> Term<V, L, C> {
        match term {
            Term::Var(var) => {
                if let Some(term) = self.map.get(var) {
                    self.subst(term)
                } else {
                    Term::Var(var.clone())
                }
            }
            Term::Lit(lit) => Term::Lit(lit.clone()),
            Term::Cons(cons, flds) => {
                let flds = flds.iter().map(|fld| self.subst(fld)).collect();
                Term::Cons(cons.clone(), flds)
            }
        }
    }

    pub fn subst_err(&self, err: &UnifyError<V, L, C>) -> UnifyError<V, L, C> {
        match err {
            UnifyError::UnifyFailed(lhs, rhs) => {
                let lhs = self.subst(lhs);
                let rhs = self.subst(rhs);
                UnifyError::UnifyFailed(lhs, rhs)
            }
            UnifyError::OccurCheckFailed(x, typ) => {
                let typ = self.subst(typ);
                UnifyError::OccurCheckFailed(x.clone(), typ)
            }
            UnifyError::UnifyVecDiffLen(vec1, vec2) => {
                let vec1 = vec1.iter().map(|typ| self.subst(typ)).collect();
                let vec2 = vec2.iter().map(|typ| self.subst(typ)).collect();
                UnifyError::UnifyVecDiffLen(vec1, vec2)
            }
        }
    }

    fn occur_check(&self, x: &V, term: &Term<V, L, C>) -> bool {
        let term = self.deref(term);
        match term {
            Term::Var(y) => x == y,
            Term::Lit(_) => false,
            Term::Cons(_cons, flds) => flds.iter().any(|fld| self.occur_check(x, fld)),
        }
    }

    pub fn fresh(&mut self, var: V) {
        self.freshs.insert(var);
    }

    pub fn unify(
        &mut self,
        lhs: &Term<V, L, C>,
        rhs: &Term<V, L, C>,
    ) -> Result<(), UnifyError<V, L, C>> {
        let lhs = self.deref(lhs).clone();
        let rhs = self.deref(rhs).clone();
        match (&lhs, &rhs) {
            (Term::Var(x1), Term::Var(x2)) if x1 == x2 => Ok(()),
            (Term::Var(x), term) | (term, Term::Var(x)) if !self.freshs.contains(x) => {
                if self.occur_check(x, term) {
                    return Err(UnifyError::OccurCheckFailed(x.clone(), term.clone()));
                }
                self.map.insert(x.clone(), term.clone());
                Ok(())
            }
            (Term::Lit(lit1), Term::Lit(lit2)) => {
                if lit1 == lit2 {
                    Ok(())
                } else {
                    Err(UnifyError::UnifyFailed(lhs.clone(), rhs.clone()))
                }
            }
            (Term::Cons(cons1, flds1), Term::Cons(cons2, flds2)) => {
                if cons1 == cons2 {
                    self.unify_many(flds1, flds2)
                } else {
                    Err(UnifyError::UnifyFailed(lhs.clone(), rhs.clone()))
                }
            }
            (lhs, rhs) => Err(UnifyError::UnifyFailed(lhs.clone(), rhs.clone())),
        }
    }

    pub fn unify_many(
        &mut self,
        lhss: &[Term<V, L, C>],
        rhss: &[Term<V, L, C>],
    ) -> Result<(), UnifyError<V, L, C>> {
        if lhss.len() == rhss.len() {
            for (lhs, rhs) in lhss.iter().zip(rhss.iter()) {
                self.unify(lhs, rhs)?;
            }
            Ok(())
        } else {
            Err(UnifyError::UnifyVecDiffLen(lhss.to_vec(), rhss.to_vec()))
        }
    }
}