ssf/types/canonicalize/
type_equality_checker.rs

1use crate::types::{Algebraic, Constructor, Type, Value};
2
3pub struct TypeEqualityChecker<'a> {
4    pairs: Vec<(&'a Algebraic, &'a Algebraic)>,
5}
6
7impl<'a> TypeEqualityChecker<'a> {
8    pub fn new(types: &'a [&'a Algebraic]) -> Self {
9        Self {
10            pairs: types.iter().cloned().zip(types.iter().cloned()).collect(),
11        }
12    }
13
14    pub fn equal_algebraics(&self, one: &Algebraic, other: &Algebraic) -> bool {
15        if one.constructors().len() != other.constructors().len() {
16            return false;
17        } else if self.pairs.contains(&(one, other)) {
18            return true;
19        }
20
21        let checker = self.push_pair(one, other);
22
23        one.constructors()
24            .iter()
25            .zip(other.constructors())
26            .all(|(one, other)| checker.equal_constructors(one, other))
27    }
28
29    fn equal_values(&self, one: &Value, other: &Value) -> bool {
30        match (one, other) {
31            (Value::Primitive(one), Value::Primitive(other)) => one == other,
32            (Value::Algebraic(one), Value::Algebraic(other)) => self.equal_algebraics(one, other),
33            (Value::Index(index), Value::Algebraic(other)) => {
34                self.equal_algebraics(self.pairs[*index].0, other)
35            }
36            (Value::Algebraic(other), Value::Index(index)) => {
37                self.equal_algebraics(other, self.pairs[*index].1)
38            }
39            (Value::Index(one), Value::Index(other)) => {
40                self.equal_algebraics(self.pairs[*one].0, self.pairs[*other].1)
41            }
42            _ => false,
43        }
44    }
45
46    fn equal(&self, one: &Type, other: &Type) -> bool {
47        match (one, other) {
48            (Type::Value(one), Type::Value(other)) => self.equal_values(one, other),
49            (Type::Function(one), Type::Function(other)) => {
50                one.arguments().len() == other.arguments().len()
51                    && one
52                        .arguments()
53                        .iter()
54                        .zip(other.arguments())
55                        .all(|(one, other)| self.equal(one, other))
56                    && self.equal_values(one.result(), other.result())
57            }
58            (_, _) => false,
59        }
60    }
61
62    fn equal_constructors(&self, one: &Constructor, other: &Constructor) -> bool {
63        one.elements().len() == other.elements().len()
64            && one
65                .elements()
66                .iter()
67                .zip(other.elements())
68                .all(|(one, other)| self.equal(one, other))
69    }
70
71    fn push_pair(&'a self, one: &'a Algebraic, other: &'a Algebraic) -> Self {
72        Self {
73            pairs: [(one, other)]
74                .iter()
75                .chain(self.pairs.iter())
76                .copied()
77                .collect(),
78        }
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85    use crate::types::{Algebraic, Function, Primitive};
86
87    #[test]
88    fn equal() {
89        for (one, other) in &[
90            (Primitive::Float64.into(), Primitive::Float64.into()),
91            (
92                Function::new(vec![Primitive::Float64.into()], Primitive::Float64).into(),
93                Function::new(vec![Primitive::Float64.into()], Primitive::Float64).into(),
94            ),
95            (
96                Algebraic::new(vec![Constructor::new(vec![Primitive::Float64.into()])]).into(),
97                Algebraic::new(vec![Constructor::new(vec![Primitive::Float64.into()])]).into(),
98            ),
99            (
100                Algebraic::new(vec![Constructor::new(vec![Value::Index(0).into()])]).into(),
101                Algebraic::new(vec![Constructor::new(vec![Algebraic::new(vec![
102                    Constructor::new(vec![Value::Index(0).into()]),
103                ])
104                .into()])])
105                .into(),
106            ),
107            (
108                Algebraic::new(vec![Constructor::new(vec![Value::Index(0).into()])]).into(),
109                Algebraic::new(vec![Constructor::new(vec![Algebraic::new(vec![
110                    Constructor::new(vec![Value::Index(1).into()]),
111                ])
112                .into()])])
113                .into(),
114            ),
115            (
116                Algebraic::new(vec![Constructor::new(vec![Algebraic::new(vec![
117                    Constructor::new(vec![Value::Index(0).into()]),
118                ])
119                .into()])])
120                .into(),
121                Algebraic::new(vec![Constructor::new(vec![Algebraic::new(vec![
122                    Constructor::new(vec![Value::Index(1).into()]),
123                ])
124                .into()])])
125                .into(),
126            ),
127            (
128                Algebraic::new(vec![Constructor::new(vec![Function::new(
129                    vec![Primitive::Float64.into()],
130                    Value::Index(0),
131                )
132                .into()])])
133                .into(),
134                Algebraic::new(vec![Constructor::new(vec![Function::new(
135                    vec![Primitive::Float64.into()],
136                    Algebraic::new(vec![Constructor::new(vec![Function::new(
137                        vec![Primitive::Float64.into()],
138                        Value::Index(0),
139                    )
140                    .into()])]),
141                )
142                .into()])])
143                .into(),
144            ),
145            (
146                Algebraic::new(vec![Constructor::new(vec![Function::new(
147                    vec![Primitive::Float64.into()],
148                    Value::Index(0),
149                )
150                .into()])])
151                .into(),
152                Algebraic::new(vec![Constructor::new(vec![Function::new(
153                    vec![Primitive::Float64.into()],
154                    Algebraic::new(vec![Constructor::new(vec![Function::new(
155                        vec![Primitive::Float64.into()],
156                        Value::Index(1),
157                    )
158                    .into()])]),
159                )
160                .into()])])
161                .into(),
162            ),
163        ] {
164            assert!(TypeEqualityChecker::new(&[]).equal(one, other));
165        }
166    }
167
168    #[test]
169    fn not_equal() {
170        for (one, other) in &[
171            (
172                Primitive::Float64.into(),
173                Function::new(vec![Primitive::Float64.into()], Primitive::Float64).into(),
174            ),
175            (
176                Function::new(
177                    vec![Primitive::Float64.into(), Primitive::Float64.into()],
178                    Primitive::Float64,
179                )
180                .into(),
181                Function::new(vec![Primitive::Float64.into()], Primitive::Float64).into(),
182            ),
183            (
184                Algebraic::new(vec![Constructor::new(vec![Primitive::Float64.into()])]).into(),
185                Algebraic::new(vec![Constructor::new(vec![
186                    Primitive::Float64.into(),
187                    Primitive::Float64.into(),
188                ])])
189                .into(),
190            ),
191        ] {
192            assert!(!TypeEqualityChecker::new(&[]).equal(one, other));
193        }
194    }
195}