rue_types/
check.rs

1use std::cmp::max;
2
3use id_arena::Arena;
4use indexmap::{IndexMap, IndexSet, indexset};
5use log::trace;
6use thiserror::Error;
7
8use crate::{
9    AtomRestriction, Atoms, BuiltinTypes, Comparison, ComparisonContext, Pair, Type, TypeId,
10    compare_impl, stringify_impl, substitute,
11};
12
13#[derive(Debug, Clone, PartialEq, Eq)]
14#[must_use]
15pub enum Check {
16    None,
17    Impossible,
18    IsAtom { can_be_truthy: bool },
19    IsPair { can_be_truthy: bool },
20    Pair(Box<Check>, Box<Check>),
21    Atom(AtomRestriction),
22    And(Vec<Check>),
23    Or(Vec<Check>),
24}
25
26impl Check {
27    pub fn simplify(self) -> Check {
28        match self {
29            Check::None => Check::None,
30            Check::Impossible => Check::Impossible,
31            Check::Pair(first, rest) => {
32                let first = first.simplify();
33                let rest = rest.simplify();
34                Check::Pair(Box::new(first), Box::new(rest))
35            }
36            Check::IsAtom { can_be_truthy } => Check::IsAtom { can_be_truthy },
37            Check::IsPair { can_be_truthy } => Check::IsPair { can_be_truthy },
38            Check::Atom(restriction) => Check::Atom(restriction),
39            Check::And(checks) => {
40                let mut flattened = Vec::new();
41
42                for check in checks {
43                    match check.simplify() {
44                        Check::None => {}
45                        Check::Impossible => {
46                            return Check::Impossible;
47                        }
48                        Check::And(inner) => {
49                            flattened.extend(inner);
50                        }
51                        check => {
52                            flattened.push(check);
53                        }
54                    }
55                }
56
57                let mut listp = None;
58                let mut length = None;
59                let mut value = None;
60                let mut result = Vec::new();
61
62                for check in flattened {
63                    match check {
64                        Check::None | Check::Impossible | Check::And(_) => unreachable!(),
65                        Check::IsAtom { can_be_truthy } => {
66                            if matches!(listp, Some(Check::IsPair { .. })) {
67                                return Check::Impossible;
68                            }
69                            listp = Some(Check::IsAtom { can_be_truthy });
70                        }
71                        Check::IsPair { can_be_truthy } => {
72                            if matches!(listp, Some(Check::IsAtom { .. })) {
73                                return Check::Impossible;
74                            }
75                            listp = Some(Check::IsPair { can_be_truthy });
76                        }
77                        Check::Atom(AtomRestriction::Length(check)) => {
78                            if length.is_some_and(|length| length != check) {
79                                return Check::Impossible;
80                            }
81                            length = Some(check);
82                        }
83                        Check::Atom(AtomRestriction::Value(check)) => {
84                            if value.is_some_and(|value| value != check) {
85                                return Check::Impossible;
86                            }
87                            value = Some(check);
88                        }
89                        check @ (Check::Or(_) | Check::Pair(..)) => {
90                            result.push(check);
91                        }
92                    }
93                }
94
95                match (length, value) {
96                    (Some(length), Some(value)) => {
97                        if length != value.len() {
98                            return Check::Impossible;
99                        }
100                        result.insert(0, Check::Atom(AtomRestriction::Value(value)));
101                    }
102                    (None, Some(value)) => {
103                        result.insert(0, Check::Atom(AtomRestriction::Value(value)));
104                    }
105                    (Some(length), None) => {
106                        result.insert(0, Check::Atom(AtomRestriction::Length(length)));
107                    }
108                    (None, None) => {}
109                }
110
111                if let Some(listp) = listp {
112                    result.insert(0, listp);
113                }
114
115                if result.is_empty() {
116                    Check::None
117                } else if result.len() == 1 {
118                    result[0].clone()
119                } else {
120                    Check::And(result)
121                }
122            }
123            Check::Or(checks) => Check::Or(checks.into_iter().map(Check::simplify).collect()),
124        }
125    }
126}
127
128#[derive(Debug, Clone, Copy, Error)]
129pub enum CheckError {
130    #[error("Maximum type check depth reached")]
131    DepthExceeded,
132
133    #[error("Cannot check if value is of function type at runtime")]
134    FunctionType,
135}
136
137#[derive(Debug)]
138struct CheckContext {
139    depth: usize,
140}
141
142pub fn check(
143    arena: &mut Arena<Type>,
144    builtins: &BuiltinTypes,
145    lhs: TypeId,
146    rhs: TypeId,
147) -> Result<Check, CheckError> {
148    let lhs = substitute(arena, lhs);
149    let rhs = substitute(arena, rhs);
150    let lhs_name = stringify_impl(arena, lhs, &mut IndexMap::new());
151    let rhs_name = stringify_impl(arena, rhs, &mut IndexMap::new());
152    trace!("Checking {lhs_name} to {rhs_name}");
153    let result = check_impl(arena, builtins, &mut CheckContext { depth: 0 }, lhs, rhs);
154    trace!("Check from {lhs_name} to {rhs_name} yielded {result:?}");
155    result
156}
157
158fn check_impl(
159    arena: &Arena<Type>,
160    builtins: &BuiltinTypes,
161    ctx: &mut CheckContext,
162    lhs: TypeId,
163    rhs: TypeId,
164) -> Result<Check, CheckError> {
165    let variants = variants_of(arena, builtins, lhs);
166    check_each(arena, builtins, ctx, &variants, rhs)
167}
168
169fn check_each(
170    arena: &Arena<Type>,
171    builtins: &BuiltinTypes,
172    ctx: &mut CheckContext,
173    lhs: &[TypeId],
174    rhs: TypeId,
175) -> Result<Check, CheckError> {
176    ctx.depth += 1;
177
178    if ctx.depth > 25 {
179        return Err(CheckError::DepthExceeded);
180    }
181
182    let mut result = Comparison::Assign;
183
184    for &id in lhs {
185        result = max(
186            result,
187            compare_impl(
188                arena,
189                builtins,
190                &mut ComparisonContext::default(),
191                id,
192                rhs,
193                None,
194                None,
195            ),
196        );
197    }
198
199    if result <= Comparison::Cast {
200        return Ok(Check::None);
201    }
202
203    let target_atoms = atoms_of(arena, rhs)?;
204
205    let mut overlap = IndexSet::new();
206    let mut exceeds_overlap = false;
207    let mut lhs_has_atom = false;
208    let mut can_be_truthy = false;
209
210    for &id in lhs {
211        let atoms = atoms_of(arena, id)?;
212
213        if let Some(atoms) = &atoms {
214            lhs_has_atom = true;
215
216            match atoms {
217                Atoms::Unrestricted => {
218                    can_be_truthy = true;
219                }
220                Atoms::Restricted(restrictions) => {
221                    for restriction in restrictions {
222                        match restriction {
223                            AtomRestriction::Value(value) => {
224                                if !value.is_empty() {
225                                    can_be_truthy = true;
226                                }
227                            }
228                            AtomRestriction::Length(length) => {
229                                if *length > 0 {
230                                    can_be_truthy = true;
231                                }
232                            }
233                        }
234                    }
235                }
236            }
237        }
238
239        match (atoms, &target_atoms) {
240            (Some(_), None) => {}
241            (Some(_), Some(Atoms::Unrestricted)) | (None, _) => {}
242            (Some(Atoms::Unrestricted), Some(Atoms::Restricted(restrictions))) => {
243                exceeds_overlap = true;
244
245                for restriction in restrictions {
246                    if let AtomRestriction::Value(value) = restriction
247                        && overlap.contains(&AtomRestriction::Length(value.len()))
248                    {
249                        continue;
250                    }
251                    overlap.insert(restriction.clone());
252                }
253            }
254            (
255                Some(Atoms::Restricted(restrictions)),
256                Some(Atoms::Restricted(target_restrictions)),
257            ) => {
258                for restriction in restrictions {
259                    if target_restrictions.contains(&restriction) {
260                        overlap.insert(restriction);
261                        continue;
262                    }
263
264                    match restriction {
265                        AtomRestriction::Value(value) => {
266                            let length = AtomRestriction::Length(value.len());
267                            if target_restrictions.contains(&length) {
268                                overlap.insert(length);
269                                continue;
270                            }
271                        }
272                        AtomRestriction::Length(_) => {}
273                    }
274
275                    exceeds_overlap = true;
276                }
277            }
278        }
279    }
280
281    let atom_result = lhs_has_atom.then(|| {
282        if target_atoms.is_none() {
283            Check::Impossible
284        } else if !exceeds_overlap {
285            Check::None
286        } else if overlap.is_empty() {
287            Check::Impossible
288        } else if overlap.len() == 1 {
289            overlap.into_iter().next().map(Check::Atom).unwrap()
290        } else {
291            Check::Or(overlap.into_iter().map(Check::Atom).collect())
292        }
293    });
294
295    let target_pairs = pairs_of(arena, builtins, rhs)?;
296
297    let mut checks = Vec::new();
298    let mut requires_check = false;
299
300    let mut pairs = Vec::new();
301
302    for &lhs in lhs {
303        for pair in pairs_of(arena, builtins, lhs)? {
304            pairs.push(pair);
305        }
306    }
307
308    for target_pair in &target_pairs {
309        let pairs = pairs.clone();
310
311        let firsts: Vec<TypeId> = pairs.iter().map(|pair| pair.first).collect();
312        let first = check_each(arena, builtins, ctx, &firsts, target_pair.first)?;
313
314        if first == Check::Impossible {
315            requires_check = true;
316            continue;
317        }
318
319        let mut rests = Vec::new();
320
321        // TODO: We can do this in inverse (rest then first) as well and see which check is simpler
322        for pair in pairs {
323            if compare_impl(
324                arena,
325                builtins,
326                &mut ComparisonContext::default(),
327                target_pair.first,
328                pair.first,
329                None,
330                None,
331            ) > Comparison::Cast
332            {
333                continue;
334            }
335
336            rests.push(pair.rest);
337        }
338
339        if rests.is_empty() {
340            requires_check = true;
341            continue;
342        }
343
344        let rest = check_each(arena, builtins, ctx, &rests, target_pair.rest)?;
345
346        if rest == Check::Impossible {
347            requires_check = true;
348            continue;
349        }
350
351        if first == Check::None && rest == Check::None {
352            // TODO: Should this set requires_check to false and break?
353            continue;
354        }
355
356        requires_check = true;
357
358        checks.push(Check::Pair(Box::new(first), Box::new(rest)));
359    }
360
361    let pair_result = (!pairs.is_empty()).then(|| {
362        if target_pairs.is_empty() {
363            Check::Impossible
364        } else if !requires_check {
365            Check::None
366        } else if checks.is_empty() {
367            Check::Impossible
368        } else if checks.len() == 1 {
369            checks[0].clone()
370        } else {
371            Check::Or(checks)
372        }
373    });
374
375    let check = match (atom_result, pair_result) {
376        (None, None) => Check::Impossible,
377        (Some(atom), None) => atom,
378        (None, Some(pair)) => pair,
379        (Some(atom), Some(Check::Impossible)) => {
380            Check::And(vec![Check::IsAtom { can_be_truthy }, atom])
381        }
382        (Some(Check::Impossible), Some(pair)) => {
383            Check::And(vec![Check::IsPair { can_be_truthy }, pair])
384        }
385        (Some(atom), Some(Check::None)) => Check::Or(vec![Check::IsPair { can_be_truthy }, atom]),
386        (Some(Check::None), Some(pair)) => Check::Or(vec![Check::IsAtom { can_be_truthy }, pair]),
387        (Some(atom), Some(pair)) => Check::Or(vec![
388            Check::And(vec![Check::IsAtom { can_be_truthy }, atom]),
389            Check::And(vec![Check::IsPair { can_be_truthy }, pair]),
390        ]),
391    };
392
393    Ok(check.simplify())
394}
395
396fn variants_of(arena: &Arena<Type>, builtins: &BuiltinTypes, id: TypeId) -> Vec<TypeId> {
397    match arena[id].clone() {
398        Type::Apply(_) => unreachable!(),
399        Type::Ref(id) => variants_of(arena, builtins, id),
400        Type::Unresolved | Type::Atom(_) | Type::Pair(_) | Type::Generic(_) => {
401            vec![id]
402        }
403        Type::Never => vec![],
404        Type::Alias(alias) => variants_of(arena, builtins, alias.inner),
405        Type::Struct(ty) => variants_of(arena, builtins, ty.inner),
406        Type::Function(_) | Type::Any => vec![builtins.atom, builtins.recursive_any_pair],
407        Type::Union(ty) => {
408            let mut variants = Vec::new();
409
410            for variant in ty.types {
411                variants.extend(variants_of(arena, builtins, variant));
412            }
413
414            variants
415        }
416    }
417}
418
419fn atoms_of(arena: &Arena<Type>, id: TypeId) -> Result<Option<Atoms>, CheckError> {
420    Ok(match arena[id].clone() {
421        Type::Apply(_) => unreachable!(),
422        Type::Ref(id) => atoms_of(arena, id)?,
423        Type::Unresolved | Type::Never | Type::Pair(_) => None,
424        Type::Generic(_) | Type::Any => Some(Atoms::Unrestricted),
425        Type::Atom(atom) => {
426            let Some(restriction) = atom.restriction else {
427                return Ok(Some(Atoms::Unrestricted));
428            };
429            Some(Atoms::Restricted(indexset![restriction]))
430        }
431        Type::Alias(alias) => atoms_of(arena, alias.inner)?,
432        Type::Struct(ty) => atoms_of(arena, ty.inner)?,
433        Type::Function(_) => return Err(CheckError::FunctionType),
434        Type::Union(ty) => {
435            let mut restrictions = IndexSet::new();
436
437            for variant in ty.types {
438                match atoms_of(arena, variant)? {
439                    None => {}
440                    Some(Atoms::Unrestricted) => return Ok(Some(Atoms::Unrestricted)),
441                    Some(Atoms::Restricted(new)) => {
442                        for restriction in new {
443                            match &restriction {
444                                AtomRestriction::Value(value) => {
445                                    if restrictions.contains(&AtomRestriction::Length(value.len()))
446                                    {
447                                        continue;
448                                    }
449                                    restrictions.insert(restriction);
450                                }
451                                AtomRestriction::Length(length) => {
452                                    restrictions.retain(|item| match item {
453                                        AtomRestriction::Value(value) => value.len() != *length,
454                                        AtomRestriction::Length(_) => true,
455                                    });
456                                    restrictions.insert(restriction);
457                                }
458                            }
459                        }
460                    }
461                }
462            }
463
464            if restrictions.is_empty() {
465                None
466            } else {
467                Some(Atoms::Restricted(restrictions))
468            }
469        }
470    })
471}
472
473fn pairs_of(
474    arena: &Arena<Type>,
475    builtins: &BuiltinTypes,
476    id: TypeId,
477) -> Result<Vec<Pair>, CheckError> {
478    Ok(match arena[id].clone() {
479        Type::Apply(_) => unreachable!(),
480        Type::Ref(id) => pairs_of(arena, builtins, id)?,
481        Type::Unresolved | Type::Never | Type::Atom(_) => vec![],
482        Type::Pair(pair) => vec![pair],
483        Type::Generic(_) | Type::Any => {
484            vec![Pair::new(builtins.recursive_any, builtins.recursive_any)]
485        }
486        Type::Alias(alias) => pairs_of(arena, builtins, alias.inner)?,
487        Type::Struct(ty) => pairs_of(arena, builtins, ty.inner)?,
488        Type::Function(_) => return Err(CheckError::FunctionType),
489        Type::Union(ty) => {
490            let mut pairs = Vec::new();
491
492            for variant in ty.types {
493                pairs.extend(pairs_of(arena, builtins, variant)?);
494            }
495
496            pairs
497        }
498    })
499}