Skip to main content

prune_lang/utils/
unify.rs

1use super::term::*;
2use itertools::Itertools;
3use std::collections::{HashMap, HashSet};
4use std::fmt;
5use std::hash::Hash;
6
7#[derive(Clone, Debug)]
8pub enum UnifyError<V, L, C> {
9    UnifyFailed(Term<V, L, C>, Term<V, L, C>),
10    OccurCheckFailed(V, Term<V, L, C>),
11    UnifyVecDiffLen(Vec<Term<V, L, C>>, Vec<Term<V, L, C>>),
12}
13
14use crate::cli::diagnostic::Diagnostic;
15impl<V: fmt::Display, L: fmt::Display, C: fmt::Display> From<UnifyError<V, L, OptCons<C>>>
16    for Diagnostic
17{
18    fn from(val: UnifyError<V, L, OptCons<C>>) -> Self {
19        match val {
20            UnifyError::UnifyFailed(lhs, rhs) => {
21                Diagnostic::error(format!("Can not unify types: {lhs} and {rhs}!"))
22            }
23            UnifyError::OccurCheckFailed(x, typ) => {
24                Diagnostic::error(format!("Occur check failed at variable: {x} in {typ}!"))
25            }
26            UnifyError::UnifyVecDiffLen(vec1, vec2) => {
27                let vec1 = vec1.iter().format(", ");
28                let vec2 = vec2.iter().format(", ");
29                Diagnostic::error(format!(
30                    "Unify vectors of different length: [{vec1}] and [{vec2}]!"
31                ))
32            }
33        }
34    }
35}
36
37#[derive(Debug)]
38pub struct Unifier<V, L, C> {
39    map: HashMap<V, Term<V, L, C>>,
40    freshs: HashSet<V>,
41}
42
43impl<V: Eq + Hash + Clone, L, C> Default for Unifier<V, L, C> {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49impl<V: Eq + Hash + Clone, L, C> Unifier<V, L, C> {
50    pub fn new() -> Unifier<V, L, C> {
51        Unifier {
52            map: HashMap::new(),
53            freshs: HashSet::new(),
54        }
55    }
56
57    pub fn is_empty(&self) -> bool {
58        self.map.is_empty() && self.freshs.is_empty()
59    }
60
61    pub fn reset(&mut self) {
62        self.map.clear();
63    }
64}
65
66impl<V: Eq + Hash + Clone, L: PartialEq + Clone, C: Eq + Clone> Unifier<V, L, C> {
67    pub fn deref<'a>(&'a self, term: &'a Term<V, L, C>) -> &'a Term<V, L, C> {
68        let mut term = term;
69        loop {
70            if let Term::Var(var) = term {
71                if let Some(term2) = self.map.get(var) {
72                    term = term2;
73                } else {
74                    return term;
75                }
76            } else {
77                return term;
78            }
79        }
80    }
81
82    pub fn subst_opt(&self, term: &Term<V, L, C>) -> Option<Term<V, L, C>> {
83        let mut flag = false;
84        let res = self.subst_opt_help(term, &mut flag);
85        if flag { Some(res) } else { None }
86    }
87
88    fn subst_opt_help(&self, term: &Term<V, L, C>, flag: &mut bool) -> Term<V, L, C> {
89        match term {
90            Term::Var(var) => {
91                if let Some(term) = self.map.get(var) {
92                    *flag = true;
93                    self.subst_opt_help(term, flag)
94                } else {
95                    Term::Var(var.clone())
96                }
97            }
98            Term::Lit(lit) => Term::Lit(lit.clone()),
99            Term::Cons(cons, flds) => {
100                let flds = flds
101                    .iter()
102                    .map(|fld| self.subst_opt_help(fld, flag))
103                    .collect();
104                Term::Cons(cons.clone(), flds)
105            }
106        }
107    }
108
109    pub fn subst(&self, term: &Term<V, L, C>) -> Term<V, L, C> {
110        match term {
111            Term::Var(var) => {
112                if let Some(term) = self.map.get(var) {
113                    self.subst(term)
114                } else {
115                    Term::Var(var.clone())
116                }
117            }
118            Term::Lit(lit) => Term::Lit(lit.clone()),
119            Term::Cons(cons, flds) => {
120                let flds = flds.iter().map(|fld| self.subst(fld)).collect();
121                Term::Cons(cons.clone(), flds)
122            }
123        }
124    }
125
126    pub fn subst_err(&self, err: &UnifyError<V, L, C>) -> UnifyError<V, L, C> {
127        match err {
128            UnifyError::UnifyFailed(lhs, rhs) => {
129                let lhs = self.subst(lhs);
130                let rhs = self.subst(rhs);
131                UnifyError::UnifyFailed(lhs, rhs)
132            }
133            UnifyError::OccurCheckFailed(x, typ) => {
134                let typ = self.subst(typ);
135                UnifyError::OccurCheckFailed(x.clone(), typ)
136            }
137            UnifyError::UnifyVecDiffLen(vec1, vec2) => {
138                let vec1 = vec1.iter().map(|typ| self.subst(typ)).collect();
139                let vec2 = vec2.iter().map(|typ| self.subst(typ)).collect();
140                UnifyError::UnifyVecDiffLen(vec1, vec2)
141            }
142        }
143    }
144
145    fn occur_check(&self, x: &V, term: &Term<V, L, C>) -> bool {
146        let term = self.deref(term);
147        match term {
148            Term::Var(y) => x == y,
149            Term::Lit(_) => false,
150            Term::Cons(_cons, flds) => flds.iter().any(|fld| self.occur_check(x, fld)),
151        }
152    }
153
154    pub fn fresh(&mut self, var: V) {
155        self.freshs.insert(var);
156    }
157
158    pub fn unify(
159        &mut self,
160        lhs: &Term<V, L, C>,
161        rhs: &Term<V, L, C>,
162    ) -> Result<(), UnifyError<V, L, C>> {
163        let lhs = self.deref(lhs).clone();
164        let rhs = self.deref(rhs).clone();
165        match (&lhs, &rhs) {
166            (Term::Var(x1), Term::Var(x2)) if x1 == x2 => Ok(()),
167            (Term::Var(x), term) | (term, Term::Var(x)) if !self.freshs.contains(x) => {
168                if self.occur_check(x, term) {
169                    return Err(UnifyError::OccurCheckFailed(x.clone(), term.clone()));
170                }
171                self.map.insert(x.clone(), term.clone());
172                Ok(())
173            }
174            (Term::Lit(lit1), Term::Lit(lit2)) => {
175                if lit1 == lit2 {
176                    Ok(())
177                } else {
178                    Err(UnifyError::UnifyFailed(lhs.clone(), rhs.clone()))
179                }
180            }
181            (Term::Cons(cons1, flds1), Term::Cons(cons2, flds2)) => {
182                if cons1 == cons2 {
183                    self.unify_many(flds1, flds2)
184                } else {
185                    Err(UnifyError::UnifyFailed(lhs.clone(), rhs.clone()))
186                }
187            }
188            (lhs, rhs) => Err(UnifyError::UnifyFailed(lhs.clone(), rhs.clone())),
189        }
190    }
191
192    pub fn unify_many(
193        &mut self,
194        lhss: &[Term<V, L, C>],
195        rhss: &[Term<V, L, C>],
196    ) -> Result<(), UnifyError<V, L, C>> {
197        if lhss.len() == rhss.len() {
198            for (lhs, rhs) in lhss.iter().zip(rhss.iter()) {
199                self.unify(lhs, rhs)?;
200            }
201            Ok(())
202        } else {
203            Err(UnifyError::UnifyVecDiffLen(lhss.to_vec(), rhss.to_vec()))
204        }
205    }
206}