1use super::term::*;
2use itertools::Itertools;
3use std::collections::{HashMap, HashSet};
4use std::fmt;
5use std::hash::Hash;
6
7#[derive(Clone, Debug)]
8pub enum UnifyError<V, L, C> {
9 UnifyFailed(Term<V, L, C>, Term<V, L, C>),
10 OccurCheckFailed(V, Term<V, L, C>),
11 UnifyVecDiffLen(Vec<Term<V, L, C>>, Vec<Term<V, L, C>>),
12}
13
14use crate::cli::diagnostic::Diagnostic;
15impl<V: fmt::Display, L: fmt::Display, C: fmt::Display> From<UnifyError<V, L, OptCons<C>>>
16 for Diagnostic
17{
18 fn from(val: UnifyError<V, L, OptCons<C>>) -> Self {
19 match val {
20 UnifyError::UnifyFailed(lhs, rhs) => {
21 Diagnostic::error(format!("Can not unify types: {lhs} and {rhs}!"))
22 }
23 UnifyError::OccurCheckFailed(x, typ) => {
24 Diagnostic::error(format!("Occur check failed at variable: {x} in {typ}!"))
25 }
26 UnifyError::UnifyVecDiffLen(vec1, vec2) => {
27 let vec1 = vec1.iter().format(", ");
28 let vec2 = vec2.iter().format(", ");
29 Diagnostic::error(format!(
30 "Unify vectors of different length: [{vec1}] and [{vec2}]!"
31 ))
32 }
33 }
34 }
35}
36
37#[derive(Debug)]
38pub struct Unifier<V, L, C> {
39 map: HashMap<V, Term<V, L, C>>,
40 freshs: HashSet<V>,
41}
42
43impl<V: Eq + Hash + Clone, L, C> Default for Unifier<V, L, C> {
44 fn default() -> Self {
45 Self::new()
46 }
47}
48
49impl<V: Eq + Hash + Clone, L, C> Unifier<V, L, C> {
50 pub fn new() -> Unifier<V, L, C> {
51 Unifier {
52 map: HashMap::new(),
53 freshs: HashSet::new(),
54 }
55 }
56
57 pub fn is_empty(&self) -> bool {
58 self.map.is_empty() && self.freshs.is_empty()
59 }
60
61 pub fn reset(&mut self) {
62 self.map.clear();
63 }
64}
65
66impl<V: Eq + Hash + Clone, L: PartialEq + Clone, C: Eq + Clone> Unifier<V, L, C> {
67 pub fn deref<'a>(&'a self, term: &'a Term<V, L, C>) -> &'a Term<V, L, C> {
68 let mut term = term;
69 loop {
70 if let Term::Var(var) = term {
71 if let Some(term2) = self.map.get(var) {
72 term = term2;
73 } else {
74 return term;
75 }
76 } else {
77 return term;
78 }
79 }
80 }
81
82 pub fn subst_opt(&self, term: &Term<V, L, C>) -> Option<Term<V, L, C>> {
83 let mut flag = false;
84 let res = self.subst_opt_help(term, &mut flag);
85 if flag { Some(res) } else { None }
86 }
87
88 fn subst_opt_help(&self, term: &Term<V, L, C>, flag: &mut bool) -> Term<V, L, C> {
89 match term {
90 Term::Var(var) => {
91 if let Some(term) = self.map.get(var) {
92 *flag = true;
93 self.subst_opt_help(term, flag)
94 } else {
95 Term::Var(var.clone())
96 }
97 }
98 Term::Lit(lit) => Term::Lit(lit.clone()),
99 Term::Cons(cons, flds) => {
100 let flds = flds
101 .iter()
102 .map(|fld| self.subst_opt_help(fld, flag))
103 .collect();
104 Term::Cons(cons.clone(), flds)
105 }
106 }
107 }
108
109 pub fn subst(&self, term: &Term<V, L, C>) -> Term<V, L, C> {
110 match term {
111 Term::Var(var) => {
112 if let Some(term) = self.map.get(var) {
113 self.subst(term)
114 } else {
115 Term::Var(var.clone())
116 }
117 }
118 Term::Lit(lit) => Term::Lit(lit.clone()),
119 Term::Cons(cons, flds) => {
120 let flds = flds.iter().map(|fld| self.subst(fld)).collect();
121 Term::Cons(cons.clone(), flds)
122 }
123 }
124 }
125
126 pub fn subst_err(&self, err: &UnifyError<V, L, C>) -> UnifyError<V, L, C> {
127 match err {
128 UnifyError::UnifyFailed(lhs, rhs) => {
129 let lhs = self.subst(lhs);
130 let rhs = self.subst(rhs);
131 UnifyError::UnifyFailed(lhs, rhs)
132 }
133 UnifyError::OccurCheckFailed(x, typ) => {
134 let typ = self.subst(typ);
135 UnifyError::OccurCheckFailed(x.clone(), typ)
136 }
137 UnifyError::UnifyVecDiffLen(vec1, vec2) => {
138 let vec1 = vec1.iter().map(|typ| self.subst(typ)).collect();
139 let vec2 = vec2.iter().map(|typ| self.subst(typ)).collect();
140 UnifyError::UnifyVecDiffLen(vec1, vec2)
141 }
142 }
143 }
144
145 fn occur_check(&self, x: &V, term: &Term<V, L, C>) -> bool {
146 let term = self.deref(term);
147 match term {
148 Term::Var(y) => x == y,
149 Term::Lit(_) => false,
150 Term::Cons(_cons, flds) => flds.iter().any(|fld| self.occur_check(x, fld)),
151 }
152 }
153
154 pub fn fresh(&mut self, var: V) {
155 self.freshs.insert(var);
156 }
157
158 pub fn unify(
159 &mut self,
160 lhs: &Term<V, L, C>,
161 rhs: &Term<V, L, C>,
162 ) -> Result<(), UnifyError<V, L, C>> {
163 let lhs = self.deref(lhs).clone();
164 let rhs = self.deref(rhs).clone();
165 match (&lhs, &rhs) {
166 (Term::Var(x1), Term::Var(x2)) if x1 == x2 => Ok(()),
167 (Term::Var(x), term) | (term, Term::Var(x)) if !self.freshs.contains(x) => {
168 if self.occur_check(x, term) {
169 return Err(UnifyError::OccurCheckFailed(x.clone(), term.clone()));
170 }
171 self.map.insert(x.clone(), term.clone());
172 Ok(())
173 }
174 (Term::Lit(lit1), Term::Lit(lit2)) => {
175 if lit1 == lit2 {
176 Ok(())
177 } else {
178 Err(UnifyError::UnifyFailed(lhs.clone(), rhs.clone()))
179 }
180 }
181 (Term::Cons(cons1, flds1), Term::Cons(cons2, flds2)) => {
182 if cons1 == cons2 {
183 self.unify_many(flds1, flds2)
184 } else {
185 Err(UnifyError::UnifyFailed(lhs.clone(), rhs.clone()))
186 }
187 }
188 (lhs, rhs) => Err(UnifyError::UnifyFailed(lhs.clone(), rhs.clone())),
189 }
190 }
191
192 pub fn unify_many(
193 &mut self,
194 lhss: &[Term<V, L, C>],
195 rhss: &[Term<V, L, C>],
196 ) -> Result<(), UnifyError<V, L, C>> {
197 if lhss.len() == rhss.len() {
198 for (lhs, rhs) in lhss.iter().zip(rhss.iter()) {
199 self.unify(lhs, rhs)?;
200 }
201 Ok(())
202 } else {
203 Err(UnifyError::UnifyVecDiffLen(lhss.to_vec(), rhss.to_vec()))
204 }
205 }
206}