1use id_arena::{Arena, Id};
2
3use crate::{
4 check_type, compare_type, debug_type, difference_type, replace_type, simplify_check,
5 stringify_type, substitute_type, Alias, Callable, Check, CheckError, Comparison,
6 ComparisonContext, HashMap, HashSet, StandardTypes, Type, TypePath,
7};
8
9pub type TypeId = Id<Type>;
10
11#[derive(Debug, Clone)]
12pub struct TypeSystem {
13 arena: Arena<Type>,
14 types: StandardTypes,
15 names: HashMap<TypeId, String>,
16}
17
18impl Default for TypeSystem {
19 fn default() -> Self {
20 let mut arena = Arena::new();
21
22 let unknown = arena.alloc(Type::Unknown);
23 let never = arena.alloc(Type::Never);
24 let any = arena.alloc(Type::Any);
25 let bytes = arena.alloc(Type::Bytes);
26 let bytes32 = arena.alloc(Type::Bytes32);
27 let public_key = arena.alloc(Type::PublicKey);
28 let int = arena.alloc(Type::Int);
29 let true_bool = arena.alloc(Type::True);
30 let false_bool = arena.alloc(Type::False);
31 let nil = arena.alloc(Type::Nil);
32 let bool = arena.alloc(Type::Union(vec![false_bool, true_bool]));
33
34 let generic_list_item = arena.alloc(Type::Generic);
35 let inner = arena.alloc(Type::Unknown);
36 let unmapped_list = arena.alloc(Type::Unknown);
37 arena[unmapped_list] = Type::Alias(Alias {
38 original_type_id: unmapped_list,
39 type_id: inner,
40 generic_types: vec![generic_list_item],
41 });
42 let pair = arena.alloc(Type::Pair(generic_list_item, unmapped_list));
43 arena[inner] = Type::Union(vec![pair, nil]);
44
45 let mut names = HashMap::new();
46 names.insert(never, "Never".to_string());
47 names.insert(any, "Any".to_string());
48 names.insert(bytes, "Bytes".to_string());
49 names.insert(bytes32, "Bytes32".to_string());
50 names.insert(public_key, "PublicKey".to_string());
51 names.insert(int, "Int".to_string());
52 names.insert(bool, "Bool".to_string());
53 names.insert(true_bool, "True".to_string());
54 names.insert(false_bool, "False".to_string());
55 names.insert(nil, "Nil".to_string());
56 names.insert(unmapped_list, "List".to_string());
57 names.insert(generic_list_item, "{item}".to_string());
58
59 Self {
60 arena,
61 types: StandardTypes {
62 unknown,
63 never,
64 any,
65 unmapped_list,
66 generic_list_item,
67 bytes,
68 bytes32,
69 public_key,
70 int,
71 bool,
72 true_bool,
73 false_bool,
74 nil,
75 },
76 names,
77 }
78 }
79}
80
81impl TypeSystem {
82 pub fn new() -> Self {
83 Self::default()
84 }
85
86 pub fn std(&self) -> StandardTypes {
87 self.types
88 }
89
90 pub fn alloc(&mut self, ty: Type) -> TypeId {
91 self.arena.alloc(ty)
92 }
93
94 pub fn get_raw(&self, type_id: TypeId) -> &Type {
95 &self.arena[type_id]
96 }
97
98 pub fn get_raw_mut(&mut self, type_id: TypeId) -> &mut Type {
99 &mut self.arena[type_id]
100 }
101
102 pub fn get_recursive(&self, type_id: TypeId) -> &Type {
103 match self.get(type_id) {
104 Type::Alias(ty) => self.get_recursive(ty.type_id),
105 Type::Struct(ty) => self.get_recursive(ty.type_id),
106 Type::Enum(ty) => self.get_recursive(ty.type_id),
107 Type::Variant(ty) => self.get_recursive(ty.type_id),
108 ty => ty,
109 }
110 }
111
112 pub fn get_unaliased(&self, type_id: TypeId) -> &Type {
113 match self.get(type_id) {
114 Type::Alias(ty) => self.get_unaliased(ty.type_id),
115 ty => ty,
116 }
117 }
118
119 pub fn get(&self, type_id: TypeId) -> &Type {
120 match &self.arena[type_id] {
121 Type::Ref(type_id) => self.get(*type_id),
122 ty => ty,
123 }
124 }
125
126 pub fn get_mut(&mut self, type_id: TypeId) -> &mut Type {
127 match &self.arena[type_id] {
128 Type::Ref(type_id) => self.get_mut(*type_id),
129 _ => &mut self.arena[type_id],
130 }
131 }
132
133 pub fn get_pair(&self, type_id: TypeId) -> Option<(TypeId, TypeId)> {
134 match self.get(type_id) {
135 Type::Pair(first, rest) => Some((*first, *rest)),
136 _ => None,
137 }
138 }
139
140 pub fn get_union(&self, type_id: TypeId) -> Option<&[TypeId]> {
141 match self.get(type_id) {
142 Type::Union(types) => Some(types),
143 _ => None,
144 }
145 }
146
147 pub fn get_callable(&self, type_id: TypeId) -> Option<&Callable> {
148 match self.get(type_id) {
149 Type::Callable(callable) => Some(callable),
150 _ => None,
151 }
152 }
153
154 pub fn get_callable_recursive(&mut self, type_id: TypeId) -> Option<&Callable> {
155 match self.get_recursive(type_id) {
156 Type::Callable(callable) => Some(callable),
157 _ => None,
158 }
159 }
160
161 pub fn stringify_named(&self, type_id: TypeId, mut names: HashMap<TypeId, String>) -> String {
162 for (id, name) in &self.names {
163 names.entry(*id).or_insert_with(|| name.clone());
164 }
165 stringify_type(self, type_id, &names, &mut HashSet::new())
166 }
167
168 pub fn stringify(&self, type_id: TypeId) -> String {
169 self.stringify_named(type_id, HashMap::new())
170 }
171
172 pub fn debug(&self, type_id: TypeId) -> String {
173 debug_type(self, "", type_id, 0, &mut HashSet::new())
174 }
175
176 pub fn compare(&self, lhs: TypeId, rhs: TypeId) -> Comparison {
177 self.compare_with_generics(lhs, rhs, &mut Vec::new(), false)
178 }
179
180 pub fn compare_with_generics(
181 &self,
182 lhs: TypeId,
183 rhs: TypeId,
184 substitution_stack: &mut Vec<HashMap<TypeId, TypeId>>,
185 infer_generics: bool,
186 ) -> Comparison {
187 compare_type(
188 self,
189 lhs,
190 rhs,
191 &mut ComparisonContext {
192 visited: HashSet::new(),
193 lhs_substitutions: Vec::new(),
194 rhs_substitutions: Vec::new(),
195 inferred: substitution_stack,
196 infer_generics,
197 },
198 )
199 }
200
201 pub fn substitute(
202 &mut self,
203 type_id: TypeId,
204 substitutions: HashMap<TypeId, TypeId>,
205 ) -> TypeId {
206 substitute_type(self, type_id, &mut vec![substitutions])
207 }
208
209 pub fn check(&mut self, lhs: TypeId, rhs: TypeId) -> Result<Check, CheckError> {
210 check_type(self, lhs, rhs, &mut HashSet::new()).map(simplify_check)
211 }
212
213 pub fn difference(&mut self, lhs: TypeId, rhs: TypeId) -> TypeId {
214 difference_type(self, lhs, rhs, &mut HashSet::new())
215 }
216
217 pub fn replace(
218 &mut self,
219 type_id: TypeId,
220 replace_type_id: TypeId,
221 path: &[TypePath],
222 ) -> TypeId {
223 replace_type(self, type_id, replace_type_id, path)
224 }
225}