1use id_arena::Arena;
2use indexmap::IndexMap;
3use log::trace;
4
5use crate::{
6 Alias, BuiltinTypes, Comparison, Struct, Type, TypeId, Union, compare, stringify_impl,
7 substitute,
8};
9
10pub fn subtract(
11 arena: &mut Arena<Type>,
12 builtins: &BuiltinTypes,
13 lhs_id: TypeId,
14 rhs_id: TypeId,
15) -> TypeId {
16 let lhs_id = substitute(arena, lhs_id);
17 let rhs_id = substitute(arena, rhs_id);
18 let lhs_name = stringify_impl(arena, lhs_id, &mut IndexMap::new());
19 let rhs_name = stringify_impl(arena, rhs_id, &mut IndexMap::new());
20 trace!("Subtracting {lhs_name} from {rhs_name}");
21 let result = subtract_impl(arena, builtins, lhs_id, rhs_id);
22 let new_name = stringify_impl(arena, result, &mut IndexMap::new());
23 trace!("Subtraction from {lhs_name} to {rhs_name} yielded {new_name}");
24 result
25}
26
27fn subtract_impl(
28 arena: &mut Arena<Type>,
29 builtins: &BuiltinTypes,
30 lhs_id: TypeId,
31 rhs_id: TypeId,
32) -> TypeId {
33 let mut lhs_variants = variants_of(arena, builtins, lhs_id);
34
35 lhs_variants
36 .retain(|variant| compare(arena, builtins, variant.type_id, rhs_id) > Comparison::Cast);
37
38 repackage(arena, lhs_variants)
39}
40
41fn repackage(arena: &mut Arena<Type>, variants: Vec<Variant>) -> TypeId {
42 let mut collected = vec![];
43 let mut semantic_type_id = None;
44 let mut leftover = vec![];
45
46 for mut variant in variants {
47 if let Some(id) = variant.semantic_type_ids.last().copied() {
48 if let Some(semantic_type_id) = semantic_type_id {
49 if id == semantic_type_id {
50 variant.semantic_type_ids.pop().unwrap();
51 collected.push(variant);
52 } else {
53 leftover.push(variant);
54 }
55 } else {
56 semantic_type_id = Some(id);
57 variant.semantic_type_ids.pop().unwrap();
58 collected.push(variant);
59 }
60 } else {
61 leftover.push(variant);
62 }
63 }
64
65 if collected.is_empty() {
66 let leftover: Vec<TypeId> = leftover
67 .into_iter()
68 .map(|variant| variant.type_id)
69 .collect();
70
71 if leftover.is_empty() {
72 return arena.alloc(Type::Never);
73 } else if leftover.len() == 1 {
74 return leftover[0];
75 }
76
77 return arena.alloc(Type::Union(Union::new(leftover)));
78 }
79
80 let semantic_type_id = semantic_type_id.unwrap();
81
82 let inner = repackage(arena, collected);
83
84 let result = match arena[semantic_type_id].clone() {
85 Type::Alias(alias) => arena.alloc(Type::Alias(Alias { inner, ..alias })),
86 Type::Struct(ty) => arena.alloc(Type::Struct(Struct { inner, ..ty })),
87 _ => unreachable!(),
88 };
89
90 if leftover.is_empty() {
91 return result;
92 }
93
94 let leftover = repackage(arena, leftover);
95
96 arena.alloc(Type::Union(Union::new(vec![result, leftover])))
97}
98
99struct Variant {
100 semantic_type_ids: Vec<TypeId>,
101 type_id: TypeId,
102}
103
104fn variants_of(arena: &Arena<Type>, builtins: &BuiltinTypes, id: TypeId) -> Vec<Variant> {
105 match arena[id].clone() {
106 Type::Apply(_) => unreachable!(),
107 Type::Ref(id) => variants_of(arena, builtins, id),
108 Type::Unresolved | Type::Atom(_) | Type::Pair(_) => vec![Variant {
109 semantic_type_ids: vec![],
110 type_id: id,
111 }],
112 Type::Never => vec![],
113 Type::Alias(alias) => {
114 let mut variants = variants_of(arena, builtins, alias.inner);
115
116 for variant in &mut variants {
117 variant.semantic_type_ids.push(id);
118 }
119
120 variants
121 }
122 Type::Struct(ty) => {
123 let mut variants = variants_of(arena, builtins, ty.inner);
124
125 for variant in &mut variants {
126 variant.semantic_type_ids.push(id);
127 }
128
129 variants
130 }
131 Type::Function(_) | Type::Generic => vec![
132 Variant {
133 semantic_type_ids: vec![],
134 type_id: builtins.atom,
135 },
136 Variant {
137 semantic_type_ids: vec![],
138 type_id: builtins.any_pair,
139 },
140 ],
141 Type::Union(ty) => {
142 let mut variants = Vec::new();
143
144 for variant in ty.types {
145 variants.extend(variants_of(arena, builtins, variant));
146 }
147
148 variants
149 }
150 }
151}