1use std::{
2 cmp::{max, min},
3 collections::{HashMap, HashSet},
4};
5
6use id_arena::Arena;
7use indexmap::{IndexMap, IndexSet};
8use log::{debug, trace};
9
10use crate::{
11 AtomRestriction, AtomSemantic, BuiltinTypes, Type, TypeId, stringify_impl, substitute,
12};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
15pub enum Comparison {
16 Assign,
17 Cast,
18 Invalid,
19}
20
21#[derive(Debug, Default)]
22pub(crate) struct ComparisonContext<'a> {
23 infer: Option<&'a mut HashMap<TypeId, TypeId>>,
24 stack: IndexSet<(TypeId, TypeId)>,
25}
26
27pub fn compare_with_inference(
28 arena: &mut Arena<Type>,
29 builtins: &BuiltinTypes,
30 lhs: TypeId,
31 rhs: TypeId,
32 infer: Option<&mut HashMap<TypeId, TypeId>>,
33) -> Comparison {
34 let lhs = substitute(arena, lhs);
35 let rhs = substitute(arena, rhs);
36 let lhs_name = stringify_impl(arena, lhs, &mut IndexMap::new());
37 let rhs_name = stringify_impl(arena, rhs, &mut IndexMap::new());
38 trace!("Comparing {lhs_name} to {rhs_name}");
39 let result = compare_impl(
40 arena,
41 builtins,
42 &mut ComparisonContext {
43 infer,
44 ..Default::default()
45 },
46 lhs,
47 rhs,
48 None,
49 None,
50 );
51 trace!("Comparison from {lhs_name} to {rhs_name} yielded {result:?}");
52 result
53}
54
55pub fn compare(
56 arena: &mut Arena<Type>,
57 builtins: &BuiltinTypes,
58 lhs: TypeId,
59 rhs: TypeId,
60) -> Comparison {
61 compare_with_inference(arena, builtins, lhs, rhs, None)
62}
63
64pub(crate) fn compare_impl(
65 arena: &Arena<Type>,
66 builtins: &BuiltinTypes,
67 ctx: &mut ComparisonContext<'_>,
68 lhs: TypeId,
69 rhs: TypeId,
70 lhs_semantic: Option<TypeId>,
71 rhs_semantic: Option<TypeId>,
72) -> Comparison {
73 if !ctx.stack.insert((lhs, rhs)) {
74 return Comparison::Assign;
75 }
76
77 let result = match (arena[lhs].clone(), arena[rhs].clone()) {
78 (Type::Apply(_), _) | (_, Type::Apply(_)) => unreachable!(),
79 (Type::Ref(lhs), _) => {
80 compare_impl(arena, builtins, ctx, lhs, rhs, lhs_semantic, rhs_semantic)
81 }
82 (_, Type::Ref(rhs)) => {
83 compare_impl(arena, builtins, ctx, lhs, rhs, lhs_semantic, rhs_semantic)
84 }
85 (Type::Unresolved, _) | (_, Type::Unresolved) => Comparison::Assign,
86 (Type::Never, _) => Comparison::Assign,
87 (_, Type::Any) => Comparison::Assign,
88 (Type::Atom(lhs), Type::Atom(rhs)) => {
89 let semantic = if lhs.semantic == rhs.semantic
90 || rhs.semantic == AtomSemantic::Any
91 || (lhs.semantic == AtomSemantic::String && rhs.semantic == AtomSemantic::Bytes)
92 || (lhs.semantic == AtomSemantic::Bytes && rhs.semantic == AtomSemantic::String)
93 {
94 Comparison::Assign
95 } else {
96 Comparison::Cast
97 };
98
99 let restriction = match (lhs.restriction, rhs.restriction) {
100 (_, None) => Comparison::Assign,
101 (None, _) => Comparison::Invalid,
102 (Some(AtomRestriction::Length(lhs)), Some(AtomRestriction::Length(rhs))) => {
103 if lhs == rhs {
104 Comparison::Assign
105 } else {
106 Comparison::Invalid
107 }
108 }
109 (Some(AtomRestriction::Value(lhs)), Some(AtomRestriction::Value(rhs))) => {
110 if lhs == rhs {
111 Comparison::Assign
112 } else {
113 Comparison::Invalid
114 }
115 }
116 (Some(AtomRestriction::Length(_)), Some(AtomRestriction::Value(_))) => {
117 Comparison::Invalid
118 }
119 (Some(AtomRestriction::Value(lhs)), Some(AtomRestriction::Length(rhs))) => {
120 if lhs.len() == rhs {
121 Comparison::Assign
122 } else {
123 Comparison::Invalid
124 }
125 }
126 };
127
128 max(semantic, restriction)
129 }
130 (Type::Pair(lhs), Type::Pair(rhs)) => {
131 let first = compare_impl(arena, builtins, ctx, lhs.first, rhs.first, None, None);
132 let rest = compare_impl(arena, builtins, ctx, lhs.rest, rhs.rest, None, None);
133 max(first, rest)
134 }
135 (Type::Atom(_), Type::Pair(_)) => Comparison::Invalid,
136 (Type::Pair(_), Type::Atom(_)) => Comparison::Invalid,
137 (Type::Function(lhs), Type::Function(rhs)) => {
138 if lhs.nil_terminated != rhs.nil_terminated || lhs.params.len() != rhs.params.len() {
141 Comparison::Invalid
142 } else {
143 let mut result = compare_impl(arena, builtins, ctx, lhs.ret, rhs.ret, None, None);
144
145 for (i, param) in lhs.params.iter().enumerate() {
146 result = max(
147 result,
148 compare_impl(arena, builtins, ctx, *param, rhs.params[i], None, None),
149 );
150 }
151
152 result
153 }
154 }
155 (_, Type::Generic(_)) => {
156 if lhs == rhs {
157 Comparison::Assign
158 } else if let Some(infer) = &mut ctx.infer {
159 if let Some(rhs) = infer.get(&rhs).copied() {
160 compare_impl(arena, builtins, ctx, lhs, rhs, lhs_semantic, rhs_semantic)
161 } else {
162 debug!(
163 "Inferring {} is {}",
164 stringify_impl(arena, rhs, &mut IndexMap::new()),
165 stringify_impl(arena, lhs, &mut IndexMap::new())
166 );
167 infer.insert(rhs, lhs);
168 Comparison::Assign
169 }
170 } else if let Type::Union(lhs) = arena[lhs].clone() {
171 let mut result = Comparison::Assign;
172
173 for &id in &lhs.types {
174 result = max(
175 result,
176 compare_impl(arena, builtins, ctx, id, rhs, lhs_semantic, rhs_semantic),
177 );
178 }
179
180 result
181 } else {
182 Comparison::Invalid
183 }
184 }
185 (Type::Struct(lhs), Type::Struct(rhs)) => max(
186 compare_impl(arena, builtins, ctx, lhs.inner, rhs.inner, None, None),
187 if lhs.semantic == rhs.semantic {
188 Comparison::Assign
189 } else {
190 Comparison::Cast
191 },
192 ),
193 (Type::Struct(lhs), _) => {
194 let inner = compare_impl(
195 arena,
196 builtins,
197 ctx,
198 lhs.inner,
199 rhs,
200 Some(lhs.semantic),
201 rhs_semantic,
202 );
203
204 let rhs_semantics = semantics_of(arena, rhs);
205
206 if rhs_semantic == Some(lhs.semantic)
207 || rhs_semantics.contains(&Some(Semantic::Id(lhs.semantic)))
208 || rhs_semantics.contains(&Some(Semantic::All))
209 {
210 inner
211 } else {
212 max(inner, Comparison::Cast)
213 }
214 }
215 (_, Type::Struct(rhs)) => {
216 let inner = compare_impl(
217 arena,
218 builtins,
219 ctx,
220 lhs,
221 rhs.inner,
222 lhs_semantic,
223 Some(rhs.semantic),
224 );
225
226 let semantics = semantics_of(arena, lhs);
227
228 if (semantics.len() != 1
229 || (!semantics.contains(&Some(Semantic::Id(rhs.semantic)))
230 && !semantics.contains(&Some(Semantic::All))))
231 && lhs_semantic != Some(rhs.semantic)
232 {
233 max(inner, Comparison::Cast)
234 } else {
235 inner
236 }
237 }
238 (Type::Alias(lhs), _) => compare_impl(
239 arena,
240 builtins,
241 ctx,
242 lhs.inner,
243 rhs,
244 lhs_semantic,
245 rhs_semantic,
246 ),
247 (_, Type::Alias(rhs)) => compare_impl(
248 arena,
249 builtins,
250 ctx,
251 lhs,
252 rhs.inner,
253 lhs_semantic,
254 rhs_semantic,
255 ),
256 (Type::Generic(_) | Type::Any, _) => compare_impl(
257 arena,
258 builtins,
259 ctx,
260 builtins.recursive_any,
261 rhs,
262 lhs_semantic,
263 rhs_semantic,
264 ),
265 (Type::Function(_), _) => {
266 let result = if let Type::Union(rhs) = arena[rhs].clone() {
267 let mut result = Comparison::Invalid;
268
269 for &id in &rhs.types {
270 result = min(
271 result,
272 compare_impl(arena, builtins, ctx, lhs, id, lhs_semantic, rhs_semantic),
273 );
274 }
275
276 result
277 } else {
278 Comparison::Invalid
279 };
280
281 min(
282 result,
283 compare_impl(
284 arena,
285 builtins,
286 ctx,
287 builtins.recursive_any,
288 rhs,
289 lhs_semantic,
290 rhs_semantic,
291 ),
292 )
293 }
294 (Type::Union(lhs), _) => {
295 let mut result = Comparison::Assign;
296
297 for &id in &lhs.types {
298 result = max(
299 result,
300 compare_impl(arena, builtins, ctx, id, rhs, lhs_semantic, rhs_semantic),
301 );
302 }
303
304 result
305 }
306 (_, Type::Union(rhs)) => {
307 let mut result = Comparison::Invalid;
308
309 for &id in &rhs.types {
310 result = min(
311 result,
312 compare_impl(arena, builtins, ctx, lhs, id, lhs_semantic, rhs_semantic),
313 );
314 }
315
316 result
317 }
318 (_, Type::Never) => Comparison::Invalid,
319 (_, Type::Function(_)) => Comparison::Invalid,
320 };
321
322 ctx.stack.pop().unwrap();
323
324 result
325}
326
327#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
328enum Semantic {
329 All,
330 Id(TypeId),
331}
332
333fn semantics_of(arena: &Arena<Type>, id: TypeId) -> HashSet<Option<Semantic>> {
334 match arena[id].clone() {
335 Type::Apply(_) => unreachable!(),
336 Type::Ref(id) => semantics_of(arena, id),
337 Type::Alias(alias) => semantics_of(arena, alias.inner),
338 Type::Never => HashSet::new(),
339 Type::Any => HashSet::from_iter([Some(Semantic::All)]),
340 Type::Unresolved | Type::Generic(_) | Type::Atom(_) | Type::Pair(_) | Type::Function(_) => {
341 HashSet::from_iter([None])
342 }
343 Type::Struct(ty) => HashSet::from_iter([Some(Semantic::Id(ty.semantic))]),
344 Type::Union(ty) => {
345 let mut semantics = HashSet::new();
346
347 for &id in &ty.types {
348 semantics.extend(semantics_of(arena, id));
349 }
350
351 semantics
352 }
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use std::borrow::Cow;
359
360 use id_arena::Arena;
361 use rstest::rstest;
362
363 use crate::{Atom, Type, compare};
364
365 use super::*;
366
367 #[rstest]
368 #[case(Atom::NIL, Atom::NIL, Comparison::Assign)]
369 #[case(Atom::NIL, Atom::FALSE, Comparison::Cast)]
370 #[case(Atom::NIL, Atom::TRUE, Comparison::Invalid)]
371 #[case(Atom::NIL, Atom::BYTES, Comparison::Assign)]
372 #[case(Atom::NIL, Atom::BYTES_32, Comparison::Invalid)]
373 #[case(Atom::NIL, Atom::PUBLIC_KEY, Comparison::Invalid)]
374 #[case(Atom::NIL, Atom::SIGNATURE, Comparison::Invalid)]
375 #[case(Atom::NIL, Atom::INT, Comparison::Cast)]
376 #[case(Atom::FALSE, Atom::NIL, Comparison::Cast)]
377 #[case(Atom::FALSE, Atom::FALSE, Comparison::Assign)]
378 #[case(Atom::FALSE, Atom::TRUE, Comparison::Invalid)]
379 #[case(Atom::FALSE, Atom::BYTES, Comparison::Cast)]
380 #[case(Atom::FALSE, Atom::BYTES_32, Comparison::Invalid)]
381 #[case(Atom::FALSE, Atom::PUBLIC_KEY, Comparison::Invalid)]
382 #[case(Atom::FALSE, Atom::SIGNATURE, Comparison::Invalid)]
383 #[case(Atom::FALSE, Atom::INT, Comparison::Cast)]
384 #[case(Atom::TRUE, Atom::NIL, Comparison::Invalid)]
385 #[case(Atom::TRUE, Atom::FALSE, Comparison::Invalid)]
386 #[case(Atom::TRUE, Atom::TRUE, Comparison::Assign)]
387 #[case(Atom::TRUE, Atom::BYTES, Comparison::Cast)]
388 #[case(Atom::TRUE, Atom::BYTES_32, Comparison::Invalid)]
389 #[case(Atom::TRUE, Atom::PUBLIC_KEY, Comparison::Invalid)]
390 #[case(Atom::TRUE, Atom::SIGNATURE, Comparison::Invalid)]
391 #[case(Atom::TRUE, Atom::INT, Comparison::Cast)]
392 #[case(Atom::BYTES, Atom::NIL, Comparison::Invalid)]
393 #[case(Atom::BYTES, Atom::FALSE, Comparison::Invalid)]
394 #[case(Atom::BYTES, Atom::TRUE, Comparison::Invalid)]
395 #[case(Atom::BYTES, Atom::BYTES, Comparison::Assign)]
396 #[case(Atom::BYTES, Atom::BYTES_32, Comparison::Invalid)]
397 #[case(Atom::BYTES, Atom::PUBLIC_KEY, Comparison::Invalid)]
398 #[case(Atom::BYTES, Atom::SIGNATURE, Comparison::Invalid)]
399 #[case(Atom::BYTES, Atom::INT, Comparison::Cast)]
400 #[case(Atom::BYTES_32, Atom::NIL, Comparison::Invalid)]
401 #[case(Atom::BYTES_32, Atom::FALSE, Comparison::Invalid)]
402 #[case(Atom::BYTES_32, Atom::TRUE, Comparison::Invalid)]
403 #[case(Atom::BYTES_32, Atom::BYTES, Comparison::Assign)]
404 #[case(Atom::BYTES_32, Atom::BYTES_32, Comparison::Assign)]
405 #[case(Atom::BYTES_32, Atom::PUBLIC_KEY, Comparison::Invalid)]
406 #[case(Atom::BYTES_32, Atom::SIGNATURE, Comparison::Invalid)]
407 #[case(Atom::BYTES_32, Atom::INT, Comparison::Cast)]
408 #[case(Atom::PUBLIC_KEY, Atom::NIL, Comparison::Invalid)]
409 #[case(Atom::PUBLIC_KEY, Atom::FALSE, Comparison::Invalid)]
410 #[case(Atom::PUBLIC_KEY, Atom::TRUE, Comparison::Invalid)]
411 #[case(Atom::PUBLIC_KEY, Atom::BYTES, Comparison::Cast)]
412 #[case(Atom::PUBLIC_KEY, Atom::BYTES_32, Comparison::Invalid)]
413 #[case(Atom::PUBLIC_KEY, Atom::PUBLIC_KEY, Comparison::Assign)]
414 #[case(Atom::PUBLIC_KEY, Atom::SIGNATURE, Comparison::Invalid)]
415 #[case(Atom::PUBLIC_KEY, Atom::INT, Comparison::Cast)]
416 #[case(Atom::SIGNATURE, Atom::NIL, Comparison::Invalid)]
417 #[case(Atom::SIGNATURE, Atom::FALSE, Comparison::Invalid)]
418 #[case(Atom::SIGNATURE, Atom::TRUE, Comparison::Invalid)]
419 #[case(Atom::SIGNATURE, Atom::BYTES, Comparison::Cast)]
420 #[case(Atom::SIGNATURE, Atom::BYTES_32, Comparison::Invalid)]
421 #[case(Atom::SIGNATURE, Atom::PUBLIC_KEY, Comparison::Invalid)]
422 #[case(Atom::SIGNATURE, Atom::SIGNATURE, Comparison::Assign)]
423 #[case(Atom::SIGNATURE, Atom::INT, Comparison::Cast)]
424 #[case(Atom::INT, Atom::NIL, Comparison::Invalid)]
425 #[case(Atom::INT, Atom::FALSE, Comparison::Invalid)]
426 #[case(Atom::INT, Atom::TRUE, Comparison::Invalid)]
427 #[case(Atom::INT, Atom::BYTES, Comparison::Cast)]
428 #[case(Atom::INT, Atom::BYTES_32, Comparison::Invalid)]
429 #[case(Atom::INT, Atom::PUBLIC_KEY, Comparison::Invalid)]
430 #[case(Atom::INT, Atom::SIGNATURE, Comparison::Invalid)]
431 #[case(Atom::INT, Atom::INT, Comparison::Assign)]
432 #[case(Atom::new(AtomSemantic::Int, Some(AtomRestriction::Value(Cow::Borrowed(&[1])))), Atom::INT, Comparison::Assign)]
433 #[case(Atom::new(AtomSemantic::Int, Some(AtomRestriction::Value(Cow::Borrowed(&[1])))), Atom::BYTES, Comparison::Cast)]
434 fn test_atoms(#[case] lhs: Atom, #[case] rhs: Atom, #[case] expected: Comparison) {
435 let mut arena = Arena::new();
436 let builtins = BuiltinTypes::new(&mut arena);
437 let lhs_id = arena.alloc(Type::Atom(lhs.clone()));
438 let rhs_id = arena.alloc(Type::Atom(rhs.clone()));
439 assert_eq!(
440 compare(&mut arena, &builtins, lhs_id, rhs_id),
441 expected,
442 "{lhs} -> {rhs}"
443 );
444 }
445}