1use std::fmt::{Display, Formatter};
2
3use itertools::Itertools;
4use thiserror::Error;
5
6use MonomorphicType::*;
7
8use crate::algorithm_u::AlgorithmUError::{InfiniteType, UnificationFail};
9use crate::r#type::{MonomorphicType, TVar};
10use crate::substitution::Substitutions;
11use crate::traits::{FreeTypeVars, Substitutable};
12
13#[derive(Debug, Error)]
14pub struct UnificationMismatch(pub Vec<MonomorphicType>, pub Vec<MonomorphicType>);
15
16#[derive(Debug, Error)]
17pub enum AlgorithmUError {
18 #[error("Cannot unify types: {0} with {1}")]
19 UnificationFail(MonomorphicType, MonomorphicType),
20 #[error("Cannot construct an infinite type: {0} ~ {1}")]
21 InfiniteType(TVar, MonomorphicType),
22 #[error(transparent)]
23 UnificationMismatch(#[from] UnificationMismatch),
24}
25
26struct AlgorithmU;
27
28impl AlgorithmU {
29 fn occurs_check(var: &TVar, with: impl FreeTypeVars) -> bool {
30 with.free_types().contains(var)
31 }
32
33 fn unify_vec(
34 vec1: &[MonomorphicType],
35 vec2: &[MonomorphicType],
36 ) -> Result<Substitutions, AlgorithmUError> {
37 match (vec1, vec2) {
38 ([], []) => Ok(Substitutions::empty()),
39 ([t1, ts1 @ ..], [t2, ts2 @ ..]) => {
40 let s1 = t1.unify(t2)?;
41 let s2 = Self::unify_vec(&ts1.substitute(&s1), &ts2.substitute(&s1))?;
42 Ok(s1 + s2)
43 }
44 (t1, t2) => Err(UnificationMismatch(t1.to_vec(), t2.to_vec()).into()),
45 }
46 }
47
48 fn bind(var: &TVar, ty: &MonomorphicType) -> Result<Substitutions, AlgorithmUError> {
49 match ty {
50 Var(v) if var == v => Ok(Substitutions::empty()),
51 _ if Self::occurs_check(var, ty) => Err(InfiniteType(*var, ty.clone())),
52 _ => Ok(Substitutions::single(*var, ty.clone())),
53 }
54 }
55
56 fn apply(
57 lhs: &MonomorphicType,
58 rhs: &MonomorphicType,
59 ) -> Result<Substitutions, AlgorithmUError> {
60 match (lhs, rhs) {
61 (a, b) if a == b => Ok(Substitutions::empty()),
62 (Var(var), b) => Self::bind(var, b),
63 (a, Var(var)) => Self::bind(var, a),
64 (Fn(i1, o1), Fn(i2, o2)) => Self::unify_vec(
65 &[i1.as_ref().clone(), o1.as_ref().clone()],
66 &[i2.as_ref().clone(), o2.as_ref().clone()],
67 ),
68 (Tuple(t1), Tuple(t2)) => Self::unify_vec(&t1.0, &t2.0),
69 (Pointer(t1), Pointer(t2)) => t1.unify(t2),
70 _ => Err(UnificationFail(lhs.clone(), rhs.clone())),
71 }
72 }
73}
74
75impl MonomorphicType {
76 pub fn unify(&self, other: &MonomorphicType) -> Result<Substitutions, AlgorithmUError> {
77 AlgorithmU::apply(self, other)
78 }
79}
80
81impl Display for UnificationMismatch {
82 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
83 write!(
84 f,
85 "Cannot unify types: [{}] with [{}]; different structure",
86 self.0.iter().join(", "),
87 self.1.iter().join(", ")
88 )
89 }
90}
91
92#[cfg(test)]
93#[allow(clippy::unwrap_used)]
94mod tests {
95 use std::collections::HashMap;
96
97 use nonempty_collections::nev;
98
99 use crate::algorithm_u::AlgorithmUError;
100 use crate::r#type::{fun, fun1, MonomorphicType, PrimitiveType, Tuple, TVar, var};
101 use crate::r#type::MonomorphicType::Constant;
102 use crate::substitution::Substitutions;
103 use crate::traits::Substitutable;
104
105 #[test]
106 fn test_tautology_example_on_constants() {
107 let a = Constant("A".to_string());
108 let b = Constant("A".to_string());
109
110 let s = a.unify(&b).unwrap();
111 assert_eq!(s.into_inner(), HashMap::new());
112 }
113
114 #[test]
115 fn test_different_constants_should_not_unify() {
116 let a = Constant("A".to_string());
117 let b = Constant("B".to_string());
118
119 let e = a.unify(&b).unwrap_err();
120 assert!(matches!(e, AlgorithmUError::UnificationFail(..)))
121 }
122
123 #[test]
124 fn test_tautology_example_on_vars() {
125 let a = var(0);
126 let b = var(0);
127
128 let s = a.unify(&b).unwrap();
129 assert_eq!(s.into_inner(), HashMap::new());
130 }
131
132 #[test]
133 fn test_variables_should_be_always_unified() {
134 let a = TVar(1);
135 let b = Constant("A".to_string());
136
137 let s1 = MonomorphicType::Var(a).unify(&b).unwrap();
138 let s2 = b.unify(&MonomorphicType::Var(a)).unwrap();
139
140 assert_eq!(s1, s2);
141 assert_eq!(s1, Substitutions::single(a, b))
142 }
143
144 #[test]
145 fn test_aliasing() {
146 let a = TVar(1);
147 let b = TVar(2);
148
149 let a_ = MonomorphicType::Var(a);
150 let b_ = MonomorphicType::Var(b);
151
152 let s1 = a_.unify(&b_).unwrap();
153 let s2 = b_.unify(&a_).unwrap();
154
155 assert_eq!(s1, Substitutions::single(a, b_.clone()));
156 assert_eq!(s2, Substitutions::single(b, a_))
157 }
158
159 #[test]
160 fn test_simple_function_unifying() {
161 let a = fun(nev![var(1), Constant("A".to_string())], Tuple::unit());
162 let b = fun(nev![var(1), var(2)], Tuple::unit());
163
164 let s = a.unify(&b).unwrap();
165 assert_eq!(s, Substitutions::single(TVar(2), Constant("A".to_string())));
166 }
167
168 #[test]
169 fn test_aliasing_in_functions() {
170 let a = fun(nev![var(1)], Tuple::unit());
171 let b = fun(nev![var(2)], Tuple::unit());
172
173 let s = a.unify(&b).unwrap();
174 assert_eq!(s, Substitutions::single(TVar(1), var(2)));
175 }
176
177 #[test]
178 fn test_functions_with_different_arity_should_not_unify() {
179 let a = fun(nev![Constant("A".to_string())], Tuple::unit());
180 let b = fun(
181 nev![Constant("A".to_string()), Constant("B".to_string())],
182 Tuple::unit(),
183 );
184
185 let s = a.unify(&b).unwrap_err();
186 assert!(matches!(s, AlgorithmUError::UnificationFail(..)))
187 }
188
189 #[test]
190 fn test_multiple_substitutions() {
191 let a = fun(
192 nev![fun1(var(1), PrimitiveType::Integral), var(1)],
193 Tuple::unit(),
194 );
195 let b = fun(nev![var(2), Constant("A".to_string())], Tuple::unit());
196
197 let s = a.unify(&b).unwrap();
198 assert_eq!(
199 s.into_inner(),
200 HashMap::from([
201 (TVar(1), Constant("A".to_string())),
202 (
203 TVar(2),
204 fun1(Constant("A".to_string()), PrimitiveType::Integral)
205 )
206 ])
207 )
208 }
209
210 #[test]
211 fn test_infinite_substitution() {
212 let a = var(1);
213 let b = fun1(var(1), Tuple::unit());
214
215 let e = a.unify(&b).unwrap_err();
216 assert!(matches!(e, AlgorithmUError::InfiniteType { .. }))
217 }
218
219 #[test]
220 fn test_transitive_substitutions() {
221 let a = var(1);
222 let b = var(2);
223 let c = Constant("A".to_string());
224
225 let s1 = a.unify(&b).unwrap();
226 let s2 = b.unify(&a).unwrap();
227 let s3 = c.unify(&b.substitute(&s2)).unwrap();
228 let s4 = a.substitute(&s1).unify(&c).unwrap();
229
230 assert_eq!(s1, Substitutions::single(TVar(1), b.clone()));
231 assert_eq!(s2, Substitutions::single(TVar(2), a.clone()));
232 assert_eq!(s3, Substitutions::single(TVar(1), c.clone()));
233 assert_eq!(s4, Substitutions::single(TVar(2), c));
234 }
235
236 #[test]
237 fn test_different_substitutions_of_same_variable() {
238 let a = var(1);
239 let b = Constant("A".to_string());
240 let c = Constant("B".to_string());
241
242 let s = a.unify(&b).unwrap();
243 let e = a.substitute(&s).unify(&c).unwrap_err();
244
245 assert_eq!(s, Substitutions::single(TVar(1), b));
246 assert!(matches!(e, AlgorithmUError::UnificationFail(..)))
247 }
248
249 #[test]
250 fn test_complex_unification() {
251 let a = fun1(
252 fun1(fun1(Constant("A".to_string()), var(1)), var(2)),
253 var(3),
254 );
255 let b = fun(nev![var(3), var(2), var(1)], Constant("A".to_string()));
256
257 let s1 = a.unify(&b).unwrap();
258 let s2 = b.unify(&a).unwrap();
259
260 assert_eq!(s1, s2);
261 assert_eq!(a.substitute(&s1), b.substitute(&s1));
262 let h = fun1(Constant("A".to_string()), Constant("A".to_string()));
263 assert_eq!(
264 a.substitute(&s1),
265 fun1(fun1(h.clone(), h.clone()), fun1(h.clone(), h))
266 )
267 }
268}