kodept_inference/
algorithm_u.rs

1use std::fmt::{Display, Formatter};
2
3use itertools::Itertools;
4use thiserror::Error;
5
6use MonomorphicType::*;
7
8use crate::algorithm_u::AlgorithmUError::{InfiniteType, UnificationFail};
9use crate::r#type::{MonomorphicType, TVar};
10use crate::substitution::Substitutions;
11use crate::traits::{FreeTypeVars, Substitutable};
12
13#[derive(Debug, Error)]
14pub struct UnificationMismatch(pub Vec<MonomorphicType>, pub Vec<MonomorphicType>);
15
16#[derive(Debug, Error)]
17pub enum AlgorithmUError {
18    #[error("Cannot unify types: {0} with {1}")]
19    UnificationFail(MonomorphicType, MonomorphicType),
20    #[error("Cannot construct an infinite type: {0} ~ {1}")]
21    InfiniteType(TVar, MonomorphicType),
22    #[error(transparent)]
23    UnificationMismatch(#[from] UnificationMismatch),
24}
25
26struct AlgorithmU;
27
28impl AlgorithmU {
29    fn occurs_check(var: &TVar, with: impl FreeTypeVars) -> bool {
30        with.free_types().contains(var)
31    }
32
33    fn unify_vec(
34        vec1: &[MonomorphicType],
35        vec2: &[MonomorphicType],
36    ) -> Result<Substitutions, AlgorithmUError> {
37        match (vec1, vec2) {
38            ([], []) => Ok(Substitutions::empty()),
39            ([t1, ts1 @ ..], [t2, ts2 @ ..]) => {
40                let s1 = t1.unify(t2)?;
41                let s2 = Self::unify_vec(&ts1.substitute(&s1), &ts2.substitute(&s1))?;
42                Ok(s1 + s2)
43            }
44            (t1, t2) => Err(UnificationMismatch(t1.to_vec(), t2.to_vec()).into()),
45        }
46    }
47
48    fn bind(var: &TVar, ty: &MonomorphicType) -> Result<Substitutions, AlgorithmUError> {
49        match ty {
50            Var(v) if var == v => Ok(Substitutions::empty()),
51            _ if Self::occurs_check(var, ty) => Err(InfiniteType(*var, ty.clone())),
52            _ => Ok(Substitutions::single(*var, ty.clone())),
53        }
54    }
55
56    fn apply(
57        lhs: &MonomorphicType,
58        rhs: &MonomorphicType,
59    ) -> Result<Substitutions, AlgorithmUError> {
60        match (lhs, rhs) {
61            (a, b) if a == b => Ok(Substitutions::empty()),
62            (Var(var), b) => Self::bind(var, b),
63            (a, Var(var)) => Self::bind(var, a),
64            (Fn(i1, o1), Fn(i2, o2)) => Self::unify_vec(
65                &[i1.as_ref().clone(), o1.as_ref().clone()],
66                &[i2.as_ref().clone(), o2.as_ref().clone()],
67            ),
68            (Tuple(t1), Tuple(t2)) => Self::unify_vec(&t1.0, &t2.0),
69            (Pointer(t1), Pointer(t2)) => t1.unify(t2),
70            _ => Err(UnificationFail(lhs.clone(), rhs.clone())),
71        }
72    }
73}
74
75impl MonomorphicType {
76    pub fn unify(&self, other: &MonomorphicType) -> Result<Substitutions, AlgorithmUError> {
77        AlgorithmU::apply(self, other)
78    }
79}
80
81impl Display for UnificationMismatch {
82    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
83        write!(
84            f,
85            "Cannot unify types: [{}] with [{}]; different structure",
86            self.0.iter().join(", "),
87            self.1.iter().join(", ")
88        )
89    }
90}
91
92#[cfg(test)]
93#[allow(clippy::unwrap_used)]
94mod tests {
95    use std::collections::HashMap;
96
97    use nonempty_collections::nev;
98
99    use crate::algorithm_u::AlgorithmUError;
100    use crate::r#type::{fun, fun1, MonomorphicType, PrimitiveType, Tuple, TVar, var};
101    use crate::r#type::MonomorphicType::Constant;
102    use crate::substitution::Substitutions;
103    use crate::traits::Substitutable;
104
105    #[test]
106    fn test_tautology_example_on_constants() {
107        let a = Constant("A".to_string());
108        let b = Constant("A".to_string());
109
110        let s = a.unify(&b).unwrap();
111        assert_eq!(s.into_inner(), HashMap::new());
112    }
113
114    #[test]
115    fn test_different_constants_should_not_unify() {
116        let a = Constant("A".to_string());
117        let b = Constant("B".to_string());
118
119        let e = a.unify(&b).unwrap_err();
120        assert!(matches!(e, AlgorithmUError::UnificationFail(..)))
121    }
122
123    #[test]
124    fn test_tautology_example_on_vars() {
125        let a = var(0);
126        let b = var(0);
127
128        let s = a.unify(&b).unwrap();
129        assert_eq!(s.into_inner(), HashMap::new());
130    }
131
132    #[test]
133    fn test_variables_should_be_always_unified() {
134        let a = TVar(1);
135        let b = Constant("A".to_string());
136
137        let s1 = MonomorphicType::Var(a).unify(&b).unwrap();
138        let s2 = b.unify(&MonomorphicType::Var(a)).unwrap();
139
140        assert_eq!(s1, s2);
141        assert_eq!(s1, Substitutions::single(a, b))
142    }
143
144    #[test]
145    fn test_aliasing() {
146        let a = TVar(1);
147        let b = TVar(2);
148
149        let a_ = MonomorphicType::Var(a);
150        let b_ = MonomorphicType::Var(b);
151
152        let s1 = a_.unify(&b_).unwrap();
153        let s2 = b_.unify(&a_).unwrap();
154
155        assert_eq!(s1, Substitutions::single(a, b_.clone()));
156        assert_eq!(s2, Substitutions::single(b, a_))
157    }
158
159    #[test]
160    fn test_simple_function_unifying() {
161        let a = fun(nev![var(1), Constant("A".to_string())], Tuple::unit());
162        let b = fun(nev![var(1), var(2)], Tuple::unit());
163
164        let s = a.unify(&b).unwrap();
165        assert_eq!(s, Substitutions::single(TVar(2), Constant("A".to_string())));
166    }
167
168    #[test]
169    fn test_aliasing_in_functions() {
170        let a = fun(nev![var(1)], Tuple::unit());
171        let b = fun(nev![var(2)], Tuple::unit());
172
173        let s = a.unify(&b).unwrap();
174        assert_eq!(s, Substitutions::single(TVar(1), var(2)));
175    }
176
177    #[test]
178    fn test_functions_with_different_arity_should_not_unify() {
179        let a = fun(nev![Constant("A".to_string())], Tuple::unit());
180        let b = fun(
181            nev![Constant("A".to_string()), Constant("B".to_string())],
182            Tuple::unit(),
183        );
184
185        let s = a.unify(&b).unwrap_err();
186        assert!(matches!(s, AlgorithmUError::UnificationFail(..)))
187    }
188
189    #[test]
190    fn test_multiple_substitutions() {
191        let a = fun(
192            nev![fun1(var(1), PrimitiveType::Integral), var(1)],
193            Tuple::unit(),
194        );
195        let b = fun(nev![var(2), Constant("A".to_string())], Tuple::unit());
196
197        let s = a.unify(&b).unwrap();
198        assert_eq!(
199            s.into_inner(),
200            HashMap::from([
201                (TVar(1), Constant("A".to_string())),
202                (
203                    TVar(2),
204                    fun1(Constant("A".to_string()), PrimitiveType::Integral)
205                )
206            ])
207        )
208    }
209
210    #[test]
211    fn test_infinite_substitution() {
212        let a = var(1);
213        let b = fun1(var(1), Tuple::unit());
214
215        let e = a.unify(&b).unwrap_err();
216        assert!(matches!(e, AlgorithmUError::InfiniteType { .. }))
217    }
218
219    #[test]
220    fn test_transitive_substitutions() {
221        let a = var(1);
222        let b = var(2);
223        let c = Constant("A".to_string());
224
225        let s1 = a.unify(&b).unwrap();
226        let s2 = b.unify(&a).unwrap();
227        let s3 = c.unify(&b.substitute(&s2)).unwrap();
228        let s4 = a.substitute(&s1).unify(&c).unwrap();
229
230        assert_eq!(s1, Substitutions::single(TVar(1), b.clone()));
231        assert_eq!(s2, Substitutions::single(TVar(2), a.clone()));
232        assert_eq!(s3, Substitutions::single(TVar(1), c.clone()));
233        assert_eq!(s4, Substitutions::single(TVar(2), c));
234    }
235
236    #[test]
237    fn test_different_substitutions_of_same_variable() {
238        let a = var(1);
239        let b = Constant("A".to_string());
240        let c = Constant("B".to_string());
241
242        let s = a.unify(&b).unwrap();
243        let e = a.substitute(&s).unify(&c).unwrap_err();
244
245        assert_eq!(s, Substitutions::single(TVar(1), b));
246        assert!(matches!(e, AlgorithmUError::UnificationFail(..)))
247    }
248
249    #[test]
250    fn test_complex_unification() {
251        let a = fun1(
252            fun1(fun1(Constant("A".to_string()), var(1)), var(2)),
253            var(3),
254        );
255        let b = fun(nev![var(3), var(2), var(1)], Constant("A".to_string()));
256
257        let s1 = a.unify(&b).unwrap();
258        let s2 = b.unify(&a).unwrap();
259
260        assert_eq!(s1, s2);
261        assert_eq!(a.substitute(&s1), b.substitute(&s1));
262        let h = fun1(Constant("A".to_string()), Constant("A".to_string()));
263        assert_eq!(
264            a.substitute(&s1),
265            fun1(fun1(h.clone(), h.clone()), fun1(h.clone(), h))
266        )
267    }
268}