prune_lang/utils/
unify.rs

1use super::term::*;
2use itertools::Itertools;
3use std::collections::{HashMap, HashSet};
4use std::fmt;
5use std::hash::Hash;
6
7#[derive(Clone, Debug)]
8pub enum UnifyError<V, L, C> {
9    UnifyFailed(Term<V, L, C>, Term<V, L, C>),
10    OccurCheckFailed(V, Term<V, L, C>),
11    UnifyVecDiffLen(Vec<Term<V, L, C>>, Vec<Term<V, L, C>>),
12}
13
14use crate::cli::diagnostic::Diagnostic;
15impl<V: fmt::Display, L: fmt::Display, C: fmt::Display> From<UnifyError<V, L, OptCons<C>>>
16    for Diagnostic
17{
18    fn from(val: UnifyError<V, L, OptCons<C>>) -> Self {
19        match val {
20            UnifyError::UnifyFailed(lhs, rhs) => {
21                Diagnostic::error(format!("Can not unify types: {} and {}!", lhs, rhs))
22            }
23            UnifyError::OccurCheckFailed(x, typ) => {
24                Diagnostic::error(format!("Occur check failed at variable: {} in {}!", x, typ))
25            }
26            UnifyError::UnifyVecDiffLen(vec1, vec2) => {
27                let vec1 = vec1.iter().format(", ");
28                let vec2 = vec2.iter().format(", ");
29                Diagnostic::error(format!(
30                    "Unify vectors of different length: [{}] and [{}]!",
31                    vec1, vec2
32                ))
33            }
34        }
35    }
36}
37
38#[derive(Debug)]
39pub struct Unifier<V, L, C> {
40    map: HashMap<V, Term<V, L, C>>,
41    freshs: HashSet<V>,
42}
43
44impl<V: Eq + Hash + Clone, L, C> Default for Unifier<V, L, C> {
45    fn default() -> Self {
46        Self::new()
47    }
48}
49
50impl<V: Eq + Hash + Clone, L, C> Unifier<V, L, C> {
51    pub fn new() -> Unifier<V, L, C> {
52        Unifier {
53            map: HashMap::new(),
54            freshs: HashSet::new(),
55        }
56    }
57
58    pub fn is_empty(&self) -> bool {
59        self.map.is_empty() && self.freshs.is_empty()
60    }
61
62    pub fn reset(&mut self) {
63        self.map.clear();
64    }
65}
66
67impl<V: Eq + Hash + Clone, L: Eq + Clone, C: Eq + Clone> Unifier<V, L, C> {
68    pub fn deref<'a>(&'a self, term: &'a Term<V, L, C>) -> &'a Term<V, L, C> {
69        let mut term = term;
70        loop {
71            if let Term::Var(var) = term {
72                if let Some(term2) = self.map.get(var) {
73                    term = term2;
74                    continue;
75                } else {
76                    return term;
77                }
78            } else {
79                return term;
80            }
81        }
82    }
83
84    pub fn merge(&self, term: &Term<V, L, C>) -> Term<V, L, C> {
85        match term {
86            Term::Var(var) => {
87                if let Some(term) = self.map.get(var) {
88                    self.merge(term)
89                } else {
90                    Term::Var(var.clone())
91                }
92            }
93            Term::Lit(lit) => Term::Lit(lit.clone()),
94            Term::Cons(cons, flds) => {
95                let flds = flds.iter().map(|fld| self.merge(fld)).collect();
96                Term::Cons(cons.clone(), flds)
97            }
98        }
99    }
100
101    pub fn merge_err(&self, err: &UnifyError<V, L, C>) -> UnifyError<V, L, C> {
102        match err {
103            UnifyError::UnifyFailed(lhs, rhs) => {
104                let lhs = self.merge(lhs);
105                let rhs = self.merge(rhs);
106                UnifyError::UnifyFailed(lhs, rhs)
107            }
108            UnifyError::OccurCheckFailed(x, typ) => {
109                let typ = self.merge(typ);
110                UnifyError::OccurCheckFailed(x.clone(), typ)
111            }
112            UnifyError::UnifyVecDiffLen(vec1, vec2) => {
113                let vec1 = vec1.iter().map(|typ| self.merge(typ)).collect();
114                let vec2 = vec2.iter().map(|typ| self.merge(typ)).collect();
115                UnifyError::UnifyVecDiffLen(vec1, vec2)
116            }
117        }
118    }
119
120    fn occur_check(&self, x: &V, term: &Term<V, L, C>) -> bool {
121        let term = self.deref(term);
122        match term {
123            Term::Var(y) => x == y,
124            Term::Lit(_) => false,
125            Term::Cons(_cons, flds) => flds.iter().any(|fld| self.occur_check(x, fld)),
126        }
127    }
128
129    pub fn fresh(&mut self, var: V) {
130        self.freshs.insert(var);
131    }
132
133    pub fn unify(
134        &mut self,
135        lhs: &Term<V, L, C>,
136        rhs: &Term<V, L, C>,
137    ) -> Result<(), UnifyError<V, L, C>> {
138        let lhs = self.deref(lhs).clone();
139        let rhs = self.deref(rhs).clone();
140        match (&lhs, &rhs) {
141            (Term::Var(x1), Term::Var(x2)) if x1 == x2 => Ok(()),
142            (Term::Var(x), term) | (term, Term::Var(x)) if !self.freshs.contains(x) => {
143                if self.occur_check(x, term) {
144                    return Err(UnifyError::OccurCheckFailed(x.clone(), term.clone()));
145                }
146                self.map.insert(x.clone(), term.clone());
147                Ok(())
148            }
149            (Term::Lit(lit1), Term::Lit(lit2)) => {
150                if lit1 == lit2 {
151                    Ok(())
152                } else {
153                    Err(UnifyError::UnifyFailed(lhs.clone(), rhs.clone()))
154                }
155            }
156            (Term::Cons(cons1, flds1), Term::Cons(cons2, flds2)) => {
157                if cons1 == cons2 {
158                    self.unify_many(flds1, flds2)
159                } else {
160                    Err(UnifyError::UnifyFailed(lhs.clone(), rhs.clone()))
161                }
162            }
163            (lhs, rhs) => Err(UnifyError::UnifyFailed(lhs.clone(), rhs.clone())),
164        }
165    }
166
167    pub fn unify_many(
168        &mut self,
169        lhss: &[Term<V, L, C>],
170        rhss: &[Term<V, L, C>],
171    ) -> Result<(), UnifyError<V, L, C>> {
172        if lhss.len() == rhss.len() {
173            for (lhs, rhs) in lhss.iter().zip(rhss.iter()) {
174                self.unify(lhs, rhs)?;
175            }
176            Ok(())
177        } else {
178            Err(UnifyError::UnifyVecDiffLen(lhss.to_vec(), rhss.to_vec()))
179        }
180    }
181}