1use std::{
2 cmp::{max, min},
3 collections::{HashMap, HashSet},
4};
5
6use id_arena::Arena;
7use indexmap::{IndexMap, IndexSet};
8use log::{debug, trace};
9
10use crate::{
11 AtomRestriction, AtomSemantic, BuiltinTypes, Type, TypeId, stringify_impl, substitute,
12};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
15pub enum Comparison {
16 Assign,
17 Cast,
18 Invalid,
19}
20
21#[derive(Debug, Default)]
22pub(crate) struct ComparisonContext<'a> {
23 infer: Option<&'a mut HashMap<TypeId, TypeId>>,
24 stack: IndexSet<(TypeId, TypeId)>,
25}
26
27pub fn compare_with_inference(
28 arena: &mut Arena<Type>,
29 builtins: &BuiltinTypes,
30 lhs: TypeId,
31 rhs: TypeId,
32 infer: Option<&mut HashMap<TypeId, TypeId>>,
33) -> Comparison {
34 let lhs = substitute(arena, lhs);
35 let rhs = substitute(arena, rhs);
36 let lhs_name = stringify_impl(arena, lhs, &mut IndexMap::new());
37 let rhs_name = stringify_impl(arena, rhs, &mut IndexMap::new());
38 trace!("Comparing {lhs_name} to {rhs_name}");
39 let result = compare_impl(
40 arena,
41 builtins,
42 &mut ComparisonContext {
43 infer,
44 ..Default::default()
45 },
46 lhs,
47 rhs,
48 None,
49 None,
50 );
51 trace!("Comparison from {lhs_name} to {rhs_name} yielded {result:?}");
52 result
53}
54
55pub fn compare(
56 arena: &mut Arena<Type>,
57 builtins: &BuiltinTypes,
58 lhs: TypeId,
59 rhs: TypeId,
60) -> Comparison {
61 compare_with_inference(arena, builtins, lhs, rhs, None)
62}
63
64pub(crate) fn compare_impl(
65 arena: &Arena<Type>,
66 builtins: &BuiltinTypes,
67 ctx: &mut ComparisonContext<'_>,
68 lhs: TypeId,
69 rhs: TypeId,
70 lhs_semantic: Option<TypeId>,
71 rhs_semantic: Option<TypeId>,
72) -> Comparison {
73 if !ctx.stack.insert((lhs, rhs)) {
74 return Comparison::Assign;
75 }
76
77 let result = match (arena[lhs].clone(), arena[rhs].clone()) {
78 (Type::Apply(_), _) | (_, Type::Apply(_)) => unreachable!(),
79 (Type::Ref(lhs), _) => {
80 compare_impl(arena, builtins, ctx, lhs, rhs, lhs_semantic, rhs_semantic)
81 }
82 (_, Type::Ref(rhs)) => {
83 compare_impl(arena, builtins, ctx, lhs, rhs, lhs_semantic, rhs_semantic)
84 }
85 (Type::Unresolved, _) | (_, Type::Unresolved) => Comparison::Assign,
86 (Type::Never, _) => Comparison::Assign,
87 (Type::Atom(lhs), Type::Atom(rhs)) => {
88 let semantic = if lhs.semantic == rhs.semantic || rhs.semantic == AtomSemantic::Any {
89 Comparison::Assign
90 } else {
91 Comparison::Cast
92 };
93
94 let restriction = match (lhs.restriction, rhs.restriction) {
95 (_, None) => Comparison::Assign,
96 (None, _) => Comparison::Invalid,
97 (Some(AtomRestriction::Length(lhs)), Some(AtomRestriction::Length(rhs))) => {
98 if lhs == rhs {
99 Comparison::Assign
100 } else {
101 Comparison::Invalid
102 }
103 }
104 (Some(AtomRestriction::Value(lhs)), Some(AtomRestriction::Value(rhs))) => {
105 if lhs == rhs {
106 Comparison::Assign
107 } else {
108 Comparison::Invalid
109 }
110 }
111 (Some(AtomRestriction::Length(_)), Some(AtomRestriction::Value(_))) => {
112 Comparison::Invalid
113 }
114 (Some(AtomRestriction::Value(lhs)), Some(AtomRestriction::Length(rhs))) => {
115 if lhs.len() == rhs {
116 Comparison::Assign
117 } else {
118 Comparison::Invalid
119 }
120 }
121 };
122
123 max(semantic, restriction)
124 }
125 (Type::Pair(lhs), Type::Pair(rhs)) => {
126 let first = compare_impl(arena, builtins, ctx, lhs.first, rhs.first, None, None);
127 let rest = compare_impl(arena, builtins, ctx, lhs.rest, rhs.rest, None, None);
128 max(first, rest)
129 }
130 (Type::Atom(_), Type::Pair(_)) => Comparison::Invalid,
131 (Type::Pair(_), Type::Atom(_)) => Comparison::Invalid,
132 (Type::Function(lhs), Type::Function(rhs)) => {
133 if lhs.nil_terminated != rhs.nil_terminated || lhs.params.len() != rhs.params.len() {
136 Comparison::Invalid
137 } else {
138 let mut result = compare_impl(arena, builtins, ctx, lhs.ret, rhs.ret, None, None);
139
140 for (i, param) in lhs.params.iter().enumerate() {
141 result = max(
142 result,
143 compare_impl(arena, builtins, ctx, *param, rhs.params[i], None, None),
144 );
145 }
146
147 result
148 }
149 }
150 (_, Type::Generic) => {
151 if lhs == rhs {
152 Comparison::Assign
153 } else if let Some(infer) = &mut ctx.infer {
154 if let Some(rhs) = infer.get(&rhs).copied() {
155 compare_impl(arena, builtins, ctx, lhs, rhs, lhs_semantic, rhs_semantic)
156 } else {
157 debug!(
158 "Inferring {} is {}",
159 stringify_impl(arena, rhs, &mut IndexMap::new()),
160 stringify_impl(arena, lhs, &mut IndexMap::new())
161 );
162 infer.insert(rhs, lhs);
163 Comparison::Assign
164 }
165 } else if let Type::Union(lhs) = arena[lhs].clone() {
166 let mut result = Comparison::Assign;
167
168 for &id in &lhs.types {
169 result = max(
170 result,
171 compare_impl(arena, builtins, ctx, id, rhs, lhs_semantic, rhs_semantic),
172 );
173 }
174
175 result
176 } else {
177 Comparison::Invalid
178 }
179 }
180 (Type::Struct(lhs), Type::Struct(rhs)) => max(
181 compare_impl(arena, builtins, ctx, lhs.inner, rhs.inner, None, None),
182 if lhs.semantic == rhs.semantic {
183 Comparison::Assign
184 } else {
185 Comparison::Cast
186 },
187 ),
188 (Type::Struct(lhs), _) => {
189 let inner = compare_impl(
190 arena,
191 builtins,
192 ctx,
193 lhs.inner,
194 rhs,
195 Some(lhs.semantic),
196 rhs_semantic,
197 );
198
199 if rhs_semantic == Some(lhs.semantic)
200 || semantics_of(arena, rhs).contains(&Some(lhs.semantic))
201 {
202 inner
203 } else {
204 max(inner, Comparison::Cast)
205 }
206 }
207 (_, Type::Struct(rhs)) => {
208 let inner = compare_impl(
209 arena,
210 builtins,
211 ctx,
212 lhs,
213 rhs.inner,
214 lhs_semantic,
215 Some(rhs.semantic),
216 );
217
218 let semantics = semantics_of(arena, lhs);
219
220 if (semantics.len() != 1 || !semantics.contains(&Some(rhs.semantic)))
221 && lhs_semantic != Some(rhs.semantic)
222 {
223 max(inner, Comparison::Cast)
224 } else {
225 inner
226 }
227 }
228 (Type::Alias(lhs), _) => compare_impl(
229 arena,
230 builtins,
231 ctx,
232 lhs.inner,
233 rhs,
234 lhs_semantic,
235 rhs_semantic,
236 ),
237 (_, Type::Alias(rhs)) => compare_impl(
238 arena,
239 builtins,
240 ctx,
241 lhs,
242 rhs.inner,
243 lhs_semantic,
244 rhs_semantic,
245 ),
246 (Type::Function(_) | Type::Generic, _) => compare_impl(
247 arena,
248 builtins,
249 ctx,
250 builtins.any,
251 rhs,
252 lhs_semantic,
253 rhs_semantic,
254 ),
255 (Type::Union(lhs), _) => {
256 let mut result = Comparison::Assign;
257
258 for &id in &lhs.types {
259 result = max(
260 result,
261 compare_impl(arena, builtins, ctx, id, rhs, lhs_semantic, rhs_semantic),
262 );
263 }
264
265 result
266 }
267 (_, Type::Union(rhs)) => {
268 let mut result = Comparison::Invalid;
269
270 for &id in &rhs.types {
271 result = min(
272 result,
273 compare_impl(arena, builtins, ctx, lhs, id, lhs_semantic, rhs_semantic),
274 );
275 }
276
277 result
278 }
279 (_, Type::Never) => Comparison::Invalid,
280 (_, Type::Function(_)) => Comparison::Invalid,
281 };
282
283 ctx.stack.pop().unwrap();
284
285 result
286}
287
288fn semantics_of(arena: &Arena<Type>, id: TypeId) -> HashSet<Option<TypeId>> {
289 match arena[id].clone() {
290 Type::Apply(_) => unreachable!(),
291 Type::Ref(id) => semantics_of(arena, id),
292 Type::Alias(alias) => semantics_of(arena, alias.inner),
293 Type::Never => HashSet::new(),
294 Type::Unresolved | Type::Generic | Type::Atom(_) | Type::Pair(_) | Type::Function(_) => {
295 HashSet::from_iter([None])
296 }
297 Type::Struct(ty) => HashSet::from_iter([Some(ty.semantic)]),
298 Type::Union(ty) => {
299 let mut semantics = HashSet::new();
300
301 for &id in &ty.types {
302 semantics.extend(semantics_of(arena, id));
303 }
304
305 semantics
306 }
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use std::borrow::Cow;
313
314 use id_arena::Arena;
315 use rstest::rstest;
316
317 use crate::{Atom, Type, compare};
318
319 use super::*;
320
321 #[rstest]
322 #[case(Atom::NIL, Atom::NIL, Comparison::Assign)]
323 #[case(Atom::NIL, Atom::FALSE, Comparison::Cast)]
324 #[case(Atom::NIL, Atom::TRUE, Comparison::Invalid)]
325 #[case(Atom::NIL, Atom::BYTES, Comparison::Assign)]
326 #[case(Atom::NIL, Atom::BYTES_32, Comparison::Invalid)]
327 #[case(Atom::NIL, Atom::PUBLIC_KEY, Comparison::Invalid)]
328 #[case(Atom::NIL, Atom::INT, Comparison::Cast)]
329 #[case(Atom::FALSE, Atom::NIL, Comparison::Cast)]
330 #[case(Atom::FALSE, Atom::FALSE, Comparison::Assign)]
331 #[case(Atom::FALSE, Atom::TRUE, Comparison::Invalid)]
332 #[case(Atom::FALSE, Atom::BYTES, Comparison::Cast)]
333 #[case(Atom::FALSE, Atom::BYTES_32, Comparison::Invalid)]
334 #[case(Atom::FALSE, Atom::PUBLIC_KEY, Comparison::Invalid)]
335 #[case(Atom::FALSE, Atom::INT, Comparison::Cast)]
336 #[case(Atom::TRUE, Atom::NIL, Comparison::Invalid)]
337 #[case(Atom::TRUE, Atom::FALSE, Comparison::Invalid)]
338 #[case(Atom::TRUE, Atom::TRUE, Comparison::Assign)]
339 #[case(Atom::TRUE, Atom::BYTES, Comparison::Cast)]
340 #[case(Atom::TRUE, Atom::BYTES_32, Comparison::Invalid)]
341 #[case(Atom::TRUE, Atom::PUBLIC_KEY, Comparison::Invalid)]
342 #[case(Atom::TRUE, Atom::INT, Comparison::Cast)]
343 #[case(Atom::BYTES, Atom::NIL, Comparison::Invalid)]
344 #[case(Atom::BYTES, Atom::FALSE, Comparison::Invalid)]
345 #[case(Atom::BYTES, Atom::TRUE, Comparison::Invalid)]
346 #[case(Atom::BYTES, Atom::BYTES, Comparison::Assign)]
347 #[case(Atom::BYTES, Atom::BYTES_32, Comparison::Invalid)]
348 #[case(Atom::BYTES, Atom::PUBLIC_KEY, Comparison::Invalid)]
349 #[case(Atom::BYTES, Atom::INT, Comparison::Cast)]
350 #[case(Atom::BYTES_32, Atom::NIL, Comparison::Invalid)]
351 #[case(Atom::BYTES_32, Atom::FALSE, Comparison::Invalid)]
352 #[case(Atom::BYTES_32, Atom::TRUE, Comparison::Invalid)]
353 #[case(Atom::BYTES_32, Atom::BYTES, Comparison::Assign)]
354 #[case(Atom::BYTES_32, Atom::BYTES_32, Comparison::Assign)]
355 #[case(Atom::BYTES_32, Atom::PUBLIC_KEY, Comparison::Invalid)]
356 #[case(Atom::BYTES_32, Atom::INT, Comparison::Cast)]
357 #[case(Atom::PUBLIC_KEY, Atom::NIL, Comparison::Invalid)]
358 #[case(Atom::PUBLIC_KEY, Atom::FALSE, Comparison::Invalid)]
359 #[case(Atom::PUBLIC_KEY, Atom::TRUE, Comparison::Invalid)]
360 #[case(Atom::PUBLIC_KEY, Atom::BYTES, Comparison::Cast)]
361 #[case(Atom::PUBLIC_KEY, Atom::BYTES_32, Comparison::Invalid)]
362 #[case(Atom::PUBLIC_KEY, Atom::PUBLIC_KEY, Comparison::Assign)]
363 #[case(Atom::PUBLIC_KEY, Atom::INT, Comparison::Cast)]
364 #[case(Atom::INT, Atom::NIL, Comparison::Invalid)]
365 #[case(Atom::INT, Atom::FALSE, Comparison::Invalid)]
366 #[case(Atom::INT, Atom::TRUE, Comparison::Invalid)]
367 #[case(Atom::INT, Atom::BYTES, Comparison::Cast)]
368 #[case(Atom::INT, Atom::BYTES_32, Comparison::Invalid)]
369 #[case(Atom::INT, Atom::PUBLIC_KEY, Comparison::Invalid)]
370 #[case(Atom::INT, Atom::INT, Comparison::Assign)]
371 #[case(Atom::new(AtomSemantic::Int, Some(AtomRestriction::Value(Cow::Borrowed(&[1])))), Atom::INT, Comparison::Assign)]
372 #[case(Atom::new(AtomSemantic::Int, Some(AtomRestriction::Value(Cow::Borrowed(&[1])))), Atom::BYTES, Comparison::Cast)]
373 fn test_atoms(#[case] lhs: Atom, #[case] rhs: Atom, #[case] expected: Comparison) {
374 let mut arena = Arena::new();
375 let builtins = BuiltinTypes::new(&mut arena);
376 let lhs_id = arena.alloc(Type::Atom(lhs.clone()));
377 let rhs_id = arena.alloc(Type::Atom(rhs.clone()));
378 assert_eq!(
379 compare(&mut arena, &builtins, lhs_id, rhs_id),
380 expected,
381 "{lhs} -> {rhs}"
382 );
383 }
384}