1use super::term::*;
2use itertools::Itertools;
3use std::collections::{HashMap, HashSet};
4use std::fmt;
5use std::hash::Hash;
6
7#[derive(Clone, Debug)]
8pub enum UnifyError<V, L, C> {
9 UnifyFailed(Term<V, L, C>, Term<V, L, C>),
10 OccurCheckFailed(V, Term<V, L, C>),
11 UnifyVecDiffLen(Vec<Term<V, L, C>>, Vec<Term<V, L, C>>),
12}
13
14use crate::cli::diagnostic::Diagnostic;
15impl<V: fmt::Display, L: fmt::Display, C: fmt::Display> From<UnifyError<V, L, OptCons<C>>>
16 for Diagnostic
17{
18 fn from(val: UnifyError<V, L, OptCons<C>>) -> Self {
19 match val {
20 UnifyError::UnifyFailed(lhs, rhs) => {
21 Diagnostic::error(format!("Can not unify types: {} and {}!", lhs, rhs))
22 }
23 UnifyError::OccurCheckFailed(x, typ) => {
24 Diagnostic::error(format!("Occur check failed at variable: {} in {}!", x, typ))
25 }
26 UnifyError::UnifyVecDiffLen(vec1, vec2) => {
27 let vec1 = vec1.iter().format(", ");
28 let vec2 = vec2.iter().format(", ");
29 Diagnostic::error(format!(
30 "Unify vectors of different length: [{}] and [{}]!",
31 vec1, vec2
32 ))
33 }
34 }
35 }
36}
37
38#[derive(Debug)]
39pub struct Unifier<V, L, C> {
40 map: HashMap<V, Term<V, L, C>>,
41 freshs: HashSet<V>,
42}
43
44impl<V: Eq + Hash + Clone, L, C> Default for Unifier<V, L, C> {
45 fn default() -> Self {
46 Self::new()
47 }
48}
49
50impl<V: Eq + Hash + Clone, L, C> Unifier<V, L, C> {
51 pub fn new() -> Unifier<V, L, C> {
52 Unifier {
53 map: HashMap::new(),
54 freshs: HashSet::new(),
55 }
56 }
57
58 pub fn is_empty(&self) -> bool {
59 self.map.is_empty() && self.freshs.is_empty()
60 }
61
62 pub fn reset(&mut self) {
63 self.map.clear();
64 }
65}
66
67impl<V: Eq + Hash + Clone, L: PartialEq + Clone, C: Eq + Clone> Unifier<V, L, C> {
68 pub fn deref<'a>(&'a self, term: &'a Term<V, L, C>) -> &'a Term<V, L, C> {
69 let mut term = term;
70 loop {
71 if let Term::Var(var) = term {
72 if let Some(term2) = self.map.get(var) {
73 term = term2;
74 continue;
75 } else {
76 return term;
77 }
78 } else {
79 return term;
80 }
81 }
82 }
83
84 pub fn subst_opt(&self, term: &Term<V, L, C>) -> Option<Term<V, L, C>> {
85 let mut flag = false;
86 let res = self.subst_opt_help(term, &mut flag);
87 if flag { Some(res) } else { None }
88 }
89
90 fn subst_opt_help(&self, term: &Term<V, L, C>, flag: &mut bool) -> Term<V, L, C> {
91 match term {
92 Term::Var(var) => {
93 if let Some(term) = self.map.get(var) {
94 *flag = true;
95 self.subst_opt_help(term, flag)
96 } else {
97 Term::Var(var.clone())
98 }
99 }
100 Term::Lit(lit) => Term::Lit(lit.clone()),
101 Term::Cons(cons, flds) => {
102 let flds = flds
103 .iter()
104 .map(|fld| self.subst_opt_help(fld, flag))
105 .collect();
106 Term::Cons(cons.clone(), flds)
107 }
108 }
109 }
110
111 pub fn subst(&self, term: &Term<V, L, C>) -> Term<V, L, C> {
112 match term {
113 Term::Var(var) => {
114 if let Some(term) = self.map.get(var) {
115 self.subst(term)
116 } else {
117 Term::Var(var.clone())
118 }
119 }
120 Term::Lit(lit) => Term::Lit(lit.clone()),
121 Term::Cons(cons, flds) => {
122 let flds = flds.iter().map(|fld| self.subst(fld)).collect();
123 Term::Cons(cons.clone(), flds)
124 }
125 }
126 }
127
128 pub fn subst_err(&self, err: &UnifyError<V, L, C>) -> UnifyError<V, L, C> {
129 match err {
130 UnifyError::UnifyFailed(lhs, rhs) => {
131 let lhs = self.subst(lhs);
132 let rhs = self.subst(rhs);
133 UnifyError::UnifyFailed(lhs, rhs)
134 }
135 UnifyError::OccurCheckFailed(x, typ) => {
136 let typ = self.subst(typ);
137 UnifyError::OccurCheckFailed(x.clone(), typ)
138 }
139 UnifyError::UnifyVecDiffLen(vec1, vec2) => {
140 let vec1 = vec1.iter().map(|typ| self.subst(typ)).collect();
141 let vec2 = vec2.iter().map(|typ| self.subst(typ)).collect();
142 UnifyError::UnifyVecDiffLen(vec1, vec2)
143 }
144 }
145 }
146
147 fn occur_check(&self, x: &V, term: &Term<V, L, C>) -> bool {
148 let term = self.deref(term);
149 match term {
150 Term::Var(y) => x == y,
151 Term::Lit(_) => false,
152 Term::Cons(_cons, flds) => flds.iter().any(|fld| self.occur_check(x, fld)),
153 }
154 }
155
156 pub fn fresh(&mut self, var: V) {
157 self.freshs.insert(var);
158 }
159
160 pub fn unify(
161 &mut self,
162 lhs: &Term<V, L, C>,
163 rhs: &Term<V, L, C>,
164 ) -> Result<(), UnifyError<V, L, C>> {
165 let lhs = self.deref(lhs).clone();
166 let rhs = self.deref(rhs).clone();
167 match (&lhs, &rhs) {
168 (Term::Var(x1), Term::Var(x2)) if x1 == x2 => Ok(()),
169 (Term::Var(x), term) | (term, Term::Var(x)) if !self.freshs.contains(x) => {
170 if self.occur_check(x, term) {
171 return Err(UnifyError::OccurCheckFailed(x.clone(), term.clone()));
172 }
173 self.map.insert(x.clone(), term.clone());
174 Ok(())
175 }
176 (Term::Lit(lit1), Term::Lit(lit2)) => {
177 if lit1 == lit2 {
178 Ok(())
179 } else {
180 Err(UnifyError::UnifyFailed(lhs.clone(), rhs.clone()))
181 }
182 }
183 (Term::Cons(cons1, flds1), Term::Cons(cons2, flds2)) => {
184 if cons1 == cons2 {
185 self.unify_many(flds1, flds2)
186 } else {
187 Err(UnifyError::UnifyFailed(lhs.clone(), rhs.clone()))
188 }
189 }
190 (lhs, rhs) => Err(UnifyError::UnifyFailed(lhs.clone(), rhs.clone())),
191 }
192 }
193
194 pub fn unify_many(
195 &mut self,
196 lhss: &[Term<V, L, C>],
197 rhss: &[Term<V, L, C>],
198 ) -> Result<(), UnifyError<V, L, C>> {
199 if lhss.len() == rhss.len() {
200 for (lhs, rhs) in lhss.iter().zip(rhss.iter()) {
201 self.unify(lhs, rhs)?;
202 }
203 Ok(())
204 } else {
205 Err(UnifyError::UnifyVecDiffLen(lhss.to_vec(), rhss.to_vec()))
206 }
207 }
208}