Skip to main content

prune_lang/utils/
unify.rs

1use super::term::*;
2use itertools::Itertools;
3use std::collections::{HashMap, HashSet};
4use std::fmt;
5use std::hash::Hash;
6
7#[derive(Clone, Debug)]
8pub enum UnifyError<V, L, C> {
9    UnifyFailed(Term<V, L, C>, Term<V, L, C>),
10    OccurCheckFailed(V, Term<V, L, C>),
11    UnifyVecDiffLen(Vec<Term<V, L, C>>, Vec<Term<V, L, C>>),
12}
13
14use crate::cli::diagnostic::Diagnostic;
15impl<V: fmt::Display, L: fmt::Display, C: fmt::Display> From<UnifyError<V, L, OptCons<C>>>
16    for Diagnostic
17{
18    fn from(val: UnifyError<V, L, OptCons<C>>) -> Self {
19        match val {
20            UnifyError::UnifyFailed(lhs, rhs) => {
21                Diagnostic::error(format!("Can not unify types: {} and {}!", lhs, rhs))
22            }
23            UnifyError::OccurCheckFailed(x, typ) => {
24                Diagnostic::error(format!("Occur check failed at variable: {} in {}!", x, typ))
25            }
26            UnifyError::UnifyVecDiffLen(vec1, vec2) => {
27                let vec1 = vec1.iter().format(", ");
28                let vec2 = vec2.iter().format(", ");
29                Diagnostic::error(format!(
30                    "Unify vectors of different length: [{}] and [{}]!",
31                    vec1, vec2
32                ))
33            }
34        }
35    }
36}
37
38#[derive(Debug)]
39pub struct Unifier<V, L, C> {
40    map: HashMap<V, Term<V, L, C>>,
41    freshs: HashSet<V>,
42}
43
44impl<V: Eq + Hash + Clone, L, C> Default for Unifier<V, L, C> {
45    fn default() -> Self {
46        Self::new()
47    }
48}
49
50impl<V: Eq + Hash + Clone, L, C> Unifier<V, L, C> {
51    pub fn new() -> Unifier<V, L, C> {
52        Unifier {
53            map: HashMap::new(),
54            freshs: HashSet::new(),
55        }
56    }
57
58    pub fn is_empty(&self) -> bool {
59        self.map.is_empty() && self.freshs.is_empty()
60    }
61
62    pub fn reset(&mut self) {
63        self.map.clear();
64    }
65}
66
67impl<V: Eq + Hash + Clone, L: PartialEq + Clone, C: Eq + Clone> Unifier<V, L, C> {
68    pub fn deref<'a>(&'a self, term: &'a Term<V, L, C>) -> &'a Term<V, L, C> {
69        let mut term = term;
70        loop {
71            if let Term::Var(var) = term {
72                if let Some(term2) = self.map.get(var) {
73                    term = term2;
74                    continue;
75                } else {
76                    return term;
77                }
78            } else {
79                return term;
80            }
81        }
82    }
83
84    pub fn subst_opt(&self, term: &Term<V, L, C>) -> Option<Term<V, L, C>> {
85        let mut flag = false;
86        let res = self.subst_opt_help(term, &mut flag);
87        if flag { Some(res) } else { None }
88    }
89
90    fn subst_opt_help(&self, term: &Term<V, L, C>, flag: &mut bool) -> Term<V, L, C> {
91        match term {
92            Term::Var(var) => {
93                if let Some(term) = self.map.get(var) {
94                    *flag = true;
95                    self.subst_opt_help(term, flag)
96                } else {
97                    Term::Var(var.clone())
98                }
99            }
100            Term::Lit(lit) => Term::Lit(lit.clone()),
101            Term::Cons(cons, flds) => {
102                let flds = flds
103                    .iter()
104                    .map(|fld| self.subst_opt_help(fld, flag))
105                    .collect();
106                Term::Cons(cons.clone(), flds)
107            }
108        }
109    }
110
111    pub fn subst(&self, term: &Term<V, L, C>) -> Term<V, L, C> {
112        match term {
113            Term::Var(var) => {
114                if let Some(term) = self.map.get(var) {
115                    self.subst(term)
116                } else {
117                    Term::Var(var.clone())
118                }
119            }
120            Term::Lit(lit) => Term::Lit(lit.clone()),
121            Term::Cons(cons, flds) => {
122                let flds = flds.iter().map(|fld| self.subst(fld)).collect();
123                Term::Cons(cons.clone(), flds)
124            }
125        }
126    }
127
128    pub fn subst_err(&self, err: &UnifyError<V, L, C>) -> UnifyError<V, L, C> {
129        match err {
130            UnifyError::UnifyFailed(lhs, rhs) => {
131                let lhs = self.subst(lhs);
132                let rhs = self.subst(rhs);
133                UnifyError::UnifyFailed(lhs, rhs)
134            }
135            UnifyError::OccurCheckFailed(x, typ) => {
136                let typ = self.subst(typ);
137                UnifyError::OccurCheckFailed(x.clone(), typ)
138            }
139            UnifyError::UnifyVecDiffLen(vec1, vec2) => {
140                let vec1 = vec1.iter().map(|typ| self.subst(typ)).collect();
141                let vec2 = vec2.iter().map(|typ| self.subst(typ)).collect();
142                UnifyError::UnifyVecDiffLen(vec1, vec2)
143            }
144        }
145    }
146
147    fn occur_check(&self, x: &V, term: &Term<V, L, C>) -> bool {
148        let term = self.deref(term);
149        match term {
150            Term::Var(y) => x == y,
151            Term::Lit(_) => false,
152            Term::Cons(_cons, flds) => flds.iter().any(|fld| self.occur_check(x, fld)),
153        }
154    }
155
156    pub fn fresh(&mut self, var: V) {
157        self.freshs.insert(var);
158    }
159
160    pub fn unify(
161        &mut self,
162        lhs: &Term<V, L, C>,
163        rhs: &Term<V, L, C>,
164    ) -> Result<(), UnifyError<V, L, C>> {
165        let lhs = self.deref(lhs).clone();
166        let rhs = self.deref(rhs).clone();
167        match (&lhs, &rhs) {
168            (Term::Var(x1), Term::Var(x2)) if x1 == x2 => Ok(()),
169            (Term::Var(x), term) | (term, Term::Var(x)) if !self.freshs.contains(x) => {
170                if self.occur_check(x, term) {
171                    return Err(UnifyError::OccurCheckFailed(x.clone(), term.clone()));
172                }
173                self.map.insert(x.clone(), term.clone());
174                Ok(())
175            }
176            (Term::Lit(lit1), Term::Lit(lit2)) => {
177                if lit1 == lit2 {
178                    Ok(())
179                } else {
180                    Err(UnifyError::UnifyFailed(lhs.clone(), rhs.clone()))
181                }
182            }
183            (Term::Cons(cons1, flds1), Term::Cons(cons2, flds2)) => {
184                if cons1 == cons2 {
185                    self.unify_many(flds1, flds2)
186                } else {
187                    Err(UnifyError::UnifyFailed(lhs.clone(), rhs.clone()))
188                }
189            }
190            (lhs, rhs) => Err(UnifyError::UnifyFailed(lhs.clone(), rhs.clone())),
191        }
192    }
193
194    pub fn unify_many(
195        &mut self,
196        lhss: &[Term<V, L, C>],
197        rhss: &[Term<V, L, C>],
198    ) -> Result<(), UnifyError<V, L, C>> {
199        if lhss.len() == rhss.len() {
200            for (lhs, rhs) in lhss.iter().zip(rhss.iter()) {
201                self.unify(lhs, rhs)?;
202            }
203            Ok(())
204        } else {
205            Err(UnifyError::UnifyVecDiffLen(lhss.to_vec(), rhss.to_vec()))
206        }
207    }
208}