1use std::{cmp::max, collections::HashSet};
2
3use id_arena::Arena;
4use indexmap::{IndexMap, IndexSet, indexset};
5use log::trace;
6use thiserror::Error;
7
8use crate::{
9 AtomRestriction, BuiltinTypes, Comparison, ComparisonContext, Pair, Type, TypeId, compare_impl,
10 stringify_impl, substitute,
11};
12
13#[derive(Debug, Clone, PartialEq, Eq)]
14#[must_use]
15pub enum Check {
16 None,
17 Impossible,
18 IsAtom,
19 IsPair,
20 Pair(Box<Check>, Box<Check>),
21 Atom(AtomRestriction),
22 And(Vec<Check>),
23 Or(Vec<Check>),
24}
25
26impl Check {
27 pub fn simplify(self) -> Check {
28 match self {
29 Check::None => Check::None,
30 Check::Impossible => Check::Impossible,
31 Check::Pair(first, rest) => {
32 let first = first.simplify();
33 let rest = rest.simplify();
34 Check::Pair(Box::new(first), Box::new(rest))
35 }
36 Check::IsAtom => Check::IsAtom,
37 Check::IsPair => Check::IsPair,
38 Check::Atom(restriction) => Check::Atom(restriction),
39 Check::And(checks) => {
40 let mut flattened = Vec::new();
41
42 for check in checks {
43 match check.simplify() {
44 Check::None => {}
45 Check::Impossible => {
46 return Check::Impossible;
47 }
48 Check::And(inner) => {
49 flattened.extend(inner);
50 }
51 check => {
52 flattened.push(check);
53 }
54 }
55 }
56
57 let mut listp = None;
58 let mut length = None;
59 let mut value = None;
60 let mut result = Vec::new();
61
62 for check in flattened {
63 match check {
64 Check::None | Check::Impossible | Check::And(_) => unreachable!(),
65 Check::IsAtom => {
66 if listp == Some(true) {
67 return Check::Impossible;
68 }
69 listp = Some(false);
70 }
71 Check::IsPair => {
72 if listp == Some(false) {
73 return Check::Impossible;
74 }
75 listp = Some(true);
76 }
77 Check::Atom(AtomRestriction::Length(check)) => {
78 if length.is_some_and(|length| length != check) {
79 return Check::Impossible;
80 }
81 length = Some(check);
82 }
83 Check::Atom(AtomRestriction::Value(check)) => {
84 if value.is_some_and(|value| value != check) {
85 return Check::Impossible;
86 }
87 value = Some(check);
88 }
89 check @ (Check::Or(_) | Check::Pair(..)) => {
90 result.push(check);
91 }
92 }
93 }
94
95 match (length, value) {
96 (Some(length), Some(value)) => {
97 if length != value.len() {
98 return Check::Impossible;
99 }
100 result.insert(0, Check::Atom(AtomRestriction::Value(value)));
101 }
102 (None, Some(value)) => {
103 result.insert(0, Check::Atom(AtomRestriction::Value(value)));
104 }
105 (Some(length), None) => {
106 result.insert(0, Check::Atom(AtomRestriction::Length(length)));
107 }
108 (None, None) => {}
109 }
110
111 match listp {
112 Some(true) => result.insert(0, Check::IsPair),
113 Some(false) => result.insert(0, Check::IsAtom),
114 None => {}
115 }
116
117 if result.is_empty() {
118 Check::None
119 } else if result.len() == 1 {
120 result[0].clone()
121 } else {
122 Check::And(result)
123 }
124 }
125 Check::Or(checks) => Check::Or(checks.into_iter().map(Check::simplify).collect()),
126 }
127 }
128}
129
130#[derive(Debug, Clone, Copy, Error)]
131pub enum CheckError {
132 #[error("Maximum type check depth reached")]
133 DepthExceeded,
134
135 #[error("Cannot check if value is of function type at runtime")]
136 FunctionType,
137}
138
139#[derive(Debug)]
140struct CheckContext {
141 depth: usize,
142}
143
144pub fn check(
145 arena: &mut Arena<Type>,
146 builtins: &BuiltinTypes,
147 lhs: TypeId,
148 rhs: TypeId,
149) -> Result<Check, CheckError> {
150 let lhs = substitute(arena, lhs);
151 let rhs = substitute(arena, rhs);
152 let lhs_name = stringify_impl(arena, lhs, &mut IndexMap::new());
153 let rhs_name = stringify_impl(arena, rhs, &mut IndexMap::new());
154 trace!("Checking {lhs_name} to {rhs_name}");
155 let result = check_impl(arena, builtins, &mut CheckContext { depth: 0 }, lhs, rhs);
156 trace!("Check from {lhs_name} to {rhs_name} yielded {result:?}");
157 result
158}
159
160fn check_impl(
161 arena: &Arena<Type>,
162 builtins: &BuiltinTypes,
163 ctx: &mut CheckContext,
164 lhs: TypeId,
165 rhs: TypeId,
166) -> Result<Check, CheckError> {
167 let mut variants = variants_of(arena, builtins, lhs)
168 .into_iter()
169 .enumerate()
170 .collect();
171 check_each(arena, builtins, ctx, &mut variants, rhs)
172}
173
174fn check_each(
175 arena: &Arena<Type>,
176 builtins: &BuiltinTypes,
177 ctx: &mut CheckContext,
178 lhs: &mut Vec<(usize, TypeId)>,
179 rhs: TypeId,
180) -> Result<Check, CheckError> {
181 ctx.depth += 1;
182
183 if ctx.depth > 25 {
184 return Err(CheckError::DepthExceeded);
185 }
186
187 let mut result = Comparison::Assign;
188
189 for &(_, lhs) in &*lhs {
190 result = max(
191 result,
192 compare_impl(
193 arena,
194 builtins,
195 &mut ComparisonContext::default(),
196 lhs,
197 rhs,
198 None,
199 None,
200 ),
201 );
202 }
203
204 if result <= Comparison::Cast {
205 return Ok(Check::None);
206 }
207
208 let target_atoms = atoms_of(arena, rhs)?;
209
210 let mut overlap = IndexSet::new();
211 let mut exceeds_overlap = false;
212 let mut unrestricted = false;
213 let mut lhs_has_atom = false;
214 let mut error = None;
215
216 lhs.retain(|&(_, id)| {
217 if error.is_some() {
218 return true;
219 }
220
221 let atoms = match atoms_of(arena, id) {
222 Ok(atoms) => atoms,
223 Err(err) => {
224 error = Some(err);
225 return true;
226 }
227 };
228
229 if atoms.is_some() {
230 lhs_has_atom = true;
231 }
232
233 match (atoms, &target_atoms) {
234 (Some(_), None) => false,
235 (Some(_), Some(Atoms::Unrestricted)) | (None, _) => true,
236 (Some(Atoms::Unrestricted), Some(Atoms::Restricted(restrictions))) => {
237 exceeds_overlap = true;
238 unrestricted = true;
239 overlap.clone_from(restrictions);
240 true
241 }
242 (
243 Some(Atoms::Restricted(restrictions)),
244 Some(Atoms::Restricted(target_restrictions)),
245 ) => {
246 let mut has_overlap = false;
247
248 for restriction in restrictions {
249 if target_restrictions.contains(&restriction) {
250 overlap.insert(restriction);
251 has_overlap = true;
252 continue;
253 }
254
255 match restriction {
256 AtomRestriction::Value(value) => {
257 let length = AtomRestriction::Length(value.len());
258 if target_restrictions.contains(&length) {
259 overlap.insert(length);
260 has_overlap = true;
261 continue;
262 }
263 }
264 AtomRestriction::Length(_) => {}
265 }
266
267 exceeds_overlap = true;
268 }
269
270 has_overlap
271 }
272 }
273 });
274
275 if let Some(error) = error {
276 return Err(error);
277 }
278
279 let atom_result = lhs_has_atom.then(|| {
280 if target_atoms.is_none() {
281 Check::Impossible
282 } else if !exceeds_overlap {
283 Check::None
284 } else if overlap.is_empty() {
285 Check::Impossible
286 } else if overlap.len() == 1 {
287 overlap.into_iter().next().map(Check::Atom).unwrap()
288 } else {
289 Check::And(overlap.into_iter().map(Check::Atom).collect())
290 }
291 });
292
293 let target_pairs = pairs_of(arena, builtins, rhs)?;
294
295 let mut checks = Vec::new();
296 let mut included_indices = IndexSet::new();
297 let mut candidate_pairs = HashSet::new();
298 let mut requires_check = false;
299 let mut lhs_has_pair = false;
300
301 let mut firsts = Vec::new();
302
303 for &(i, lhs) in &*lhs {
304 for pair in pairs_of(arena, builtins, lhs)? {
305 for ty in variants_of(arena, builtins, pair.first) {
306 candidate_pairs.insert(lhs);
307 firsts.push((i, ty));
308 lhs_has_pair = true;
309 }
310 }
311 }
312
313 for target_pair in &target_pairs {
314 let mut firsts = firsts.clone();
315
316 let first = check_each(arena, builtins, ctx, &mut firsts, target_pair.first)?;
317
318 if first == Check::Impossible {
319 requires_check = true;
320 continue;
321 }
322
323 let mut rests = Vec::new();
324
325 for (i, _) in firsts {
326 for pair in pairs_of(
327 arena,
328 builtins,
329 lhs.iter().find(|(j, _)| *j == i).unwrap().1,
330 )? {
331 for ty in variants_of(arena, builtins, pair.rest) {
332 rests.push((i, ty));
333 }
334 }
335 }
336
337 let rest = check_each(arena, builtins, ctx, &mut rests, target_pair.rest)?;
338
339 if rest == Check::Impossible {
340 requires_check = true;
341 continue;
342 }
343
344 for (i, _) in rests {
345 included_indices.insert(i);
346 }
347
348 if first == Check::None && rest == Check::None {
349 continue;
350 }
351
352 requires_check = true;
353
354 checks.push(Check::Pair(Box::new(first), Box::new(rest)));
355 }
356
357 lhs.retain(|&(i, type_id)| {
358 !candidate_pairs.contains(&type_id) || included_indices.contains(&i)
359 });
360
361 let pair_result = lhs_has_pair.then(|| {
362 if target_pairs.is_empty() {
363 Check::Impossible
364 } else if !requires_check {
365 Check::None
366 } else if checks.is_empty() {
367 Check::Impossible
368 } else if checks.len() == 1 {
369 checks[0].clone()
370 } else {
371 Check::Or(checks)
372 }
373 });
374
375 let check = match (atom_result, pair_result) {
376 (None, None) => Check::Impossible,
377 (Some(atom), None) => atom,
378 (None, Some(pair)) => pair,
379 (Some(atom), Some(Check::Impossible)) => Check::And(vec![Check::IsAtom, atom]),
380 (Some(Check::Impossible), Some(pair)) => Check::And(vec![Check::IsPair, pair]),
381 (Some(atom), Some(Check::None)) => Check::Or(vec![Check::IsPair, atom]),
382 (Some(Check::None), Some(pair)) => Check::Or(vec![Check::IsAtom, pair]),
383 (Some(atom), Some(pair)) => Check::Or(vec![
384 Check::And(vec![Check::IsAtom, atom]),
385 Check::And(vec![Check::IsPair, pair]),
386 ]),
387 };
388
389 Ok(check.simplify())
390}
391
392fn variants_of(arena: &Arena<Type>, builtins: &BuiltinTypes, id: TypeId) -> Vec<TypeId> {
393 match arena[id].clone() {
394 Type::Apply(_) => unreachable!(),
395 Type::Ref(id) => variants_of(arena, builtins, id),
396 Type::Unresolved | Type::Atom(_) | Type::Pair(_) | Type::Generic => {
397 vec![id]
398 }
399 Type::Never => vec![],
400 Type::Alias(alias) => variants_of(arena, builtins, alias.inner),
401 Type::Struct(ty) => variants_of(arena, builtins, ty.inner),
402 Type::Function(_) => vec![builtins.atom, builtins.any_pair],
403 Type::Union(ty) => {
404 let mut variants = Vec::new();
405
406 for variant in ty.types {
407 variants.extend(variants_of(arena, builtins, variant));
408 }
409
410 variants
411 }
412 }
413}
414
415#[derive(Debug, Clone)]
416enum Atoms {
417 Unrestricted,
418 Restricted(IndexSet<AtomRestriction>),
419}
420
421fn atoms_of(arena: &Arena<Type>, id: TypeId) -> Result<Option<Atoms>, CheckError> {
422 Ok(match arena[id].clone() {
423 Type::Apply(_) => unreachable!(),
424 Type::Ref(id) => atoms_of(arena, id)?,
425 Type::Unresolved | Type::Never | Type::Pair(_) => None,
426 Type::Generic => Some(Atoms::Unrestricted),
427 Type::Atom(atom) => {
428 let Some(restriction) = atom.restriction else {
429 return Ok(Some(Atoms::Unrestricted));
430 };
431 Some(Atoms::Restricted(indexset![restriction]))
432 }
433 Type::Alias(alias) => atoms_of(arena, alias.inner)?,
434 Type::Struct(ty) => atoms_of(arena, ty.inner)?,
435 Type::Function(_) => return Err(CheckError::FunctionType),
436 Type::Union(ty) => {
437 let mut restrictions = IndexSet::new();
438
439 for variant in ty.types {
440 match atoms_of(arena, variant)? {
441 None => {}
442 Some(Atoms::Unrestricted) => return Ok(Some(Atoms::Unrestricted)),
443 Some(Atoms::Restricted(new)) => {
444 for restriction in new {
445 match &restriction {
446 AtomRestriction::Value(value) => {
447 if restrictions.contains(&AtomRestriction::Length(value.len()))
448 {
449 continue;
450 }
451 restrictions.insert(restriction);
452 }
453 AtomRestriction::Length(length) => {
454 restrictions.retain(|item| match item {
455 AtomRestriction::Value(value) => value.len() != *length,
456 AtomRestriction::Length(_) => true,
457 });
458 restrictions.insert(restriction);
459 }
460 }
461 }
462 }
463 }
464 }
465
466 if restrictions.is_empty() {
467 None
468 } else {
469 Some(Atoms::Restricted(restrictions))
470 }
471 }
472 })
473}
474
475fn pairs_of(
476 arena: &Arena<Type>,
477 builtins: &BuiltinTypes,
478 id: TypeId,
479) -> Result<Vec<Pair>, CheckError> {
480 Ok(match arena[id].clone() {
481 Type::Apply(_) => unreachable!(),
482 Type::Ref(id) => pairs_of(arena, builtins, id)?,
483 Type::Unresolved | Type::Never | Type::Atom(_) => vec![],
484 Type::Pair(pair) => vec![pair],
485 Type::Generic => vec![Pair::new(builtins.any, builtins.any)],
486 Type::Alias(alias) => pairs_of(arena, builtins, alias.inner)?,
487 Type::Struct(ty) => pairs_of(arena, builtins, ty.inner)?,
488 Type::Function(_) => return Err(CheckError::FunctionType),
489 Type::Union(ty) => {
490 let mut pairs = Vec::new();
491
492 for variant in ty.types {
493 pairs.extend(pairs_of(arena, builtins, variant)?);
494 }
495
496 pairs
497 }
498 })
499}