Skip to main content

rue_types/
compare.rs

1use std::{
2    cmp::{max, min},
3    collections::{HashMap, HashSet},
4};
5
6use id_arena::Arena;
7use indexmap::{IndexMap, IndexSet};
8use log::{debug, trace};
9
10use crate::{
11    AtomRestriction, AtomSemantic, BuiltinTypes, Type, TypeId, stringify_impl, substitute,
12};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
15pub enum Comparison {
16    Assign,
17    Cast,
18    Invalid,
19}
20
21#[derive(Debug, Default)]
22pub(crate) struct ComparisonContext<'a> {
23    infer: Option<&'a mut HashMap<TypeId, Vec<TypeId>>>,
24    stack: IndexSet<(TypeId, TypeId)>,
25}
26
27pub fn compare_with_inference(
28    arena: &mut Arena<Type>,
29    builtins: &BuiltinTypes,
30    lhs: TypeId,
31    rhs: TypeId,
32    infer: Option<&mut HashMap<TypeId, Vec<TypeId>>>,
33) -> Comparison {
34    let lhs = substitute(arena, lhs);
35    let rhs = substitute(arena, rhs);
36    let lhs_name = stringify_impl(arena, lhs, &mut IndexMap::new());
37    let rhs_name = stringify_impl(arena, rhs, &mut IndexMap::new());
38    trace!("Comparing {lhs_name} to {rhs_name}");
39    let result = compare_impl(
40        arena,
41        builtins,
42        &mut ComparisonContext {
43            infer,
44            ..Default::default()
45        },
46        lhs,
47        rhs,
48        None,
49        None,
50    );
51    trace!("Comparison from {lhs_name} to {rhs_name} yielded {result:?}");
52    result
53}
54
55pub fn compare(
56    arena: &mut Arena<Type>,
57    builtins: &BuiltinTypes,
58    lhs: TypeId,
59    rhs: TypeId,
60) -> Comparison {
61    compare_with_inference(arena, builtins, lhs, rhs, None)
62}
63
64pub(crate) fn compare_impl(
65    arena: &Arena<Type>,
66    builtins: &BuiltinTypes,
67    ctx: &mut ComparisonContext<'_>,
68    lhs: TypeId,
69    rhs: TypeId,
70    lhs_semantic: Option<TypeId>,
71    rhs_semantic: Option<TypeId>,
72) -> Comparison {
73    if !ctx.stack.insert((lhs, rhs)) {
74        return Comparison::Assign;
75    }
76
77    let result = match (arena[lhs].clone(), arena[rhs].clone()) {
78        (Type::Apply(_), _) | (_, Type::Apply(_)) => unreachable!(),
79        (Type::Ref(lhs), _) => {
80            compare_impl(arena, builtins, ctx, lhs, rhs, lhs_semantic, rhs_semantic)
81        }
82        (_, Type::Ref(rhs)) => {
83            compare_impl(arena, builtins, ctx, lhs, rhs, lhs_semantic, rhs_semantic)
84        }
85        (Type::Unresolved, _) | (_, Type::Unresolved) => Comparison::Assign,
86        (Type::Never, _) => Comparison::Assign,
87        (_, Type::Any) => Comparison::Assign,
88        (Type::Atom(lhs), Type::Atom(rhs)) => {
89            let semantic = if lhs.semantic == rhs.semantic
90                || rhs.semantic == AtomSemantic::Any
91                || (lhs.semantic == AtomSemantic::String && rhs.semantic == AtomSemantic::Bytes)
92                || (lhs.semantic == AtomSemantic::Bytes && rhs.semantic == AtomSemantic::String)
93            {
94                Comparison::Assign
95            } else {
96                Comparison::Cast
97            };
98
99            let restriction = match (lhs.restriction, rhs.restriction) {
100                (_, None) => Comparison::Assign,
101                (None, _) => Comparison::Invalid,
102                (Some(AtomRestriction::Length(lhs)), Some(AtomRestriction::Length(rhs))) => {
103                    if lhs == rhs {
104                        Comparison::Assign
105                    } else {
106                        Comparison::Invalid
107                    }
108                }
109                (Some(AtomRestriction::Value(lhs)), Some(AtomRestriction::Value(rhs))) => {
110                    if lhs == rhs {
111                        Comparison::Assign
112                    } else {
113                        Comparison::Invalid
114                    }
115                }
116                (Some(AtomRestriction::Length(_)), Some(AtomRestriction::Value(_))) => {
117                    Comparison::Invalid
118                }
119                (Some(AtomRestriction::Value(lhs)), Some(AtomRestriction::Length(rhs))) => {
120                    if lhs.len() == rhs {
121                        Comparison::Assign
122                    } else {
123                        Comparison::Invalid
124                    }
125                }
126            };
127
128            max(semantic, restriction)
129        }
130        (Type::Pair(lhs), Type::Pair(rhs)) => {
131            let first = compare_impl(arena, builtins, ctx, lhs.first, rhs.first, None, None);
132            let rest = compare_impl(arena, builtins, ctx, lhs.rest, rhs.rest, None, None);
133            max(first, rest)
134        }
135        (Type::Atom(_), Type::Pair(_)) => Comparison::Invalid,
136        (Type::Pair(_), Type::Atom(_)) => Comparison::Invalid,
137        (Type::Function(lhs), Type::Function(rhs)) => {
138            // TODO: We could relax the identical parameter requirement
139
140            if lhs.nil_terminated != rhs.nil_terminated || lhs.params.len() != rhs.params.len() {
141                Comparison::Invalid
142            } else {
143                let mut result = compare_impl(arena, builtins, ctx, lhs.ret, rhs.ret, None, None);
144
145                for (i, param) in lhs.params.iter().enumerate() {
146                    result = max(
147                        result,
148                        compare_impl(arena, builtins, ctx, *param.1, rhs.params[i], None, None),
149                    );
150                }
151
152                result
153            }
154        }
155        (_, Type::Generic(_)) => {
156            if lhs == rhs {
157                Comparison::Assign
158            } else if let Some(infer) = &mut ctx.infer {
159                debug!(
160                    "Inferring {} could include {}",
161                    stringify_impl(arena, rhs, &mut IndexMap::new()),
162                    stringify_impl(arena, lhs, &mut IndexMap::new())
163                );
164                infer.entry(rhs).or_default().push(lhs);
165                Comparison::Assign
166            } else if let Type::Union(lhs) = arena[lhs].clone() {
167                let mut result = Comparison::Assign;
168
169                for &id in &lhs.types {
170                    result = max(
171                        result,
172                        compare_impl(arena, builtins, ctx, id, rhs, lhs_semantic, rhs_semantic),
173                    );
174                }
175
176                result
177            } else {
178                Comparison::Invalid
179            }
180        }
181        (Type::Struct(lhs), Type::Struct(rhs)) => max(
182            compare_impl(arena, builtins, ctx, lhs.inner, rhs.inner, None, None),
183            if lhs.semantic == rhs.semantic {
184                Comparison::Assign
185            } else {
186                Comparison::Cast
187            },
188        ),
189        (Type::Struct(lhs), _) => {
190            let inner = compare_impl(
191                arena,
192                builtins,
193                ctx,
194                lhs.inner,
195                rhs,
196                Some(lhs.semantic),
197                rhs_semantic,
198            );
199
200            let rhs_semantics = semantics_of(arena, rhs);
201
202            if rhs_semantic == Some(lhs.semantic)
203                || rhs_semantics.contains(&Some(Semantic::Id(lhs.semantic)))
204                || rhs_semantics.contains(&Some(Semantic::All))
205            {
206                inner
207            } else {
208                max(inner, Comparison::Cast)
209            }
210        }
211        (_, Type::Struct(rhs)) => {
212            let inner = compare_impl(
213                arena,
214                builtins,
215                ctx,
216                lhs,
217                rhs.inner,
218                lhs_semantic,
219                Some(rhs.semantic),
220            );
221
222            let semantics = semantics_of(arena, lhs);
223
224            if (semantics.len() != 1
225                || (!semantics.contains(&Some(Semantic::Id(rhs.semantic)))
226                    && !semantics.contains(&Some(Semantic::All))))
227                && lhs_semantic != Some(rhs.semantic)
228            {
229                max(inner, Comparison::Cast)
230            } else {
231                inner
232            }
233        }
234        (Type::Alias(lhs), _) => compare_impl(
235            arena,
236            builtins,
237            ctx,
238            lhs.inner,
239            rhs,
240            lhs_semantic,
241            rhs_semantic,
242        ),
243        (_, Type::Alias(rhs)) => compare_impl(
244            arena,
245            builtins,
246            ctx,
247            lhs,
248            rhs.inner,
249            lhs_semantic,
250            rhs_semantic,
251        ),
252        (Type::Generic(_) | Type::Any, _) => compare_impl(
253            arena,
254            builtins,
255            ctx,
256            builtins.recursive_any,
257            rhs,
258            lhs_semantic,
259            rhs_semantic,
260        ),
261        (Type::Function(_), _) => {
262            let result = if let Type::Union(rhs) = arena[rhs].clone() {
263                let mut result = Comparison::Invalid;
264
265                for &id in &rhs.types {
266                    result = min(
267                        result,
268                        compare_impl(arena, builtins, ctx, lhs, id, lhs_semantic, rhs_semantic),
269                    );
270                }
271
272                result
273            } else {
274                Comparison::Invalid
275            };
276
277            min(
278                result,
279                compare_impl(
280                    arena,
281                    builtins,
282                    ctx,
283                    builtins.recursive_any,
284                    rhs,
285                    lhs_semantic,
286                    rhs_semantic,
287                ),
288            )
289        }
290        (Type::Union(lhs), _) => {
291            let mut result = Comparison::Assign;
292
293            for &id in &lhs.types {
294                result = max(
295                    result,
296                    compare_impl(arena, builtins, ctx, id, rhs, lhs_semantic, rhs_semantic),
297                );
298            }
299
300            result
301        }
302        (_, Type::Union(rhs)) => {
303            let mut result = Comparison::Invalid;
304
305            for &id in &rhs.types {
306                result = min(
307                    result,
308                    compare_impl(arena, builtins, ctx, lhs, id, lhs_semantic, rhs_semantic),
309                );
310            }
311
312            result
313        }
314        (_, Type::Never) => Comparison::Invalid,
315        (_, Type::Function(_)) => Comparison::Invalid,
316    };
317
318    ctx.stack.pop().unwrap();
319
320    result
321}
322
323#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
324enum Semantic {
325    All,
326    Id(TypeId),
327}
328
329fn semantics_of(arena: &Arena<Type>, id: TypeId) -> HashSet<Option<Semantic>> {
330    match arena[id].clone() {
331        Type::Apply(_) => unreachable!(),
332        Type::Ref(id) => semantics_of(arena, id),
333        Type::Alias(alias) => semantics_of(arena, alias.inner),
334        Type::Never => HashSet::new(),
335        Type::Any => HashSet::from_iter([Some(Semantic::All)]),
336        Type::Unresolved | Type::Generic(_) | Type::Atom(_) | Type::Pair(_) | Type::Function(_) => {
337            HashSet::from_iter([None])
338        }
339        Type::Struct(ty) => HashSet::from_iter([Some(Semantic::Id(ty.semantic))]),
340        Type::Union(ty) => {
341            let mut semantics = HashSet::new();
342
343            for &id in &ty.types {
344                semantics.extend(semantics_of(arena, id));
345            }
346
347            semantics
348        }
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use std::borrow::Cow;
355
356    use id_arena::Arena;
357    use rstest::rstest;
358
359    use crate::{Atom, Type, compare};
360
361    use super::*;
362
363    #[rstest]
364    #[case(Atom::NIL, Atom::NIL, Comparison::Assign)]
365    #[case(Atom::NIL, Atom::FALSE, Comparison::Cast)]
366    #[case(Atom::NIL, Atom::TRUE, Comparison::Invalid)]
367    #[case(Atom::NIL, Atom::BYTES, Comparison::Assign)]
368    #[case(Atom::NIL, Atom::BYTES_32, Comparison::Invalid)]
369    #[case(Atom::NIL, Atom::PUBLIC_KEY, Comparison::Invalid)]
370    #[case(Atom::NIL, Atom::SIGNATURE, Comparison::Invalid)]
371    #[case(Atom::NIL, Atom::INT, Comparison::Cast)]
372    #[case(Atom::FALSE, Atom::NIL, Comparison::Cast)]
373    #[case(Atom::FALSE, Atom::FALSE, Comparison::Assign)]
374    #[case(Atom::FALSE, Atom::TRUE, Comparison::Invalid)]
375    #[case(Atom::FALSE, Atom::BYTES, Comparison::Cast)]
376    #[case(Atom::FALSE, Atom::BYTES_32, Comparison::Invalid)]
377    #[case(Atom::FALSE, Atom::PUBLIC_KEY, Comparison::Invalid)]
378    #[case(Atom::FALSE, Atom::SIGNATURE, Comparison::Invalid)]
379    #[case(Atom::FALSE, Atom::INT, Comparison::Cast)]
380    #[case(Atom::TRUE, Atom::NIL, Comparison::Invalid)]
381    #[case(Atom::TRUE, Atom::FALSE, Comparison::Invalid)]
382    #[case(Atom::TRUE, Atom::TRUE, Comparison::Assign)]
383    #[case(Atom::TRUE, Atom::BYTES, Comparison::Cast)]
384    #[case(Atom::TRUE, Atom::BYTES_32, Comparison::Invalid)]
385    #[case(Atom::TRUE, Atom::PUBLIC_KEY, Comparison::Invalid)]
386    #[case(Atom::TRUE, Atom::SIGNATURE, Comparison::Invalid)]
387    #[case(Atom::TRUE, Atom::INT, Comparison::Cast)]
388    #[case(Atom::BYTES, Atom::NIL, Comparison::Invalid)]
389    #[case(Atom::BYTES, Atom::FALSE, Comparison::Invalid)]
390    #[case(Atom::BYTES, Atom::TRUE, Comparison::Invalid)]
391    #[case(Atom::BYTES, Atom::BYTES, Comparison::Assign)]
392    #[case(Atom::BYTES, Atom::BYTES_32, Comparison::Invalid)]
393    #[case(Atom::BYTES, Atom::PUBLIC_KEY, Comparison::Invalid)]
394    #[case(Atom::BYTES, Atom::SIGNATURE, Comparison::Invalid)]
395    #[case(Atom::BYTES, Atom::INT, Comparison::Cast)]
396    #[case(Atom::BYTES_32, Atom::NIL, Comparison::Invalid)]
397    #[case(Atom::BYTES_32, Atom::FALSE, Comparison::Invalid)]
398    #[case(Atom::BYTES_32, Atom::TRUE, Comparison::Invalid)]
399    #[case(Atom::BYTES_32, Atom::BYTES, Comparison::Assign)]
400    #[case(Atom::BYTES_32, Atom::BYTES_32, Comparison::Assign)]
401    #[case(Atom::BYTES_32, Atom::PUBLIC_KEY, Comparison::Invalid)]
402    #[case(Atom::BYTES_32, Atom::SIGNATURE, Comparison::Invalid)]
403    #[case(Atom::BYTES_32, Atom::INT, Comparison::Cast)]
404    #[case(Atom::PUBLIC_KEY, Atom::NIL, Comparison::Invalid)]
405    #[case(Atom::PUBLIC_KEY, Atom::FALSE, Comparison::Invalid)]
406    #[case(Atom::PUBLIC_KEY, Atom::TRUE, Comparison::Invalid)]
407    #[case(Atom::PUBLIC_KEY, Atom::BYTES, Comparison::Cast)]
408    #[case(Atom::PUBLIC_KEY, Atom::BYTES_32, Comparison::Invalid)]
409    #[case(Atom::PUBLIC_KEY, Atom::PUBLIC_KEY, Comparison::Assign)]
410    #[case(Atom::PUBLIC_KEY, Atom::SIGNATURE, Comparison::Invalid)]
411    #[case(Atom::PUBLIC_KEY, Atom::INT, Comparison::Cast)]
412    #[case(Atom::SIGNATURE, Atom::NIL, Comparison::Invalid)]
413    #[case(Atom::SIGNATURE, Atom::FALSE, Comparison::Invalid)]
414    #[case(Atom::SIGNATURE, Atom::TRUE, Comparison::Invalid)]
415    #[case(Atom::SIGNATURE, Atom::BYTES, Comparison::Cast)]
416    #[case(Atom::SIGNATURE, Atom::BYTES_32, Comparison::Invalid)]
417    #[case(Atom::SIGNATURE, Atom::PUBLIC_KEY, Comparison::Invalid)]
418    #[case(Atom::SIGNATURE, Atom::SIGNATURE, Comparison::Assign)]
419    #[case(Atom::SIGNATURE, Atom::INT, Comparison::Cast)]
420    #[case(Atom::INT, Atom::NIL, Comparison::Invalid)]
421    #[case(Atom::INT, Atom::FALSE, Comparison::Invalid)]
422    #[case(Atom::INT, Atom::TRUE, Comparison::Invalid)]
423    #[case(Atom::INT, Atom::BYTES, Comparison::Cast)]
424    #[case(Atom::INT, Atom::BYTES_32, Comparison::Invalid)]
425    #[case(Atom::INT, Atom::PUBLIC_KEY, Comparison::Invalid)]
426    #[case(Atom::INT, Atom::SIGNATURE, Comparison::Invalid)]
427    #[case(Atom::INT, Atom::INT, Comparison::Assign)]
428    #[case(Atom::new(AtomSemantic::Int, Some(AtomRestriction::Value(Cow::Borrowed(&[1])))), Atom::INT, Comparison::Assign)]
429    #[case(Atom::new(AtomSemantic::Int, Some(AtomRestriction::Value(Cow::Borrowed(&[1])))), Atom::BYTES, Comparison::Cast)]
430    fn test_atoms(#[case] lhs: Atom, #[case] rhs: Atom, #[case] expected: Comparison) {
431        let mut arena = Arena::new();
432        let builtins = BuiltinTypes::new(&mut arena);
433        let lhs_id = arena.alloc(Type::Atom(lhs.clone()));
434        let rhs_id = arena.alloc(Type::Atom(rhs.clone()));
435        assert_eq!(
436            compare(&mut arena, &builtins, lhs_id, rhs_id),
437            expected,
438            "{lhs} -> {rhs}"
439        );
440    }
441}