1use std::{
2 cmp::{max, min},
3 collections::{HashMap, HashSet},
4};
5
6use id_arena::Arena;
7use indexmap::{IndexMap, IndexSet};
8use log::{debug, trace};
9
10use crate::{
11 AtomRestriction, AtomSemantic, BuiltinTypes, Type, TypeId, stringify_impl, substitute,
12};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
15pub enum Comparison {
16 Assign,
17 Cast,
18 Invalid,
19}
20
21#[derive(Debug, Default)]
22pub(crate) struct ComparisonContext<'a> {
23 infer: Option<&'a mut HashMap<TypeId, Vec<TypeId>>>,
24 stack: IndexSet<(TypeId, TypeId)>,
25}
26
27pub fn compare_with_inference(
28 arena: &mut Arena<Type>,
29 builtins: &BuiltinTypes,
30 lhs: TypeId,
31 rhs: TypeId,
32 infer: Option<&mut HashMap<TypeId, Vec<TypeId>>>,
33) -> Comparison {
34 let lhs = substitute(arena, lhs);
35 let rhs = substitute(arena, rhs);
36 let lhs_name = stringify_impl(arena, lhs, &mut IndexMap::new());
37 let rhs_name = stringify_impl(arena, rhs, &mut IndexMap::new());
38 trace!("Comparing {lhs_name} to {rhs_name}");
39 let result = compare_impl(
40 arena,
41 builtins,
42 &mut ComparisonContext {
43 infer,
44 ..Default::default()
45 },
46 lhs,
47 rhs,
48 None,
49 None,
50 );
51 trace!("Comparison from {lhs_name} to {rhs_name} yielded {result:?}");
52 result
53}
54
55pub fn compare(
56 arena: &mut Arena<Type>,
57 builtins: &BuiltinTypes,
58 lhs: TypeId,
59 rhs: TypeId,
60) -> Comparison {
61 compare_with_inference(arena, builtins, lhs, rhs, None)
62}
63
64pub(crate) fn compare_impl(
65 arena: &Arena<Type>,
66 builtins: &BuiltinTypes,
67 ctx: &mut ComparisonContext<'_>,
68 lhs: TypeId,
69 rhs: TypeId,
70 lhs_semantic: Option<TypeId>,
71 rhs_semantic: Option<TypeId>,
72) -> Comparison {
73 if !ctx.stack.insert((lhs, rhs)) {
74 return Comparison::Assign;
75 }
76
77 let result = match (arena[lhs].clone(), arena[rhs].clone()) {
78 (Type::Apply(_), _) | (_, Type::Apply(_)) => unreachable!(),
79 (Type::Ref(lhs), _) => {
80 compare_impl(arena, builtins, ctx, lhs, rhs, lhs_semantic, rhs_semantic)
81 }
82 (_, Type::Ref(rhs)) => {
83 compare_impl(arena, builtins, ctx, lhs, rhs, lhs_semantic, rhs_semantic)
84 }
85 (Type::Unresolved, _) | (_, Type::Unresolved) => Comparison::Assign,
86 (Type::Never, _) => Comparison::Assign,
87 (_, Type::Any) => Comparison::Assign,
88 (Type::Atom(lhs), Type::Atom(rhs)) => {
89 let semantic = if lhs.semantic == rhs.semantic
90 || rhs.semantic == AtomSemantic::Any
91 || (lhs.semantic == AtomSemantic::String && rhs.semantic == AtomSemantic::Bytes)
92 || (lhs.semantic == AtomSemantic::Bytes && rhs.semantic == AtomSemantic::String)
93 {
94 Comparison::Assign
95 } else {
96 Comparison::Cast
97 };
98
99 let restriction = match (lhs.restriction, rhs.restriction) {
100 (_, None) => Comparison::Assign,
101 (None, _) => Comparison::Invalid,
102 (Some(AtomRestriction::Length(lhs)), Some(AtomRestriction::Length(rhs))) => {
103 if lhs == rhs {
104 Comparison::Assign
105 } else {
106 Comparison::Invalid
107 }
108 }
109 (Some(AtomRestriction::Value(lhs)), Some(AtomRestriction::Value(rhs))) => {
110 if lhs == rhs {
111 Comparison::Assign
112 } else {
113 Comparison::Invalid
114 }
115 }
116 (Some(AtomRestriction::Length(_)), Some(AtomRestriction::Value(_))) => {
117 Comparison::Invalid
118 }
119 (Some(AtomRestriction::Value(lhs)), Some(AtomRestriction::Length(rhs))) => {
120 if lhs.len() == rhs {
121 Comparison::Assign
122 } else {
123 Comparison::Invalid
124 }
125 }
126 };
127
128 max(semantic, restriction)
129 }
130 (Type::Pair(lhs), Type::Pair(rhs)) => {
131 let first = compare_impl(arena, builtins, ctx, lhs.first, rhs.first, None, None);
132 let rest = compare_impl(arena, builtins, ctx, lhs.rest, rhs.rest, None, None);
133 max(first, rest)
134 }
135 (Type::Atom(_), Type::Pair(_)) => Comparison::Invalid,
136 (Type::Pair(_), Type::Atom(_)) => Comparison::Invalid,
137 (Type::Function(lhs), Type::Function(rhs)) => {
138 if lhs.nil_terminated != rhs.nil_terminated || lhs.params.len() != rhs.params.len() {
141 Comparison::Invalid
142 } else {
143 let mut result = compare_impl(arena, builtins, ctx, lhs.ret, rhs.ret, None, None);
144
145 for (i, param) in lhs.params.iter().enumerate() {
146 result = max(
147 result,
148 compare_impl(arena, builtins, ctx, *param.1, rhs.params[i], None, None),
149 );
150 }
151
152 result
153 }
154 }
155 (_, Type::Generic(_)) => {
156 if lhs == rhs {
157 Comparison::Assign
158 } else if let Some(infer) = &mut ctx.infer {
159 debug!(
160 "Inferring {} could include {}",
161 stringify_impl(arena, rhs, &mut IndexMap::new()),
162 stringify_impl(arena, lhs, &mut IndexMap::new())
163 );
164 infer.entry(rhs).or_default().push(lhs);
165 Comparison::Assign
166 } else if let Type::Union(lhs) = arena[lhs].clone() {
167 let mut result = Comparison::Assign;
168
169 for &id in &lhs.types {
170 result = max(
171 result,
172 compare_impl(arena, builtins, ctx, id, rhs, lhs_semantic, rhs_semantic),
173 );
174 }
175
176 result
177 } else {
178 Comparison::Invalid
179 }
180 }
181 (Type::Struct(lhs), Type::Struct(rhs)) => max(
182 compare_impl(arena, builtins, ctx, lhs.inner, rhs.inner, None, None),
183 if lhs.semantic == rhs.semantic {
184 Comparison::Assign
185 } else {
186 Comparison::Cast
187 },
188 ),
189 (Type::Struct(lhs), _) => {
190 let inner = compare_impl(
191 arena,
192 builtins,
193 ctx,
194 lhs.inner,
195 rhs,
196 Some(lhs.semantic),
197 rhs_semantic,
198 );
199
200 let rhs_semantics = semantics_of(arena, rhs);
201
202 if rhs_semantic == Some(lhs.semantic)
203 || rhs_semantics.contains(&Some(Semantic::Id(lhs.semantic)))
204 || rhs_semantics.contains(&Some(Semantic::All))
205 {
206 inner
207 } else {
208 max(inner, Comparison::Cast)
209 }
210 }
211 (_, Type::Struct(rhs)) => {
212 let inner = compare_impl(
213 arena,
214 builtins,
215 ctx,
216 lhs,
217 rhs.inner,
218 lhs_semantic,
219 Some(rhs.semantic),
220 );
221
222 let semantics = semantics_of(arena, lhs);
223
224 if (semantics.len() != 1
225 || (!semantics.contains(&Some(Semantic::Id(rhs.semantic)))
226 && !semantics.contains(&Some(Semantic::All))))
227 && lhs_semantic != Some(rhs.semantic)
228 {
229 max(inner, Comparison::Cast)
230 } else {
231 inner
232 }
233 }
234 (Type::Alias(lhs), _) => compare_impl(
235 arena,
236 builtins,
237 ctx,
238 lhs.inner,
239 rhs,
240 lhs_semantic,
241 rhs_semantic,
242 ),
243 (_, Type::Alias(rhs)) => compare_impl(
244 arena,
245 builtins,
246 ctx,
247 lhs,
248 rhs.inner,
249 lhs_semantic,
250 rhs_semantic,
251 ),
252 (Type::Generic(_) | Type::Any, _) => compare_impl(
253 arena,
254 builtins,
255 ctx,
256 builtins.recursive_any,
257 rhs,
258 lhs_semantic,
259 rhs_semantic,
260 ),
261 (Type::Function(_), _) => {
262 let result = if let Type::Union(rhs) = arena[rhs].clone() {
263 let mut result = Comparison::Invalid;
264
265 for &id in &rhs.types {
266 result = min(
267 result,
268 compare_impl(arena, builtins, ctx, lhs, id, lhs_semantic, rhs_semantic),
269 );
270 }
271
272 result
273 } else {
274 Comparison::Invalid
275 };
276
277 min(
278 result,
279 compare_impl(
280 arena,
281 builtins,
282 ctx,
283 builtins.recursive_any,
284 rhs,
285 lhs_semantic,
286 rhs_semantic,
287 ),
288 )
289 }
290 (Type::Union(lhs), _) => {
291 let mut result = Comparison::Assign;
292
293 for &id in &lhs.types {
294 result = max(
295 result,
296 compare_impl(arena, builtins, ctx, id, rhs, lhs_semantic, rhs_semantic),
297 );
298 }
299
300 result
301 }
302 (_, Type::Union(rhs)) => {
303 let mut result = Comparison::Invalid;
304
305 for &id in &rhs.types {
306 result = min(
307 result,
308 compare_impl(arena, builtins, ctx, lhs, id, lhs_semantic, rhs_semantic),
309 );
310 }
311
312 result
313 }
314 (_, Type::Never) => Comparison::Invalid,
315 (_, Type::Function(_)) => Comparison::Invalid,
316 };
317
318 ctx.stack.pop().unwrap();
319
320 result
321}
322
323#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
324enum Semantic {
325 All,
326 Id(TypeId),
327}
328
329fn semantics_of(arena: &Arena<Type>, id: TypeId) -> HashSet<Option<Semantic>> {
330 match arena[id].clone() {
331 Type::Apply(_) => unreachable!(),
332 Type::Ref(id) => semantics_of(arena, id),
333 Type::Alias(alias) => semantics_of(arena, alias.inner),
334 Type::Never => HashSet::new(),
335 Type::Any => HashSet::from_iter([Some(Semantic::All)]),
336 Type::Unresolved | Type::Generic(_) | Type::Atom(_) | Type::Pair(_) | Type::Function(_) => {
337 HashSet::from_iter([None])
338 }
339 Type::Struct(ty) => HashSet::from_iter([Some(Semantic::Id(ty.semantic))]),
340 Type::Union(ty) => {
341 let mut semantics = HashSet::new();
342
343 for &id in &ty.types {
344 semantics.extend(semantics_of(arena, id));
345 }
346
347 semantics
348 }
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use std::borrow::Cow;
355
356 use id_arena::Arena;
357 use rstest::rstest;
358
359 use crate::{Atom, Type, compare};
360
361 use super::*;
362
363 #[rstest]
364 #[case(Atom::NIL, Atom::NIL, Comparison::Assign)]
365 #[case(Atom::NIL, Atom::FALSE, Comparison::Cast)]
366 #[case(Atom::NIL, Atom::TRUE, Comparison::Invalid)]
367 #[case(Atom::NIL, Atom::BYTES, Comparison::Assign)]
368 #[case(Atom::NIL, Atom::BYTES_32, Comparison::Invalid)]
369 #[case(Atom::NIL, Atom::PUBLIC_KEY, Comparison::Invalid)]
370 #[case(Atom::NIL, Atom::SIGNATURE, Comparison::Invalid)]
371 #[case(Atom::NIL, Atom::INT, Comparison::Cast)]
372 #[case(Atom::FALSE, Atom::NIL, Comparison::Cast)]
373 #[case(Atom::FALSE, Atom::FALSE, Comparison::Assign)]
374 #[case(Atom::FALSE, Atom::TRUE, Comparison::Invalid)]
375 #[case(Atom::FALSE, Atom::BYTES, Comparison::Cast)]
376 #[case(Atom::FALSE, Atom::BYTES_32, Comparison::Invalid)]
377 #[case(Atom::FALSE, Atom::PUBLIC_KEY, Comparison::Invalid)]
378 #[case(Atom::FALSE, Atom::SIGNATURE, Comparison::Invalid)]
379 #[case(Atom::FALSE, Atom::INT, Comparison::Cast)]
380 #[case(Atom::TRUE, Atom::NIL, Comparison::Invalid)]
381 #[case(Atom::TRUE, Atom::FALSE, Comparison::Invalid)]
382 #[case(Atom::TRUE, Atom::TRUE, Comparison::Assign)]
383 #[case(Atom::TRUE, Atom::BYTES, Comparison::Cast)]
384 #[case(Atom::TRUE, Atom::BYTES_32, Comparison::Invalid)]
385 #[case(Atom::TRUE, Atom::PUBLIC_KEY, Comparison::Invalid)]
386 #[case(Atom::TRUE, Atom::SIGNATURE, Comparison::Invalid)]
387 #[case(Atom::TRUE, Atom::INT, Comparison::Cast)]
388 #[case(Atom::BYTES, Atom::NIL, Comparison::Invalid)]
389 #[case(Atom::BYTES, Atom::FALSE, Comparison::Invalid)]
390 #[case(Atom::BYTES, Atom::TRUE, Comparison::Invalid)]
391 #[case(Atom::BYTES, Atom::BYTES, Comparison::Assign)]
392 #[case(Atom::BYTES, Atom::BYTES_32, Comparison::Invalid)]
393 #[case(Atom::BYTES, Atom::PUBLIC_KEY, Comparison::Invalid)]
394 #[case(Atom::BYTES, Atom::SIGNATURE, Comparison::Invalid)]
395 #[case(Atom::BYTES, Atom::INT, Comparison::Cast)]
396 #[case(Atom::BYTES_32, Atom::NIL, Comparison::Invalid)]
397 #[case(Atom::BYTES_32, Atom::FALSE, Comparison::Invalid)]
398 #[case(Atom::BYTES_32, Atom::TRUE, Comparison::Invalid)]
399 #[case(Atom::BYTES_32, Atom::BYTES, Comparison::Assign)]
400 #[case(Atom::BYTES_32, Atom::BYTES_32, Comparison::Assign)]
401 #[case(Atom::BYTES_32, Atom::PUBLIC_KEY, Comparison::Invalid)]
402 #[case(Atom::BYTES_32, Atom::SIGNATURE, Comparison::Invalid)]
403 #[case(Atom::BYTES_32, Atom::INT, Comparison::Cast)]
404 #[case(Atom::PUBLIC_KEY, Atom::NIL, Comparison::Invalid)]
405 #[case(Atom::PUBLIC_KEY, Atom::FALSE, Comparison::Invalid)]
406 #[case(Atom::PUBLIC_KEY, Atom::TRUE, Comparison::Invalid)]
407 #[case(Atom::PUBLIC_KEY, Atom::BYTES, Comparison::Cast)]
408 #[case(Atom::PUBLIC_KEY, Atom::BYTES_32, Comparison::Invalid)]
409 #[case(Atom::PUBLIC_KEY, Atom::PUBLIC_KEY, Comparison::Assign)]
410 #[case(Atom::PUBLIC_KEY, Atom::SIGNATURE, Comparison::Invalid)]
411 #[case(Atom::PUBLIC_KEY, Atom::INT, Comparison::Cast)]
412 #[case(Atom::SIGNATURE, Atom::NIL, Comparison::Invalid)]
413 #[case(Atom::SIGNATURE, Atom::FALSE, Comparison::Invalid)]
414 #[case(Atom::SIGNATURE, Atom::TRUE, Comparison::Invalid)]
415 #[case(Atom::SIGNATURE, Atom::BYTES, Comparison::Cast)]
416 #[case(Atom::SIGNATURE, Atom::BYTES_32, Comparison::Invalid)]
417 #[case(Atom::SIGNATURE, Atom::PUBLIC_KEY, Comparison::Invalid)]
418 #[case(Atom::SIGNATURE, Atom::SIGNATURE, Comparison::Assign)]
419 #[case(Atom::SIGNATURE, Atom::INT, Comparison::Cast)]
420 #[case(Atom::INT, Atom::NIL, Comparison::Invalid)]
421 #[case(Atom::INT, Atom::FALSE, Comparison::Invalid)]
422 #[case(Atom::INT, Atom::TRUE, Comparison::Invalid)]
423 #[case(Atom::INT, Atom::BYTES, Comparison::Cast)]
424 #[case(Atom::INT, Atom::BYTES_32, Comparison::Invalid)]
425 #[case(Atom::INT, Atom::PUBLIC_KEY, Comparison::Invalid)]
426 #[case(Atom::INT, Atom::SIGNATURE, Comparison::Invalid)]
427 #[case(Atom::INT, Atom::INT, Comparison::Assign)]
428 #[case(Atom::new(AtomSemantic::Int, Some(AtomRestriction::Value(Cow::Borrowed(&[1])))), Atom::INT, Comparison::Assign)]
429 #[case(Atom::new(AtomSemantic::Int, Some(AtomRestriction::Value(Cow::Borrowed(&[1])))), Atom::BYTES, Comparison::Cast)]
430 fn test_atoms(#[case] lhs: Atom, #[case] rhs: Atom, #[case] expected: Comparison) {
431 let mut arena = Arena::new();
432 let builtins = BuiltinTypes::new(&mut arena);
433 let lhs_id = arena.alloc(Type::Atom(lhs.clone()));
434 let rhs_id = arena.alloc(Type::Atom(rhs.clone()));
435 assert_eq!(
436 compare(&mut arena, &builtins, lhs_id, rhs_id),
437 expected,
438 "{lhs} -> {rhs}"
439 );
440 }
441}