1use id_arena::Arena;
2use indexmap::{IndexSet, indexset};
3
4use crate::{AtomRestriction, FunctionType, Pair, Type, TypeId, substitute};
5
6#[derive(Debug, Clone)]
7pub enum Atoms {
8 Unrestricted,
9 Restricted(IndexSet<AtomRestriction>),
10}
11
12pub fn extract_atoms(arena: &mut Arena<Type>, id: TypeId, strict: bool) -> Option<Atoms> {
13 let id = substitute(arena, id);
14 extract_atoms_impl(arena, id, strict)
15}
16
17fn extract_atoms_impl(arena: &Arena<Type>, id: TypeId, strict: bool) -> Option<Atoms> {
18 match arena[id].clone() {
19 Type::Apply(_) => unreachable!(),
20 Type::Ref(id) => extract_atoms_impl(arena, id, strict),
21 Type::Unresolved => Some(Atoms::Unrestricted),
22 Type::Generic(_) | Type::Never | Type::Function(_) | Type::Pair(_) | Type::Any => None,
23 Type::Atom(atom) => atom
24 .restriction
25 .map_or(Some(Atoms::Unrestricted), |restriction| {
26 Some(Atoms::Restricted(indexset![restriction]))
27 }),
28 Type::Struct(ty) => extract_atoms_impl(arena, ty.inner, strict),
29 Type::Alias(alias) => extract_atoms_impl(arena, alias.inner, strict),
30 Type::Union(ty) => {
31 let mut result = None;
32
33 for ty in ty.types {
34 let inner = extract_atoms_impl(arena, ty, strict);
35
36 let inner = if strict {
37 inner?
38 } else if let Some(inner) = inner {
39 inner
40 } else {
41 continue;
42 };
43
44 match (&result, &inner) {
45 (None, _) | (Some(_), Atoms::Unrestricted) => result = Some(inner),
46 (Some(Atoms::Unrestricted), _) => {}
47 (Some(Atoms::Restricted(restrictions)), Atoms::Restricted(inner)) => {
48 let mut restrictions = restrictions.clone();
49 restrictions.extend(inner.clone());
50 result = Some(Atoms::Restricted(restrictions));
51 }
52 }
53 }
54
55 result
56 }
57 }
58}
59
60pub fn extract_pairs(arena: &mut Arena<Type>, id: TypeId, strict: bool) -> Vec<Pair> {
61 let id = substitute(arena, id);
62 extract_pairs_impl(arena, id, strict).unwrap_or_default()
63}
64
65fn extract_pairs_impl(arena: &Arena<Type>, id: TypeId, strict: bool) -> Option<Vec<Pair>> {
66 match arena[id].clone() {
67 Type::Apply(_) => unreachable!(),
68 Type::Ref(id) => extract_pairs_impl(arena, id, strict),
69 Type::Unresolved => Some(vec![]),
70 Type::Generic(_) | Type::Atom(_) | Type::Function(_) | Type::Any => None,
71 Type::Never => Some(vec![]),
72 Type::Pair(pair) => Some(vec![pair]),
73 Type::Struct(ty) => extract_pairs_impl(arena, ty.inner, strict),
74 Type::Alias(alias) => extract_pairs_impl(arena, alias.inner, strict),
75 Type::Union(ty) => {
76 let mut pairs = Vec::new();
77
78 for ty in ty.types {
79 let inner = extract_pairs_impl(arena, ty, strict);
80
81 if strict {
82 pairs.extend(inner?);
83 } else {
84 pairs.extend(inner.unwrap_or_default());
85 }
86 }
87
88 Some(pairs)
89 }
90 }
91}
92
93pub fn extract_functions(arena: &mut Arena<Type>, id: TypeId) -> Vec<FunctionType> {
94 let id = substitute(arena, id);
95 extract_functions_impl(arena, id).unwrap_or_default()
96}
97
98fn extract_functions_impl(arena: &Arena<Type>, id: TypeId) -> Option<Vec<FunctionType>> {
99 match arena[id].clone() {
100 Type::Apply(_) => unreachable!(),
101 Type::Ref(id) => extract_functions_impl(arena, id),
102 Type::Unresolved | Type::Never => Some(vec![]),
103 Type::Generic(_) | Type::Atom(_) | Type::Pair(_) | Type::Any => None,
104 Type::Function(function) => Some(vec![function]),
105 Type::Struct(ty) => extract_functions_impl(arena, ty.inner),
106 Type::Alias(alias) => extract_functions_impl(arena, alias.inner),
107 Type::Union(ty) => {
108 let mut pairs = Vec::new();
109
110 for ty in ty.types {
111 pairs.extend(extract_functions_impl(arena, ty)?);
112 }
113
114 Some(pairs)
115 }
116 }
117}