1use std::cmp::max;
2
3use id_arena::Arena;
4use indexmap::{IndexMap, IndexSet, indexset};
5use log::trace;
6use thiserror::Error;
7
8use crate::{
9 AtomRestriction, Atoms, BuiltinTypes, Comparison, ComparisonContext, Pair, Type, TypeId,
10 compare_impl, stringify_impl, substitute,
11};
12
13#[derive(Debug, Clone, PartialEq, Eq)]
14#[must_use]
15pub enum Check {
16 None,
17 Impossible,
18 IsAtom { can_be_truthy: bool },
19 IsPair { can_be_truthy: bool },
20 Pair(Box<Check>, Box<Check>),
21 Atom(AtomRestriction),
22 And(Vec<Check>),
23 Or(Vec<Check>),
24}
25
26impl Check {
27 pub fn simplify(self) -> Check {
28 match self {
29 Check::None => Check::None,
30 Check::Impossible => Check::Impossible,
31 Check::Pair(first, rest) => {
32 let first = first.simplify();
33 let rest = rest.simplify();
34 Check::Pair(Box::new(first), Box::new(rest))
35 }
36 Check::IsAtom { can_be_truthy } => Check::IsAtom { can_be_truthy },
37 Check::IsPair { can_be_truthy } => Check::IsPair { can_be_truthy },
38 Check::Atom(restriction) => Check::Atom(restriction),
39 Check::And(checks) => {
40 let mut flattened = Vec::new();
41
42 for check in checks {
43 match check.simplify() {
44 Check::None => {}
45 Check::Impossible => {
46 return Check::Impossible;
47 }
48 Check::And(inner) => {
49 flattened.extend(inner);
50 }
51 check => {
52 flattened.push(check);
53 }
54 }
55 }
56
57 let mut listp = None;
58 let mut length = None;
59 let mut value = None;
60 let mut result = Vec::new();
61
62 for check in flattened {
63 match check {
64 Check::None | Check::Impossible | Check::And(_) => unreachable!(),
65 Check::IsAtom { can_be_truthy } => {
66 if matches!(listp, Some(Check::IsPair { .. })) {
67 return Check::Impossible;
68 }
69 listp = Some(Check::IsAtom { can_be_truthy });
70 }
71 Check::IsPair { can_be_truthy } => {
72 if matches!(listp, Some(Check::IsAtom { .. })) {
73 return Check::Impossible;
74 }
75 listp = Some(Check::IsPair { can_be_truthy });
76 }
77 Check::Atom(AtomRestriction::Length(check)) => {
78 if length.is_some_and(|length| length != check) {
79 return Check::Impossible;
80 }
81 length = Some(check);
82 }
83 Check::Atom(AtomRestriction::Value(check)) => {
84 if value.is_some_and(|value| value != check) {
85 return Check::Impossible;
86 }
87 value = Some(check);
88 }
89 check @ (Check::Or(_) | Check::Pair(..)) => {
90 result.push(check);
91 }
92 }
93 }
94
95 match (length, value) {
96 (Some(length), Some(value)) => {
97 if length != value.len() {
98 return Check::Impossible;
99 }
100 result.insert(0, Check::Atom(AtomRestriction::Value(value)));
101 }
102 (None, Some(value)) => {
103 result.insert(0, Check::Atom(AtomRestriction::Value(value)));
104 }
105 (Some(length), None) => {
106 result.insert(0, Check::Atom(AtomRestriction::Length(length)));
107 }
108 (None, None) => {}
109 }
110
111 if let Some(listp) = listp {
112 result.insert(0, listp);
113 }
114
115 if result.is_empty() {
116 Check::None
117 } else if result.len() == 1 {
118 result[0].clone()
119 } else {
120 Check::And(result)
121 }
122 }
123 Check::Or(checks) => Check::Or(checks.into_iter().map(Check::simplify).collect()),
124 }
125 }
126}
127
128#[derive(Debug, Clone, Copy, Error)]
129pub enum CheckError {
130 #[error("Maximum type check depth reached")]
131 DepthExceeded,
132
133 #[error("Cannot check if value is of function type at runtime")]
134 FunctionType,
135}
136
137#[derive(Debug)]
138struct CheckContext {
139 depth: usize,
140}
141
142pub fn check(
143 arena: &mut Arena<Type>,
144 builtins: &BuiltinTypes,
145 lhs: TypeId,
146 rhs: TypeId,
147) -> Result<Check, CheckError> {
148 let lhs = substitute(arena, lhs);
149 let rhs = substitute(arena, rhs);
150 let lhs_name = stringify_impl(arena, lhs, &mut IndexMap::new());
151 let rhs_name = stringify_impl(arena, rhs, &mut IndexMap::new());
152 trace!("Checking {lhs_name} to {rhs_name}");
153 let result = check_impl(arena, builtins, &mut CheckContext { depth: 0 }, lhs, rhs);
154 trace!("Check from {lhs_name} to {rhs_name} yielded {result:?}");
155 result
156}
157
158fn check_impl(
159 arena: &Arena<Type>,
160 builtins: &BuiltinTypes,
161 ctx: &mut CheckContext,
162 lhs: TypeId,
163 rhs: TypeId,
164) -> Result<Check, CheckError> {
165 let variants = variants_of(arena, builtins, lhs);
166 check_each(arena, builtins, ctx, &variants, rhs)
167}
168
169fn check_each(
170 arena: &Arena<Type>,
171 builtins: &BuiltinTypes,
172 ctx: &mut CheckContext,
173 lhs: &[TypeId],
174 rhs: TypeId,
175) -> Result<Check, CheckError> {
176 ctx.depth += 1;
177
178 if ctx.depth > 25 {
179 return Err(CheckError::DepthExceeded);
180 }
181
182 let mut result = Comparison::Assign;
183
184 for &id in lhs {
185 result = max(
186 result,
187 compare_impl(
188 arena,
189 builtins,
190 &mut ComparisonContext::default(),
191 id,
192 rhs,
193 None,
194 None,
195 ),
196 );
197 }
198
199 if result <= Comparison::Cast {
200 return Ok(Check::None);
201 }
202
203 let target_atoms = atoms_of(arena, rhs)?;
204
205 let mut overlap = IndexSet::new();
206 let mut exceeds_overlap = false;
207 let mut lhs_has_atom = false;
208 let mut can_be_truthy = false;
209
210 for &id in lhs {
211 let atoms = atoms_of(arena, id)?;
212
213 if let Some(atoms) = &atoms {
214 lhs_has_atom = true;
215
216 match atoms {
217 Atoms::Unrestricted => {
218 can_be_truthy = true;
219 }
220 Atoms::Restricted(restrictions) => {
221 for restriction in restrictions {
222 match restriction {
223 AtomRestriction::Value(value) => {
224 if !value.is_empty() {
225 can_be_truthy = true;
226 }
227 }
228 AtomRestriction::Length(length) => {
229 if *length > 0 {
230 can_be_truthy = true;
231 }
232 }
233 }
234 }
235 }
236 }
237 }
238
239 match (atoms, &target_atoms) {
240 (Some(_), None) => {}
241 (Some(_), Some(Atoms::Unrestricted)) | (None, _) => {}
242 (Some(Atoms::Unrestricted), Some(Atoms::Restricted(restrictions))) => {
243 exceeds_overlap = true;
244
245 for restriction in restrictions {
246 if let AtomRestriction::Value(value) = restriction
247 && overlap.contains(&AtomRestriction::Length(value.len()))
248 {
249 continue;
250 }
251 overlap.insert(restriction.clone());
252 }
253 }
254 (
255 Some(Atoms::Restricted(restrictions)),
256 Some(Atoms::Restricted(target_restrictions)),
257 ) => {
258 for restriction in restrictions {
259 if target_restrictions.contains(&restriction) {
260 overlap.insert(restriction);
261 continue;
262 }
263
264 match restriction {
265 AtomRestriction::Value(value) => {
266 let length = AtomRestriction::Length(value.len());
267 if target_restrictions.contains(&length) {
268 overlap.insert(length);
269 continue;
270 }
271 }
272 AtomRestriction::Length(_) => {}
273 }
274
275 exceeds_overlap = true;
276 }
277 }
278 }
279 }
280
281 let atom_result = lhs_has_atom.then(|| {
282 if target_atoms.is_none() {
283 Check::Impossible
284 } else if !exceeds_overlap {
285 Check::None
286 } else if overlap.is_empty() {
287 Check::Impossible
288 } else if overlap.len() == 1 {
289 overlap.into_iter().next().map(Check::Atom).unwrap()
290 } else {
291 Check::Or(overlap.into_iter().map(Check::Atom).collect())
292 }
293 });
294
295 let target_pairs = pairs_of(arena, builtins, rhs)?;
296
297 let mut checks = Vec::new();
298 let mut requires_check = false;
299
300 let mut pairs = Vec::new();
301
302 for &lhs in lhs {
303 for pair in pairs_of(arena, builtins, lhs)? {
304 pairs.push(pair);
305 }
306 }
307
308 for target_pair in &target_pairs {
309 let pairs = pairs.clone();
310
311 let firsts: Vec<TypeId> = pairs.iter().map(|pair| pair.first).collect();
312 let first = check_each(arena, builtins, ctx, &firsts, target_pair.first)?;
313
314 if first == Check::Impossible {
315 requires_check = true;
316 continue;
317 }
318
319 let mut rests = Vec::new();
320
321 for pair in pairs {
323 if compare_impl(
324 arena,
325 builtins,
326 &mut ComparisonContext::default(),
327 target_pair.first,
328 pair.first,
329 None,
330 None,
331 ) > Comparison::Cast
332 {
333 continue;
334 }
335
336 rests.push(pair.rest);
337 }
338
339 if rests.is_empty() {
340 requires_check = true;
341 continue;
342 }
343
344 let rest = check_each(arena, builtins, ctx, &rests, target_pair.rest)?;
345
346 if rest == Check::Impossible {
347 requires_check = true;
348 continue;
349 }
350
351 if first == Check::None && rest == Check::None {
352 continue;
354 }
355
356 requires_check = true;
357
358 checks.push(Check::Pair(Box::new(first), Box::new(rest)));
359 }
360
361 let pair_result = (!pairs.is_empty()).then(|| {
362 if target_pairs.is_empty() {
363 Check::Impossible
364 } else if !requires_check {
365 Check::None
366 } else if checks.is_empty() {
367 Check::Impossible
368 } else if checks.len() == 1 {
369 checks[0].clone()
370 } else {
371 Check::Or(checks)
372 }
373 });
374
375 let check = match (atom_result, pair_result) {
376 (None, None) => Check::Impossible,
377 (Some(atom), None) => atom,
378 (None, Some(pair)) => pair,
379 (Some(atom), Some(Check::Impossible)) => {
380 Check::And(vec![Check::IsAtom { can_be_truthy }, atom])
381 }
382 (Some(Check::Impossible), Some(pair)) => {
383 Check::And(vec![Check::IsPair { can_be_truthy }, pair])
384 }
385 (Some(atom), Some(Check::None)) => Check::Or(vec![Check::IsPair { can_be_truthy }, atom]),
386 (Some(Check::None), Some(pair)) => Check::Or(vec![Check::IsAtom { can_be_truthy }, pair]),
387 (Some(atom), Some(pair)) => Check::Or(vec![
388 Check::And(vec![Check::IsAtom { can_be_truthy }, atom]),
389 Check::And(vec![Check::IsPair { can_be_truthy }, pair]),
390 ]),
391 };
392
393 Ok(check.simplify())
394}
395
396fn variants_of(arena: &Arena<Type>, builtins: &BuiltinTypes, id: TypeId) -> Vec<TypeId> {
397 match arena[id].clone() {
398 Type::Apply(_) => unreachable!(),
399 Type::Ref(id) => variants_of(arena, builtins, id),
400 Type::Unresolved | Type::Atom(_) | Type::Pair(_) | Type::Generic(_) => {
401 vec![id]
402 }
403 Type::Never => vec![],
404 Type::Alias(alias) => variants_of(arena, builtins, alias.inner),
405 Type::Struct(ty) => variants_of(arena, builtins, ty.inner),
406 Type::Function(_) | Type::Any => vec![builtins.atom, builtins.recursive_any_pair],
407 Type::Union(ty) => {
408 let mut variants = Vec::new();
409
410 for variant in ty.types {
411 variants.extend(variants_of(arena, builtins, variant));
412 }
413
414 variants
415 }
416 }
417}
418
419fn atoms_of(arena: &Arena<Type>, id: TypeId) -> Result<Option<Atoms>, CheckError> {
420 Ok(match arena[id].clone() {
421 Type::Apply(_) => unreachable!(),
422 Type::Ref(id) => atoms_of(arena, id)?,
423 Type::Unresolved | Type::Never | Type::Pair(_) => None,
424 Type::Generic(_) | Type::Any => Some(Atoms::Unrestricted),
425 Type::Atom(atom) => {
426 let Some(restriction) = atom.restriction else {
427 return Ok(Some(Atoms::Unrestricted));
428 };
429 Some(Atoms::Restricted(indexset![restriction]))
430 }
431 Type::Alias(alias) => atoms_of(arena, alias.inner)?,
432 Type::Struct(ty) => atoms_of(arena, ty.inner)?,
433 Type::Function(_) => return Err(CheckError::FunctionType),
434 Type::Union(ty) => {
435 let mut restrictions = IndexSet::new();
436
437 for variant in ty.types {
438 match atoms_of(arena, variant)? {
439 None => {}
440 Some(Atoms::Unrestricted) => return Ok(Some(Atoms::Unrestricted)),
441 Some(Atoms::Restricted(new)) => {
442 for restriction in new {
443 match &restriction {
444 AtomRestriction::Value(value) => {
445 if restrictions.contains(&AtomRestriction::Length(value.len()))
446 {
447 continue;
448 }
449 restrictions.insert(restriction);
450 }
451 AtomRestriction::Length(length) => {
452 restrictions.retain(|item| match item {
453 AtomRestriction::Value(value) => value.len() != *length,
454 AtomRestriction::Length(_) => true,
455 });
456 restrictions.insert(restriction);
457 }
458 }
459 }
460 }
461 }
462 }
463
464 if restrictions.is_empty() {
465 None
466 } else {
467 Some(Atoms::Restricted(restrictions))
468 }
469 }
470 })
471}
472
473fn pairs_of(
474 arena: &Arena<Type>,
475 builtins: &BuiltinTypes,
476 id: TypeId,
477) -> Result<Vec<Pair>, CheckError> {
478 Ok(match arena[id].clone() {
479 Type::Apply(_) => unreachable!(),
480 Type::Ref(id) => pairs_of(arena, builtins, id)?,
481 Type::Unresolved | Type::Never | Type::Atom(_) => vec![],
482 Type::Pair(pair) => vec![pair],
483 Type::Generic(_) | Type::Any => {
484 vec![Pair::new(builtins.recursive_any, builtins.recursive_any)]
485 }
486 Type::Alias(alias) => pairs_of(arena, builtins, alias.inner)?,
487 Type::Struct(ty) => pairs_of(arena, builtins, ty.inner)?,
488 Type::Function(_) => return Err(CheckError::FunctionType),
489 Type::Union(ty) => {
490 let mut pairs = Vec::new();
491
492 for variant in ty.types {
493 pairs.extend(pairs_of(arena, builtins, variant)?);
494 }
495
496 pairs
497 }
498 })
499}