1use crate::types::{Algebraic, Constructor, Type, Value};
2
3pub struct TypeEqualityChecker<'a> {
4 pairs: Vec<(&'a Algebraic, &'a Algebraic)>,
5}
6
7impl<'a> TypeEqualityChecker<'a> {
8 pub fn new(types: &'a [&'a Algebraic]) -> Self {
9 Self {
10 pairs: types.iter().cloned().zip(types.iter().cloned()).collect(),
11 }
12 }
13
14 pub fn equal_algebraics(&self, one: &Algebraic, other: &Algebraic) -> bool {
15 if one.constructors().len() != other.constructors().len() {
16 return false;
17 } else if self.pairs.contains(&(one, other)) {
18 return true;
19 }
20
21 let checker = self.push_pair(one, other);
22
23 one.constructors()
24 .iter()
25 .zip(other.constructors())
26 .all(|(one, other)| checker.equal_constructors(one, other))
27 }
28
29 fn equal_values(&self, one: &Value, other: &Value) -> bool {
30 match (one, other) {
31 (Value::Primitive(one), Value::Primitive(other)) => one == other,
32 (Value::Algebraic(one), Value::Algebraic(other)) => self.equal_algebraics(one, other),
33 (Value::Index(index), Value::Algebraic(other)) => {
34 self.equal_algebraics(self.pairs[*index].0, other)
35 }
36 (Value::Algebraic(other), Value::Index(index)) => {
37 self.equal_algebraics(other, self.pairs[*index].1)
38 }
39 (Value::Index(one), Value::Index(other)) => {
40 self.equal_algebraics(self.pairs[*one].0, self.pairs[*other].1)
41 }
42 _ => false,
43 }
44 }
45
46 fn equal(&self, one: &Type, other: &Type) -> bool {
47 match (one, other) {
48 (Type::Value(one), Type::Value(other)) => self.equal_values(one, other),
49 (Type::Function(one), Type::Function(other)) => {
50 one.arguments().len() == other.arguments().len()
51 && one
52 .arguments()
53 .iter()
54 .zip(other.arguments())
55 .all(|(one, other)| self.equal(one, other))
56 && self.equal_values(one.result(), other.result())
57 }
58 (_, _) => false,
59 }
60 }
61
62 fn equal_constructors(&self, one: &Constructor, other: &Constructor) -> bool {
63 one.elements().len() == other.elements().len()
64 && one
65 .elements()
66 .iter()
67 .zip(other.elements())
68 .all(|(one, other)| self.equal(one, other))
69 }
70
71 fn push_pair(&'a self, one: &'a Algebraic, other: &'a Algebraic) -> Self {
72 Self {
73 pairs: [(one, other)]
74 .iter()
75 .chain(self.pairs.iter())
76 .copied()
77 .collect(),
78 }
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85 use crate::types::{Algebraic, Function, Primitive};
86
87 #[test]
88 fn equal() {
89 for (one, other) in &[
90 (Primitive::Float64.into(), Primitive::Float64.into()),
91 (
92 Function::new(vec![Primitive::Float64.into()], Primitive::Float64).into(),
93 Function::new(vec![Primitive::Float64.into()], Primitive::Float64).into(),
94 ),
95 (
96 Algebraic::new(vec![Constructor::new(vec![Primitive::Float64.into()])]).into(),
97 Algebraic::new(vec![Constructor::new(vec![Primitive::Float64.into()])]).into(),
98 ),
99 (
100 Algebraic::new(vec![Constructor::new(vec![Value::Index(0).into()])]).into(),
101 Algebraic::new(vec![Constructor::new(vec![Algebraic::new(vec![
102 Constructor::new(vec![Value::Index(0).into()]),
103 ])
104 .into()])])
105 .into(),
106 ),
107 (
108 Algebraic::new(vec![Constructor::new(vec![Value::Index(0).into()])]).into(),
109 Algebraic::new(vec![Constructor::new(vec![Algebraic::new(vec![
110 Constructor::new(vec![Value::Index(1).into()]),
111 ])
112 .into()])])
113 .into(),
114 ),
115 (
116 Algebraic::new(vec![Constructor::new(vec![Algebraic::new(vec![
117 Constructor::new(vec![Value::Index(0).into()]),
118 ])
119 .into()])])
120 .into(),
121 Algebraic::new(vec![Constructor::new(vec![Algebraic::new(vec![
122 Constructor::new(vec![Value::Index(1).into()]),
123 ])
124 .into()])])
125 .into(),
126 ),
127 (
128 Algebraic::new(vec![Constructor::new(vec![Function::new(
129 vec![Primitive::Float64.into()],
130 Value::Index(0),
131 )
132 .into()])])
133 .into(),
134 Algebraic::new(vec![Constructor::new(vec![Function::new(
135 vec![Primitive::Float64.into()],
136 Algebraic::new(vec![Constructor::new(vec![Function::new(
137 vec![Primitive::Float64.into()],
138 Value::Index(0),
139 )
140 .into()])]),
141 )
142 .into()])])
143 .into(),
144 ),
145 (
146 Algebraic::new(vec![Constructor::new(vec![Function::new(
147 vec![Primitive::Float64.into()],
148 Value::Index(0),
149 )
150 .into()])])
151 .into(),
152 Algebraic::new(vec![Constructor::new(vec![Function::new(
153 vec![Primitive::Float64.into()],
154 Algebraic::new(vec![Constructor::new(vec![Function::new(
155 vec![Primitive::Float64.into()],
156 Value::Index(1),
157 )
158 .into()])]),
159 )
160 .into()])])
161 .into(),
162 ),
163 ] {
164 assert!(TypeEqualityChecker::new(&[]).equal(one, other));
165 }
166 }
167
168 #[test]
169 fn not_equal() {
170 for (one, other) in &[
171 (
172 Primitive::Float64.into(),
173 Function::new(vec![Primitive::Float64.into()], Primitive::Float64).into(),
174 ),
175 (
176 Function::new(
177 vec![Primitive::Float64.into(), Primitive::Float64.into()],
178 Primitive::Float64,
179 )
180 .into(),
181 Function::new(vec![Primitive::Float64.into()], Primitive::Float64).into(),
182 ),
183 (
184 Algebraic::new(vec![Constructor::new(vec![Primitive::Float64.into()])]).into(),
185 Algebraic::new(vec![Constructor::new(vec![
186 Primitive::Float64.into(),
187 Primitive::Float64.into(),
188 ])])
189 .into(),
190 ),
191 ] {
192 assert!(!TypeEqualityChecker::new(&[]).equal(one, other));
193 }
194 }
195}