rue_types/
check.rs

1use std::{cmp::max, collections::HashSet};
2
3use id_arena::Arena;
4use indexmap::{IndexMap, IndexSet, indexset};
5use log::trace;
6use thiserror::Error;
7
8use crate::{
9    AtomRestriction, BuiltinTypes, Comparison, ComparisonContext, Pair, Type, TypeId, compare_impl,
10    stringify_impl, substitute,
11};
12
13#[derive(Debug, Clone, PartialEq, Eq)]
14#[must_use]
15pub enum Check {
16    None,
17    Impossible,
18    IsAtom,
19    IsPair,
20    Pair(Box<Check>, Box<Check>),
21    Atom(AtomRestriction),
22    And(Vec<Check>),
23    Or(Vec<Check>),
24}
25
26impl Check {
27    pub fn simplify(self) -> Check {
28        match self {
29            Check::None => Check::None,
30            Check::Impossible => Check::Impossible,
31            Check::Pair(first, rest) => {
32                let first = first.simplify();
33                let rest = rest.simplify();
34                Check::Pair(Box::new(first), Box::new(rest))
35            }
36            Check::IsAtom => Check::IsAtom,
37            Check::IsPair => Check::IsPair,
38            Check::Atom(restriction) => Check::Atom(restriction),
39            Check::And(checks) => {
40                let mut flattened = Vec::new();
41
42                for check in checks {
43                    match check.simplify() {
44                        Check::None => {}
45                        Check::Impossible => {
46                            return Check::Impossible;
47                        }
48                        Check::And(inner) => {
49                            flattened.extend(inner);
50                        }
51                        check => {
52                            flattened.push(check);
53                        }
54                    }
55                }
56
57                let mut listp = None;
58                let mut length = None;
59                let mut value = None;
60                let mut result = Vec::new();
61
62                for check in flattened {
63                    match check {
64                        Check::None | Check::Impossible | Check::And(_) => unreachable!(),
65                        Check::IsAtom => {
66                            if listp == Some(true) {
67                                return Check::Impossible;
68                            }
69                            listp = Some(false);
70                        }
71                        Check::IsPair => {
72                            if listp == Some(false) {
73                                return Check::Impossible;
74                            }
75                            listp = Some(true);
76                        }
77                        Check::Atom(AtomRestriction::Length(check)) => {
78                            if length.is_some_and(|length| length != check) {
79                                return Check::Impossible;
80                            }
81                            length = Some(check);
82                        }
83                        Check::Atom(AtomRestriction::Value(check)) => {
84                            if value.is_some_and(|value| value != check) {
85                                return Check::Impossible;
86                            }
87                            value = Some(check);
88                        }
89                        check @ (Check::Or(_) | Check::Pair(..)) => {
90                            result.push(check);
91                        }
92                    }
93                }
94
95                match (length, value) {
96                    (Some(length), Some(value)) => {
97                        if length != value.len() {
98                            return Check::Impossible;
99                        }
100                        result.insert(0, Check::Atom(AtomRestriction::Value(value)));
101                    }
102                    (None, Some(value)) => {
103                        result.insert(0, Check::Atom(AtomRestriction::Value(value)));
104                    }
105                    (Some(length), None) => {
106                        result.insert(0, Check::Atom(AtomRestriction::Length(length)));
107                    }
108                    (None, None) => {}
109                }
110
111                match listp {
112                    Some(true) => result.insert(0, Check::IsPair),
113                    Some(false) => result.insert(0, Check::IsAtom),
114                    None => {}
115                }
116
117                if result.is_empty() {
118                    Check::None
119                } else if result.len() == 1 {
120                    result[0].clone()
121                } else {
122                    Check::And(result)
123                }
124            }
125            Check::Or(checks) => Check::Or(checks.into_iter().map(Check::simplify).collect()),
126        }
127    }
128}
129
130#[derive(Debug, Clone, Copy, Error)]
131pub enum CheckError {
132    #[error("Maximum type check depth reached")]
133    DepthExceeded,
134
135    #[error("Cannot check if value is of function type at runtime")]
136    FunctionType,
137}
138
139#[derive(Debug)]
140struct CheckContext {
141    depth: usize,
142}
143
144pub fn check(
145    arena: &mut Arena<Type>,
146    builtins: &BuiltinTypes,
147    lhs: TypeId,
148    rhs: TypeId,
149) -> Result<Check, CheckError> {
150    let lhs = substitute(arena, lhs);
151    let rhs = substitute(arena, rhs);
152    let lhs_name = stringify_impl(arena, lhs, &mut IndexMap::new());
153    let rhs_name = stringify_impl(arena, rhs, &mut IndexMap::new());
154    trace!("Checking {lhs_name} to {rhs_name}");
155    let result = check_impl(arena, builtins, &mut CheckContext { depth: 0 }, lhs, rhs);
156    trace!("Check from {lhs_name} to {rhs_name} yielded {result:?}");
157    result
158}
159
160fn check_impl(
161    arena: &Arena<Type>,
162    builtins: &BuiltinTypes,
163    ctx: &mut CheckContext,
164    lhs: TypeId,
165    rhs: TypeId,
166) -> Result<Check, CheckError> {
167    let mut variants = variants_of(arena, builtins, lhs)
168        .into_iter()
169        .enumerate()
170        .collect();
171    check_each(arena, builtins, ctx, &mut variants, rhs)
172}
173
174fn check_each(
175    arena: &Arena<Type>,
176    builtins: &BuiltinTypes,
177    ctx: &mut CheckContext,
178    lhs: &mut Vec<(usize, TypeId)>,
179    rhs: TypeId,
180) -> Result<Check, CheckError> {
181    ctx.depth += 1;
182
183    if ctx.depth > 25 {
184        return Err(CheckError::DepthExceeded);
185    }
186
187    let mut result = Comparison::Assign;
188
189    for &(_, lhs) in &*lhs {
190        result = max(
191            result,
192            compare_impl(
193                arena,
194                builtins,
195                &mut ComparisonContext::default(),
196                lhs,
197                rhs,
198                None,
199                None,
200            ),
201        );
202    }
203
204    if result <= Comparison::Cast {
205        return Ok(Check::None);
206    }
207
208    let target_atoms = atoms_of(arena, rhs)?;
209
210    let mut overlap = IndexSet::new();
211    let mut exceeds_overlap = false;
212    let mut unrestricted = false;
213    let mut lhs_has_atom = false;
214    let mut error = None;
215
216    lhs.retain(|&(_, id)| {
217        if error.is_some() {
218            return true;
219        }
220
221        let atoms = match atoms_of(arena, id) {
222            Ok(atoms) => atoms,
223            Err(err) => {
224                error = Some(err);
225                return true;
226            }
227        };
228
229        if atoms.is_some() {
230            lhs_has_atom = true;
231        }
232
233        match (atoms, &target_atoms) {
234            (Some(_), None) => false,
235            (Some(_), Some(Atoms::Unrestricted)) | (None, _) => true,
236            (Some(Atoms::Unrestricted), Some(Atoms::Restricted(restrictions))) => {
237                exceeds_overlap = true;
238                unrestricted = true;
239                overlap.clone_from(restrictions);
240                true
241            }
242            (
243                Some(Atoms::Restricted(restrictions)),
244                Some(Atoms::Restricted(target_restrictions)),
245            ) => {
246                let mut has_overlap = false;
247
248                for restriction in restrictions {
249                    if target_restrictions.contains(&restriction) {
250                        overlap.insert(restriction);
251                        has_overlap = true;
252                        continue;
253                    }
254
255                    match restriction {
256                        AtomRestriction::Value(value) => {
257                            let length = AtomRestriction::Length(value.len());
258                            if target_restrictions.contains(&length) {
259                                overlap.insert(length);
260                                has_overlap = true;
261                                continue;
262                            }
263                        }
264                        AtomRestriction::Length(_) => {}
265                    }
266
267                    exceeds_overlap = true;
268                }
269
270                has_overlap
271            }
272        }
273    });
274
275    if let Some(error) = error {
276        return Err(error);
277    }
278
279    let atom_result = lhs_has_atom.then(|| {
280        if target_atoms.is_none() {
281            Check::Impossible
282        } else if !exceeds_overlap {
283            Check::None
284        } else if overlap.is_empty() {
285            Check::Impossible
286        } else if overlap.len() == 1 {
287            overlap.into_iter().next().map(Check::Atom).unwrap()
288        } else {
289            Check::And(overlap.into_iter().map(Check::Atom).collect())
290        }
291    });
292
293    let target_pairs = pairs_of(arena, builtins, rhs)?;
294
295    let mut checks = Vec::new();
296    let mut included_indices = IndexSet::new();
297    let mut candidate_pairs = HashSet::new();
298    let mut requires_check = false;
299    let mut lhs_has_pair = false;
300
301    let mut firsts = Vec::new();
302
303    for &(i, lhs) in &*lhs {
304        for pair in pairs_of(arena, builtins, lhs)? {
305            for ty in variants_of(arena, builtins, pair.first) {
306                candidate_pairs.insert(lhs);
307                firsts.push((i, ty));
308                lhs_has_pair = true;
309            }
310        }
311    }
312
313    for target_pair in &target_pairs {
314        let mut firsts = firsts.clone();
315
316        let first = check_each(arena, builtins, ctx, &mut firsts, target_pair.first)?;
317
318        if first == Check::Impossible {
319            requires_check = true;
320            continue;
321        }
322
323        let mut rests = Vec::new();
324
325        for (i, _) in firsts {
326            for pair in pairs_of(
327                arena,
328                builtins,
329                lhs.iter().find(|(j, _)| *j == i).unwrap().1,
330            )? {
331                for ty in variants_of(arena, builtins, pair.rest) {
332                    rests.push((i, ty));
333                }
334            }
335        }
336
337        let rest = check_each(arena, builtins, ctx, &mut rests, target_pair.rest)?;
338
339        if rest == Check::Impossible {
340            requires_check = true;
341            continue;
342        }
343
344        for (i, _) in rests {
345            included_indices.insert(i);
346        }
347
348        if first == Check::None && rest == Check::None {
349            continue;
350        }
351
352        requires_check = true;
353
354        checks.push(Check::Pair(Box::new(first), Box::new(rest)));
355    }
356
357    lhs.retain(|&(i, type_id)| {
358        !candidate_pairs.contains(&type_id) || included_indices.contains(&i)
359    });
360
361    let pair_result = lhs_has_pair.then(|| {
362        if target_pairs.is_empty() {
363            Check::Impossible
364        } else if !requires_check {
365            Check::None
366        } else if checks.is_empty() {
367            Check::Impossible
368        } else if checks.len() == 1 {
369            checks[0].clone()
370        } else {
371            Check::Or(checks)
372        }
373    });
374
375    let check = match (atom_result, pair_result) {
376        (None, None) => Check::Impossible,
377        (Some(atom), None) => atom,
378        (None, Some(pair)) => pair,
379        (Some(atom), Some(Check::Impossible)) => Check::And(vec![Check::IsAtom, atom]),
380        (Some(Check::Impossible), Some(pair)) => Check::And(vec![Check::IsPair, pair]),
381        (Some(atom), Some(Check::None)) => Check::Or(vec![Check::IsPair, atom]),
382        (Some(Check::None), Some(pair)) => Check::Or(vec![Check::IsAtom, pair]),
383        (Some(atom), Some(pair)) => Check::Or(vec![
384            Check::And(vec![Check::IsAtom, atom]),
385            Check::And(vec![Check::IsPair, pair]),
386        ]),
387    };
388
389    Ok(check.simplify())
390}
391
392fn variants_of(arena: &Arena<Type>, builtins: &BuiltinTypes, id: TypeId) -> Vec<TypeId> {
393    match arena[id].clone() {
394        Type::Apply(_) => unreachable!(),
395        Type::Ref(id) => variants_of(arena, builtins, id),
396        Type::Unresolved | Type::Atom(_) | Type::Pair(_) | Type::Generic => {
397            vec![id]
398        }
399        Type::Never => vec![],
400        Type::Alias(alias) => variants_of(arena, builtins, alias.inner),
401        Type::Struct(ty) => variants_of(arena, builtins, ty.inner),
402        Type::Function(_) => vec![builtins.atom, builtins.any_pair],
403        Type::Union(ty) => {
404            let mut variants = Vec::new();
405
406            for variant in ty.types {
407                variants.extend(variants_of(arena, builtins, variant));
408            }
409
410            variants
411        }
412    }
413}
414
415#[derive(Debug, Clone)]
416enum Atoms {
417    Unrestricted,
418    Restricted(IndexSet<AtomRestriction>),
419}
420
421fn atoms_of(arena: &Arena<Type>, id: TypeId) -> Result<Option<Atoms>, CheckError> {
422    Ok(match arena[id].clone() {
423        Type::Apply(_) => unreachable!(),
424        Type::Ref(id) => atoms_of(arena, id)?,
425        Type::Unresolved | Type::Never | Type::Pair(_) => None,
426        Type::Generic => Some(Atoms::Unrestricted),
427        Type::Atom(atom) => {
428            let Some(restriction) = atom.restriction else {
429                return Ok(Some(Atoms::Unrestricted));
430            };
431            Some(Atoms::Restricted(indexset![restriction]))
432        }
433        Type::Alias(alias) => atoms_of(arena, alias.inner)?,
434        Type::Struct(ty) => atoms_of(arena, ty.inner)?,
435        Type::Function(_) => return Err(CheckError::FunctionType),
436        Type::Union(ty) => {
437            let mut restrictions = IndexSet::new();
438
439            for variant in ty.types {
440                match atoms_of(arena, variant)? {
441                    None => {}
442                    Some(Atoms::Unrestricted) => return Ok(Some(Atoms::Unrestricted)),
443                    Some(Atoms::Restricted(new)) => {
444                        for restriction in new {
445                            match &restriction {
446                                AtomRestriction::Value(value) => {
447                                    if restrictions.contains(&AtomRestriction::Length(value.len()))
448                                    {
449                                        continue;
450                                    }
451                                    restrictions.insert(restriction);
452                                }
453                                AtomRestriction::Length(length) => {
454                                    restrictions.retain(|item| match item {
455                                        AtomRestriction::Value(value) => value.len() != *length,
456                                        AtomRestriction::Length(_) => true,
457                                    });
458                                    restrictions.insert(restriction);
459                                }
460                            }
461                        }
462                    }
463                }
464            }
465
466            if restrictions.is_empty() {
467                None
468            } else {
469                Some(Atoms::Restricted(restrictions))
470            }
471        }
472    })
473}
474
475fn pairs_of(
476    arena: &Arena<Type>,
477    builtins: &BuiltinTypes,
478    id: TypeId,
479) -> Result<Vec<Pair>, CheckError> {
480    Ok(match arena[id].clone() {
481        Type::Apply(_) => unreachable!(),
482        Type::Ref(id) => pairs_of(arena, builtins, id)?,
483        Type::Unresolved | Type::Never | Type::Atom(_) => vec![],
484        Type::Pair(pair) => vec![pair],
485        Type::Generic => vec![Pair::new(builtins.any, builtins.any)],
486        Type::Alias(alias) => pairs_of(arena, builtins, alias.inner)?,
487        Type::Struct(ty) => pairs_of(arena, builtins, ty.inner)?,
488        Type::Function(_) => return Err(CheckError::FunctionType),
489        Type::Union(ty) => {
490            let mut pairs = Vec::new();
491
492            for variant in ty.types {
493                pairs.extend(pairs_of(arena, builtins, variant)?);
494            }
495
496            pairs
497        }
498    })
499}