1use super::term::*;
2use itertools::Itertools;
3use std::collections::{HashMap, HashSet};
4use std::fmt;
5use std::hash::Hash;
6
7#[derive(Clone, Debug)]
8pub enum UnifyError<V, L, C> {
9 UnifyFailed(Term<V, L, C>, Term<V, L, C>),
10 OccurCheckFailed(V, Term<V, L, C>),
11 UnifyVecDiffLen(Vec<Term<V, L, C>>, Vec<Term<V, L, C>>),
12}
13
14use crate::cli::diagnostic::Diagnostic;
15impl<V: fmt::Display, L: fmt::Display, C: fmt::Display> From<UnifyError<V, L, OptCons<C>>>
16 for Diagnostic
17{
18 fn from(val: UnifyError<V, L, OptCons<C>>) -> Self {
19 match val {
20 UnifyError::UnifyFailed(lhs, rhs) => {
21 Diagnostic::error(format!("Can not unify types: {} and {}!", lhs, rhs))
22 }
23 UnifyError::OccurCheckFailed(x, typ) => {
24 Diagnostic::error(format!("Occur check failed at variable: {} in {}!", x, typ))
25 }
26 UnifyError::UnifyVecDiffLen(vec1, vec2) => {
27 let vec1 = vec1.iter().format(", ");
28 let vec2 = vec2.iter().format(", ");
29 Diagnostic::error(format!(
30 "Unify vectors of different length: [{}] and [{}]!",
31 vec1, vec2
32 ))
33 }
34 }
35 }
36}
37
38#[derive(Debug)]
39pub struct Unifier<V, L, C> {
40 map: HashMap<V, Term<V, L, C>>,
41 freshs: HashSet<V>,
42}
43
44impl<V: Eq + Hash + Clone, L, C> Default for Unifier<V, L, C> {
45 fn default() -> Self {
46 Self::new()
47 }
48}
49
50impl<V: Eq + Hash + Clone, L, C> Unifier<V, L, C> {
51 pub fn new() -> Unifier<V, L, C> {
52 Unifier {
53 map: HashMap::new(),
54 freshs: HashSet::new(),
55 }
56 }
57
58 pub fn is_empty(&self) -> bool {
59 self.map.is_empty() && self.freshs.is_empty()
60 }
61
62 pub fn reset(&mut self) {
63 self.map.clear();
64 }
65}
66
67impl<V: Eq + Hash + Clone, L: Eq + Clone, C: Eq + Clone> Unifier<V, L, C> {
68 pub fn deref<'a>(&'a self, term: &'a Term<V, L, C>) -> &'a Term<V, L, C> {
69 let mut term = term;
70 loop {
71 if let Term::Var(var) = term {
72 if let Some(term2) = self.map.get(var) {
73 term = term2;
74 continue;
75 } else {
76 return term;
77 }
78 } else {
79 return term;
80 }
81 }
82 }
83
84 pub fn merge(&self, term: &Term<V, L, C>) -> Term<V, L, C> {
85 match term {
86 Term::Var(var) => {
87 if let Some(term) = self.map.get(var) {
88 self.merge(term)
89 } else {
90 Term::Var(var.clone())
91 }
92 }
93 Term::Lit(lit) => Term::Lit(lit.clone()),
94 Term::Cons(cons, flds) => {
95 let flds = flds.iter().map(|fld| self.merge(fld)).collect();
96 Term::Cons(cons.clone(), flds)
97 }
98 }
99 }
100
101 pub fn merge_err(&self, err: &UnifyError<V, L, C>) -> UnifyError<V, L, C> {
102 match err {
103 UnifyError::UnifyFailed(lhs, rhs) => {
104 let lhs = self.merge(lhs);
105 let rhs = self.merge(rhs);
106 UnifyError::UnifyFailed(lhs, rhs)
107 }
108 UnifyError::OccurCheckFailed(x, typ) => {
109 let typ = self.merge(typ);
110 UnifyError::OccurCheckFailed(x.clone(), typ)
111 }
112 UnifyError::UnifyVecDiffLen(vec1, vec2) => {
113 let vec1 = vec1.iter().map(|typ| self.merge(typ)).collect();
114 let vec2 = vec2.iter().map(|typ| self.merge(typ)).collect();
115 UnifyError::UnifyVecDiffLen(vec1, vec2)
116 }
117 }
118 }
119
120 fn occur_check(&self, x: &V, term: &Term<V, L, C>) -> bool {
121 let term = self.deref(term);
122 match term {
123 Term::Var(y) => x == y,
124 Term::Lit(_) => false,
125 Term::Cons(_cons, flds) => flds.iter().any(|fld| self.occur_check(x, fld)),
126 }
127 }
128
129 pub fn fresh(&mut self, var: V) {
130 self.freshs.insert(var);
131 }
132
133 pub fn unify(
134 &mut self,
135 lhs: &Term<V, L, C>,
136 rhs: &Term<V, L, C>,
137 ) -> Result<(), UnifyError<V, L, C>> {
138 let lhs = self.deref(lhs).clone();
139 let rhs = self.deref(rhs).clone();
140 match (&lhs, &rhs) {
141 (Term::Var(x1), Term::Var(x2)) if x1 == x2 => Ok(()),
142 (Term::Var(x), term) | (term, Term::Var(x)) if !self.freshs.contains(x) => {
143 if self.occur_check(x, term) {
144 return Err(UnifyError::OccurCheckFailed(x.clone(), term.clone()));
145 }
146 self.map.insert(x.clone(), term.clone());
147 Ok(())
148 }
149 (Term::Lit(lit1), Term::Lit(lit2)) => {
150 if lit1 == lit2 {
151 Ok(())
152 } else {
153 Err(UnifyError::UnifyFailed(lhs.clone(), rhs.clone()))
154 }
155 }
156 (Term::Cons(cons1, flds1), Term::Cons(cons2, flds2)) => {
157 if cons1 == cons2 {
158 self.unify_many(flds1, flds2)
159 } else {
160 Err(UnifyError::UnifyFailed(lhs.clone(), rhs.clone()))
161 }
162 }
163 (lhs, rhs) => Err(UnifyError::UnifyFailed(lhs.clone(), rhs.clone())),
164 }
165 }
166
167 pub fn unify_many(
168 &mut self,
169 lhss: &[Term<V, L, C>],
170 rhss: &[Term<V, L, C>],
171 ) -> Result<(), UnifyError<V, L, C>> {
172 if lhss.len() == rhss.len() {
173 for (lhs, rhs) in lhss.iter().zip(rhss.iter()) {
174 self.unify(lhs, rhs)?;
175 }
176 Ok(())
177 } else {
178 Err(UnifyError::UnifyVecDiffLen(lhss.to_vec(), rhss.to_vec()))
179 }
180 }
181}