rue_types/
compare.rs

1use std::{
2    cmp::{max, min},
3    collections::{HashMap, HashSet},
4};
5
6use id_arena::Arena;
7use indexmap::{IndexMap, IndexSet};
8use log::{debug, trace};
9
10use crate::{
11    AtomRestriction, AtomSemantic, BuiltinTypes, Type, TypeId, stringify_impl, substitute,
12};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
15pub enum Comparison {
16    Assign,
17    Cast,
18    Invalid,
19}
20
21#[derive(Debug, Default)]
22pub(crate) struct ComparisonContext<'a> {
23    infer: Option<&'a mut HashMap<TypeId, TypeId>>,
24    stack: IndexSet<(TypeId, TypeId)>,
25}
26
27pub fn compare_with_inference(
28    arena: &mut Arena<Type>,
29    builtins: &BuiltinTypes,
30    lhs: TypeId,
31    rhs: TypeId,
32    infer: Option<&mut HashMap<TypeId, TypeId>>,
33) -> Comparison {
34    let lhs = substitute(arena, lhs);
35    let rhs = substitute(arena, rhs);
36    let lhs_name = stringify_impl(arena, lhs, &mut IndexMap::new());
37    let rhs_name = stringify_impl(arena, rhs, &mut IndexMap::new());
38    trace!("Comparing {lhs_name} to {rhs_name}");
39    let result = compare_impl(
40        arena,
41        builtins,
42        &mut ComparisonContext {
43            infer,
44            ..Default::default()
45        },
46        lhs,
47        rhs,
48        None,
49        None,
50    );
51    trace!("Comparison from {lhs_name} to {rhs_name} yielded {result:?}");
52    result
53}
54
55pub fn compare(
56    arena: &mut Arena<Type>,
57    builtins: &BuiltinTypes,
58    lhs: TypeId,
59    rhs: TypeId,
60) -> Comparison {
61    compare_with_inference(arena, builtins, lhs, rhs, None)
62}
63
64pub(crate) fn compare_impl(
65    arena: &Arena<Type>,
66    builtins: &BuiltinTypes,
67    ctx: &mut ComparisonContext<'_>,
68    lhs: TypeId,
69    rhs: TypeId,
70    lhs_semantic: Option<TypeId>,
71    rhs_semantic: Option<TypeId>,
72) -> Comparison {
73    if !ctx.stack.insert((lhs, rhs)) {
74        return Comparison::Assign;
75    }
76
77    let result = match (arena[lhs].clone(), arena[rhs].clone()) {
78        (Type::Apply(_), _) | (_, Type::Apply(_)) => unreachable!(),
79        (Type::Ref(lhs), _) => {
80            compare_impl(arena, builtins, ctx, lhs, rhs, lhs_semantic, rhs_semantic)
81        }
82        (_, Type::Ref(rhs)) => {
83            compare_impl(arena, builtins, ctx, lhs, rhs, lhs_semantic, rhs_semantic)
84        }
85        (Type::Unresolved, _) | (_, Type::Unresolved) => Comparison::Assign,
86        (Type::Never, _) => Comparison::Assign,
87        (_, Type::Any) => Comparison::Assign,
88        (Type::Atom(lhs), Type::Atom(rhs)) => {
89            let semantic = if lhs.semantic == rhs.semantic
90                || rhs.semantic == AtomSemantic::Any
91                || (lhs.semantic == AtomSemantic::String && rhs.semantic == AtomSemantic::Bytes)
92                || (lhs.semantic == AtomSemantic::Bytes && rhs.semantic == AtomSemantic::String)
93            {
94                Comparison::Assign
95            } else {
96                Comparison::Cast
97            };
98
99            let restriction = match (lhs.restriction, rhs.restriction) {
100                (_, None) => Comparison::Assign,
101                (None, _) => Comparison::Invalid,
102                (Some(AtomRestriction::Length(lhs)), Some(AtomRestriction::Length(rhs))) => {
103                    if lhs == rhs {
104                        Comparison::Assign
105                    } else {
106                        Comparison::Invalid
107                    }
108                }
109                (Some(AtomRestriction::Value(lhs)), Some(AtomRestriction::Value(rhs))) => {
110                    if lhs == rhs {
111                        Comparison::Assign
112                    } else {
113                        Comparison::Invalid
114                    }
115                }
116                (Some(AtomRestriction::Length(_)), Some(AtomRestriction::Value(_))) => {
117                    Comparison::Invalid
118                }
119                (Some(AtomRestriction::Value(lhs)), Some(AtomRestriction::Length(rhs))) => {
120                    if lhs.len() == rhs {
121                        Comparison::Assign
122                    } else {
123                        Comparison::Invalid
124                    }
125                }
126            };
127
128            max(semantic, restriction)
129        }
130        (Type::Pair(lhs), Type::Pair(rhs)) => {
131            let first = compare_impl(arena, builtins, ctx, lhs.first, rhs.first, None, None);
132            let rest = compare_impl(arena, builtins, ctx, lhs.rest, rhs.rest, None, None);
133            max(first, rest)
134        }
135        (Type::Atom(_), Type::Pair(_)) => Comparison::Invalid,
136        (Type::Pair(_), Type::Atom(_)) => Comparison::Invalid,
137        (Type::Function(lhs), Type::Function(rhs)) => {
138            // TODO: We could relax the identical parameter requirement
139
140            if lhs.nil_terminated != rhs.nil_terminated || lhs.params.len() != rhs.params.len() {
141                Comparison::Invalid
142            } else {
143                let mut result = compare_impl(arena, builtins, ctx, lhs.ret, rhs.ret, None, None);
144
145                for (i, param) in lhs.params.iter().enumerate() {
146                    result = max(
147                        result,
148                        compare_impl(arena, builtins, ctx, *param, rhs.params[i], None, None),
149                    );
150                }
151
152                result
153            }
154        }
155        (_, Type::Generic(_)) => {
156            if lhs == rhs {
157                Comparison::Assign
158            } else if let Some(infer) = &mut ctx.infer {
159                if let Some(rhs) = infer.get(&rhs).copied() {
160                    compare_impl(arena, builtins, ctx, lhs, rhs, lhs_semantic, rhs_semantic)
161                } else {
162                    debug!(
163                        "Inferring {} is {}",
164                        stringify_impl(arena, rhs, &mut IndexMap::new()),
165                        stringify_impl(arena, lhs, &mut IndexMap::new())
166                    );
167                    infer.insert(rhs, lhs);
168                    Comparison::Assign
169                }
170            } else if let Type::Union(lhs) = arena[lhs].clone() {
171                let mut result = Comparison::Assign;
172
173                for &id in &lhs.types {
174                    result = max(
175                        result,
176                        compare_impl(arena, builtins, ctx, id, rhs, lhs_semantic, rhs_semantic),
177                    );
178                }
179
180                result
181            } else {
182                Comparison::Invalid
183            }
184        }
185        (Type::Struct(lhs), Type::Struct(rhs)) => max(
186            compare_impl(arena, builtins, ctx, lhs.inner, rhs.inner, None, None),
187            if lhs.semantic == rhs.semantic {
188                Comparison::Assign
189            } else {
190                Comparison::Cast
191            },
192        ),
193        (Type::Struct(lhs), _) => {
194            let inner = compare_impl(
195                arena,
196                builtins,
197                ctx,
198                lhs.inner,
199                rhs,
200                Some(lhs.semantic),
201                rhs_semantic,
202            );
203
204            let rhs_semantics = semantics_of(arena, rhs);
205
206            if rhs_semantic == Some(lhs.semantic)
207                || rhs_semantics.contains(&Some(Semantic::Id(lhs.semantic)))
208                || rhs_semantics.contains(&Some(Semantic::All))
209            {
210                inner
211            } else {
212                max(inner, Comparison::Cast)
213            }
214        }
215        (_, Type::Struct(rhs)) => {
216            let inner = compare_impl(
217                arena,
218                builtins,
219                ctx,
220                lhs,
221                rhs.inner,
222                lhs_semantic,
223                Some(rhs.semantic),
224            );
225
226            let semantics = semantics_of(arena, lhs);
227
228            if (semantics.len() != 1
229                || (!semantics.contains(&Some(Semantic::Id(rhs.semantic)))
230                    && !semantics.contains(&Some(Semantic::All))))
231                && lhs_semantic != Some(rhs.semantic)
232            {
233                max(inner, Comparison::Cast)
234            } else {
235                inner
236            }
237        }
238        (Type::Alias(lhs), _) => compare_impl(
239            arena,
240            builtins,
241            ctx,
242            lhs.inner,
243            rhs,
244            lhs_semantic,
245            rhs_semantic,
246        ),
247        (_, Type::Alias(rhs)) => compare_impl(
248            arena,
249            builtins,
250            ctx,
251            lhs,
252            rhs.inner,
253            lhs_semantic,
254            rhs_semantic,
255        ),
256        (Type::Generic(_) | Type::Any, _) => compare_impl(
257            arena,
258            builtins,
259            ctx,
260            builtins.recursive_any,
261            rhs,
262            lhs_semantic,
263            rhs_semantic,
264        ),
265        (Type::Function(_), _) => {
266            let result = if let Type::Union(rhs) = arena[rhs].clone() {
267                let mut result = Comparison::Invalid;
268
269                for &id in &rhs.types {
270                    result = min(
271                        result,
272                        compare_impl(arena, builtins, ctx, lhs, id, lhs_semantic, rhs_semantic),
273                    );
274                }
275
276                result
277            } else {
278                Comparison::Invalid
279            };
280
281            min(
282                result,
283                compare_impl(
284                    arena,
285                    builtins,
286                    ctx,
287                    builtins.recursive_any,
288                    rhs,
289                    lhs_semantic,
290                    rhs_semantic,
291                ),
292            )
293        }
294        (Type::Union(lhs), _) => {
295            let mut result = Comparison::Assign;
296
297            for &id in &lhs.types {
298                result = max(
299                    result,
300                    compare_impl(arena, builtins, ctx, id, rhs, lhs_semantic, rhs_semantic),
301                );
302            }
303
304            result
305        }
306        (_, Type::Union(rhs)) => {
307            let mut result = Comparison::Invalid;
308
309            for &id in &rhs.types {
310                result = min(
311                    result,
312                    compare_impl(arena, builtins, ctx, lhs, id, lhs_semantic, rhs_semantic),
313                );
314            }
315
316            result
317        }
318        (_, Type::Never) => Comparison::Invalid,
319        (_, Type::Function(_)) => Comparison::Invalid,
320    };
321
322    ctx.stack.pop().unwrap();
323
324    result
325}
326
327#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
328enum Semantic {
329    All,
330    Id(TypeId),
331}
332
333fn semantics_of(arena: &Arena<Type>, id: TypeId) -> HashSet<Option<Semantic>> {
334    match arena[id].clone() {
335        Type::Apply(_) => unreachable!(),
336        Type::Ref(id) => semantics_of(arena, id),
337        Type::Alias(alias) => semantics_of(arena, alias.inner),
338        Type::Never => HashSet::new(),
339        Type::Any => HashSet::from_iter([Some(Semantic::All)]),
340        Type::Unresolved | Type::Generic(_) | Type::Atom(_) | Type::Pair(_) | Type::Function(_) => {
341            HashSet::from_iter([None])
342        }
343        Type::Struct(ty) => HashSet::from_iter([Some(Semantic::Id(ty.semantic))]),
344        Type::Union(ty) => {
345            let mut semantics = HashSet::new();
346
347            for &id in &ty.types {
348                semantics.extend(semantics_of(arena, id));
349            }
350
351            semantics
352        }
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use std::borrow::Cow;
359
360    use id_arena::Arena;
361    use rstest::rstest;
362
363    use crate::{Atom, Type, compare};
364
365    use super::*;
366
367    #[rstest]
368    #[case(Atom::NIL, Atom::NIL, Comparison::Assign)]
369    #[case(Atom::NIL, Atom::FALSE, Comparison::Cast)]
370    #[case(Atom::NIL, Atom::TRUE, Comparison::Invalid)]
371    #[case(Atom::NIL, Atom::BYTES, Comparison::Assign)]
372    #[case(Atom::NIL, Atom::BYTES_32, Comparison::Invalid)]
373    #[case(Atom::NIL, Atom::PUBLIC_KEY, Comparison::Invalid)]
374    #[case(Atom::NIL, Atom::SIGNATURE, Comparison::Invalid)]
375    #[case(Atom::NIL, Atom::INT, Comparison::Cast)]
376    #[case(Atom::FALSE, Atom::NIL, Comparison::Cast)]
377    #[case(Atom::FALSE, Atom::FALSE, Comparison::Assign)]
378    #[case(Atom::FALSE, Atom::TRUE, Comparison::Invalid)]
379    #[case(Atom::FALSE, Atom::BYTES, Comparison::Cast)]
380    #[case(Atom::FALSE, Atom::BYTES_32, Comparison::Invalid)]
381    #[case(Atom::FALSE, Atom::PUBLIC_KEY, Comparison::Invalid)]
382    #[case(Atom::FALSE, Atom::SIGNATURE, Comparison::Invalid)]
383    #[case(Atom::FALSE, Atom::INT, Comparison::Cast)]
384    #[case(Atom::TRUE, Atom::NIL, Comparison::Invalid)]
385    #[case(Atom::TRUE, Atom::FALSE, Comparison::Invalid)]
386    #[case(Atom::TRUE, Atom::TRUE, Comparison::Assign)]
387    #[case(Atom::TRUE, Atom::BYTES, Comparison::Cast)]
388    #[case(Atom::TRUE, Atom::BYTES_32, Comparison::Invalid)]
389    #[case(Atom::TRUE, Atom::PUBLIC_KEY, Comparison::Invalid)]
390    #[case(Atom::TRUE, Atom::SIGNATURE, Comparison::Invalid)]
391    #[case(Atom::TRUE, Atom::INT, Comparison::Cast)]
392    #[case(Atom::BYTES, Atom::NIL, Comparison::Invalid)]
393    #[case(Atom::BYTES, Atom::FALSE, Comparison::Invalid)]
394    #[case(Atom::BYTES, Atom::TRUE, Comparison::Invalid)]
395    #[case(Atom::BYTES, Atom::BYTES, Comparison::Assign)]
396    #[case(Atom::BYTES, Atom::BYTES_32, Comparison::Invalid)]
397    #[case(Atom::BYTES, Atom::PUBLIC_KEY, Comparison::Invalid)]
398    #[case(Atom::BYTES, Atom::SIGNATURE, Comparison::Invalid)]
399    #[case(Atom::BYTES, Atom::INT, Comparison::Cast)]
400    #[case(Atom::BYTES_32, Atom::NIL, Comparison::Invalid)]
401    #[case(Atom::BYTES_32, Atom::FALSE, Comparison::Invalid)]
402    #[case(Atom::BYTES_32, Atom::TRUE, Comparison::Invalid)]
403    #[case(Atom::BYTES_32, Atom::BYTES, Comparison::Assign)]
404    #[case(Atom::BYTES_32, Atom::BYTES_32, Comparison::Assign)]
405    #[case(Atom::BYTES_32, Atom::PUBLIC_KEY, Comparison::Invalid)]
406    #[case(Atom::BYTES_32, Atom::SIGNATURE, Comparison::Invalid)]
407    #[case(Atom::BYTES_32, Atom::INT, Comparison::Cast)]
408    #[case(Atom::PUBLIC_KEY, Atom::NIL, Comparison::Invalid)]
409    #[case(Atom::PUBLIC_KEY, Atom::FALSE, Comparison::Invalid)]
410    #[case(Atom::PUBLIC_KEY, Atom::TRUE, Comparison::Invalid)]
411    #[case(Atom::PUBLIC_KEY, Atom::BYTES, Comparison::Cast)]
412    #[case(Atom::PUBLIC_KEY, Atom::BYTES_32, Comparison::Invalid)]
413    #[case(Atom::PUBLIC_KEY, Atom::PUBLIC_KEY, Comparison::Assign)]
414    #[case(Atom::PUBLIC_KEY, Atom::SIGNATURE, Comparison::Invalid)]
415    #[case(Atom::PUBLIC_KEY, Atom::INT, Comparison::Cast)]
416    #[case(Atom::SIGNATURE, Atom::NIL, Comparison::Invalid)]
417    #[case(Atom::SIGNATURE, Atom::FALSE, Comparison::Invalid)]
418    #[case(Atom::SIGNATURE, Atom::TRUE, Comparison::Invalid)]
419    #[case(Atom::SIGNATURE, Atom::BYTES, Comparison::Cast)]
420    #[case(Atom::SIGNATURE, Atom::BYTES_32, Comparison::Invalid)]
421    #[case(Atom::SIGNATURE, Atom::PUBLIC_KEY, Comparison::Invalid)]
422    #[case(Atom::SIGNATURE, Atom::SIGNATURE, Comparison::Assign)]
423    #[case(Atom::SIGNATURE, Atom::INT, Comparison::Cast)]
424    #[case(Atom::INT, Atom::NIL, Comparison::Invalid)]
425    #[case(Atom::INT, Atom::FALSE, Comparison::Invalid)]
426    #[case(Atom::INT, Atom::TRUE, Comparison::Invalid)]
427    #[case(Atom::INT, Atom::BYTES, Comparison::Cast)]
428    #[case(Atom::INT, Atom::BYTES_32, Comparison::Invalid)]
429    #[case(Atom::INT, Atom::PUBLIC_KEY, Comparison::Invalid)]
430    #[case(Atom::INT, Atom::SIGNATURE, Comparison::Invalid)]
431    #[case(Atom::INT, Atom::INT, Comparison::Assign)]
432    #[case(Atom::new(AtomSemantic::Int, Some(AtomRestriction::Value(Cow::Borrowed(&[1])))), Atom::INT, Comparison::Assign)]
433    #[case(Atom::new(AtomSemantic::Int, Some(AtomRestriction::Value(Cow::Borrowed(&[1])))), Atom::BYTES, Comparison::Cast)]
434    fn test_atoms(#[case] lhs: Atom, #[case] rhs: Atom, #[case] expected: Comparison) {
435        let mut arena = Arena::new();
436        let builtins = BuiltinTypes::new(&mut arena);
437        let lhs_id = arena.alloc(Type::Atom(lhs.clone()));
438        let rhs_id = arena.alloc(Type::Atom(rhs.clone()));
439        assert_eq!(
440            compare(&mut arena, &builtins, lhs_id, rhs_id),
441            expected,
442            "{lhs} -> {rhs}"
443        );
444    }
445}