rue_types/
compare.rs

1use std::{
2    cmp::{max, min},
3    collections::{HashMap, HashSet},
4};
5
6use id_arena::Arena;
7use indexmap::{IndexMap, IndexSet};
8use log::{debug, trace};
9
10use crate::{
11    AtomRestriction, AtomSemantic, BuiltinTypes, Type, TypeId, stringify_impl, substitute,
12};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
15pub enum Comparison {
16    Assign,
17    Cast,
18    Invalid,
19}
20
21#[derive(Debug, Default)]
22pub(crate) struct ComparisonContext<'a> {
23    infer: Option<&'a mut HashMap<TypeId, TypeId>>,
24    stack: IndexSet<(TypeId, TypeId)>,
25}
26
27pub fn compare_with_inference(
28    arena: &mut Arena<Type>,
29    builtins: &BuiltinTypes,
30    lhs: TypeId,
31    rhs: TypeId,
32    infer: Option<&mut HashMap<TypeId, TypeId>>,
33) -> Comparison {
34    let lhs = substitute(arena, lhs);
35    let rhs = substitute(arena, rhs);
36    let lhs_name = stringify_impl(arena, lhs, &mut IndexMap::new());
37    let rhs_name = stringify_impl(arena, rhs, &mut IndexMap::new());
38    trace!("Comparing {lhs_name} to {rhs_name}");
39    let result = compare_impl(
40        arena,
41        builtins,
42        &mut ComparisonContext {
43            infer,
44            ..Default::default()
45        },
46        lhs,
47        rhs,
48        None,
49        None,
50    );
51    trace!("Comparison from {lhs_name} to {rhs_name} yielded {result:?}");
52    result
53}
54
55pub fn compare(
56    arena: &mut Arena<Type>,
57    builtins: &BuiltinTypes,
58    lhs: TypeId,
59    rhs: TypeId,
60) -> Comparison {
61    compare_with_inference(arena, builtins, lhs, rhs, None)
62}
63
64pub(crate) fn compare_impl(
65    arena: &Arena<Type>,
66    builtins: &BuiltinTypes,
67    ctx: &mut ComparisonContext<'_>,
68    lhs: TypeId,
69    rhs: TypeId,
70    lhs_semantic: Option<TypeId>,
71    rhs_semantic: Option<TypeId>,
72) -> Comparison {
73    if !ctx.stack.insert((lhs, rhs)) {
74        return Comparison::Assign;
75    }
76
77    let result = match (arena[lhs].clone(), arena[rhs].clone()) {
78        (Type::Apply(_), _) | (_, Type::Apply(_)) => unreachable!(),
79        (Type::Ref(lhs), _) => {
80            compare_impl(arena, builtins, ctx, lhs, rhs, lhs_semantic, rhs_semantic)
81        }
82        (_, Type::Ref(rhs)) => {
83            compare_impl(arena, builtins, ctx, lhs, rhs, lhs_semantic, rhs_semantic)
84        }
85        (Type::Unresolved, _) | (_, Type::Unresolved) => Comparison::Assign,
86        (Type::Never, _) => Comparison::Assign,
87        (Type::Atom(lhs), Type::Atom(rhs)) => {
88            let semantic = if lhs.semantic == rhs.semantic || rhs.semantic == AtomSemantic::Any {
89                Comparison::Assign
90            } else {
91                Comparison::Cast
92            };
93
94            let restriction = match (lhs.restriction, rhs.restriction) {
95                (_, None) => Comparison::Assign,
96                (None, _) => Comparison::Invalid,
97                (Some(AtomRestriction::Length(lhs)), Some(AtomRestriction::Length(rhs))) => {
98                    if lhs == rhs {
99                        Comparison::Assign
100                    } else {
101                        Comparison::Invalid
102                    }
103                }
104                (Some(AtomRestriction::Value(lhs)), Some(AtomRestriction::Value(rhs))) => {
105                    if lhs == rhs {
106                        Comparison::Assign
107                    } else {
108                        Comparison::Invalid
109                    }
110                }
111                (Some(AtomRestriction::Length(_)), Some(AtomRestriction::Value(_))) => {
112                    Comparison::Invalid
113                }
114                (Some(AtomRestriction::Value(lhs)), Some(AtomRestriction::Length(rhs))) => {
115                    if lhs.len() == rhs {
116                        Comparison::Assign
117                    } else {
118                        Comparison::Invalid
119                    }
120                }
121            };
122
123            max(semantic, restriction)
124        }
125        (Type::Pair(lhs), Type::Pair(rhs)) => {
126            let first = compare_impl(arena, builtins, ctx, lhs.first, rhs.first, None, None);
127            let rest = compare_impl(arena, builtins, ctx, lhs.rest, rhs.rest, None, None);
128            max(first, rest)
129        }
130        (Type::Atom(_), Type::Pair(_)) => Comparison::Invalid,
131        (Type::Pair(_), Type::Atom(_)) => Comparison::Invalid,
132        (Type::Function(lhs), Type::Function(rhs)) => {
133            // TODO: We could relax the identical parameter requirement
134
135            if lhs.nil_terminated != rhs.nil_terminated || lhs.params.len() != rhs.params.len() {
136                Comparison::Invalid
137            } else {
138                let mut result = compare_impl(arena, builtins, ctx, lhs.ret, rhs.ret, None, None);
139
140                for (i, param) in lhs.params.iter().enumerate() {
141                    result = max(
142                        result,
143                        compare_impl(arena, builtins, ctx, *param, rhs.params[i], None, None),
144                    );
145                }
146
147                result
148            }
149        }
150        (_, Type::Generic) => {
151            if lhs == rhs {
152                Comparison::Assign
153            } else if let Some(infer) = &mut ctx.infer {
154                if let Some(rhs) = infer.get(&rhs).copied() {
155                    compare_impl(arena, builtins, ctx, lhs, rhs, lhs_semantic, rhs_semantic)
156                } else {
157                    debug!(
158                        "Inferring {} is {}",
159                        stringify_impl(arena, rhs, &mut IndexMap::new()),
160                        stringify_impl(arena, lhs, &mut IndexMap::new())
161                    );
162                    infer.insert(rhs, lhs);
163                    Comparison::Assign
164                }
165            } else if let Type::Union(lhs) = arena[lhs].clone() {
166                let mut result = Comparison::Assign;
167
168                for &id in &lhs.types {
169                    result = max(
170                        result,
171                        compare_impl(arena, builtins, ctx, id, rhs, lhs_semantic, rhs_semantic),
172                    );
173                }
174
175                result
176            } else {
177                Comparison::Invalid
178            }
179        }
180        (Type::Struct(lhs), Type::Struct(rhs)) => max(
181            compare_impl(arena, builtins, ctx, lhs.inner, rhs.inner, None, None),
182            if lhs.semantic == rhs.semantic {
183                Comparison::Assign
184            } else {
185                Comparison::Cast
186            },
187        ),
188        (Type::Struct(lhs), _) => {
189            let inner = compare_impl(
190                arena,
191                builtins,
192                ctx,
193                lhs.inner,
194                rhs,
195                Some(lhs.semantic),
196                rhs_semantic,
197            );
198
199            if rhs_semantic == Some(lhs.semantic)
200                || semantics_of(arena, rhs).contains(&Some(lhs.semantic))
201            {
202                inner
203            } else {
204                max(inner, Comparison::Cast)
205            }
206        }
207        (_, Type::Struct(rhs)) => {
208            let inner = compare_impl(
209                arena,
210                builtins,
211                ctx,
212                lhs,
213                rhs.inner,
214                lhs_semantic,
215                Some(rhs.semantic),
216            );
217
218            let semantics = semantics_of(arena, lhs);
219
220            if (semantics.len() != 1 || !semantics.contains(&Some(rhs.semantic)))
221                && lhs_semantic != Some(rhs.semantic)
222            {
223                max(inner, Comparison::Cast)
224            } else {
225                inner
226            }
227        }
228        (Type::Alias(lhs), _) => compare_impl(
229            arena,
230            builtins,
231            ctx,
232            lhs.inner,
233            rhs,
234            lhs_semantic,
235            rhs_semantic,
236        ),
237        (_, Type::Alias(rhs)) => compare_impl(
238            arena,
239            builtins,
240            ctx,
241            lhs,
242            rhs.inner,
243            lhs_semantic,
244            rhs_semantic,
245        ),
246        (Type::Function(_) | Type::Generic, _) => compare_impl(
247            arena,
248            builtins,
249            ctx,
250            builtins.any,
251            rhs,
252            lhs_semantic,
253            rhs_semantic,
254        ),
255        (Type::Union(lhs), _) => {
256            let mut result = Comparison::Assign;
257
258            for &id in &lhs.types {
259                result = max(
260                    result,
261                    compare_impl(arena, builtins, ctx, id, rhs, lhs_semantic, rhs_semantic),
262                );
263            }
264
265            result
266        }
267        (_, Type::Union(rhs)) => {
268            let mut result = Comparison::Invalid;
269
270            for &id in &rhs.types {
271                result = min(
272                    result,
273                    compare_impl(arena, builtins, ctx, lhs, id, lhs_semantic, rhs_semantic),
274                );
275            }
276
277            result
278        }
279        (_, Type::Never) => Comparison::Invalid,
280        (_, Type::Function(_)) => Comparison::Invalid,
281    };
282
283    ctx.stack.pop().unwrap();
284
285    result
286}
287
288fn semantics_of(arena: &Arena<Type>, id: TypeId) -> HashSet<Option<TypeId>> {
289    match arena[id].clone() {
290        Type::Apply(_) => unreachable!(),
291        Type::Ref(id) => semantics_of(arena, id),
292        Type::Alias(alias) => semantics_of(arena, alias.inner),
293        Type::Never => HashSet::new(),
294        Type::Unresolved | Type::Generic | Type::Atom(_) | Type::Pair(_) | Type::Function(_) => {
295            HashSet::from_iter([None])
296        }
297        Type::Struct(ty) => HashSet::from_iter([Some(ty.semantic)]),
298        Type::Union(ty) => {
299            let mut semantics = HashSet::new();
300
301            for &id in &ty.types {
302                semantics.extend(semantics_of(arena, id));
303            }
304
305            semantics
306        }
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use std::borrow::Cow;
313
314    use id_arena::Arena;
315    use rstest::rstest;
316
317    use crate::{Atom, Type, compare};
318
319    use super::*;
320
321    #[rstest]
322    #[case(Atom::NIL, Atom::NIL, Comparison::Assign)]
323    #[case(Atom::NIL, Atom::FALSE, Comparison::Cast)]
324    #[case(Atom::NIL, Atom::TRUE, Comparison::Invalid)]
325    #[case(Atom::NIL, Atom::BYTES, Comparison::Assign)]
326    #[case(Atom::NIL, Atom::BYTES_32, Comparison::Invalid)]
327    #[case(Atom::NIL, Atom::PUBLIC_KEY, Comparison::Invalid)]
328    #[case(Atom::NIL, Atom::INT, Comparison::Cast)]
329    #[case(Atom::FALSE, Atom::NIL, Comparison::Cast)]
330    #[case(Atom::FALSE, Atom::FALSE, Comparison::Assign)]
331    #[case(Atom::FALSE, Atom::TRUE, Comparison::Invalid)]
332    #[case(Atom::FALSE, Atom::BYTES, Comparison::Cast)]
333    #[case(Atom::FALSE, Atom::BYTES_32, Comparison::Invalid)]
334    #[case(Atom::FALSE, Atom::PUBLIC_KEY, Comparison::Invalid)]
335    #[case(Atom::FALSE, Atom::INT, Comparison::Cast)]
336    #[case(Atom::TRUE, Atom::NIL, Comparison::Invalid)]
337    #[case(Atom::TRUE, Atom::FALSE, Comparison::Invalid)]
338    #[case(Atom::TRUE, Atom::TRUE, Comparison::Assign)]
339    #[case(Atom::TRUE, Atom::BYTES, Comparison::Cast)]
340    #[case(Atom::TRUE, Atom::BYTES_32, Comparison::Invalid)]
341    #[case(Atom::TRUE, Atom::PUBLIC_KEY, Comparison::Invalid)]
342    #[case(Atom::TRUE, Atom::INT, Comparison::Cast)]
343    #[case(Atom::BYTES, Atom::NIL, Comparison::Invalid)]
344    #[case(Atom::BYTES, Atom::FALSE, Comparison::Invalid)]
345    #[case(Atom::BYTES, Atom::TRUE, Comparison::Invalid)]
346    #[case(Atom::BYTES, Atom::BYTES, Comparison::Assign)]
347    #[case(Atom::BYTES, Atom::BYTES_32, Comparison::Invalid)]
348    #[case(Atom::BYTES, Atom::PUBLIC_KEY, Comparison::Invalid)]
349    #[case(Atom::BYTES, Atom::INT, Comparison::Cast)]
350    #[case(Atom::BYTES_32, Atom::NIL, Comparison::Invalid)]
351    #[case(Atom::BYTES_32, Atom::FALSE, Comparison::Invalid)]
352    #[case(Atom::BYTES_32, Atom::TRUE, Comparison::Invalid)]
353    #[case(Atom::BYTES_32, Atom::BYTES, Comparison::Assign)]
354    #[case(Atom::BYTES_32, Atom::BYTES_32, Comparison::Assign)]
355    #[case(Atom::BYTES_32, Atom::PUBLIC_KEY, Comparison::Invalid)]
356    #[case(Atom::BYTES_32, Atom::INT, Comparison::Cast)]
357    #[case(Atom::PUBLIC_KEY, Atom::NIL, Comparison::Invalid)]
358    #[case(Atom::PUBLIC_KEY, Atom::FALSE, Comparison::Invalid)]
359    #[case(Atom::PUBLIC_KEY, Atom::TRUE, Comparison::Invalid)]
360    #[case(Atom::PUBLIC_KEY, Atom::BYTES, Comparison::Cast)]
361    #[case(Atom::PUBLIC_KEY, Atom::BYTES_32, Comparison::Invalid)]
362    #[case(Atom::PUBLIC_KEY, Atom::PUBLIC_KEY, Comparison::Assign)]
363    #[case(Atom::PUBLIC_KEY, Atom::INT, Comparison::Cast)]
364    #[case(Atom::INT, Atom::NIL, Comparison::Invalid)]
365    #[case(Atom::INT, Atom::FALSE, Comparison::Invalid)]
366    #[case(Atom::INT, Atom::TRUE, Comparison::Invalid)]
367    #[case(Atom::INT, Atom::BYTES, Comparison::Cast)]
368    #[case(Atom::INT, Atom::BYTES_32, Comparison::Invalid)]
369    #[case(Atom::INT, Atom::PUBLIC_KEY, Comparison::Invalid)]
370    #[case(Atom::INT, Atom::INT, Comparison::Assign)]
371    #[case(Atom::new(AtomSemantic::Int, Some(AtomRestriction::Value(Cow::Borrowed(&[1])))), Atom::INT, Comparison::Assign)]
372    #[case(Atom::new(AtomSemantic::Int, Some(AtomRestriction::Value(Cow::Borrowed(&[1])))), Atom::BYTES, Comparison::Cast)]
373    fn test_atoms(#[case] lhs: Atom, #[case] rhs: Atom, #[case] expected: Comparison) {
374        let mut arena = Arena::new();
375        let builtins = BuiltinTypes::new(&mut arena);
376        let lhs_id = arena.alloc(Type::Atom(lhs.clone()));
377        let rhs_id = arena.alloc(Type::Atom(rhs.clone()));
378        assert_eq!(
379            compare(&mut arena, &builtins, lhs_id, rhs_id),
380            expected,
381            "{lhs} -> {rhs}"
382        );
383    }
384}