rue_typing/
type_system.rs

1use id_arena::{Arena, Id};
2
3use crate::{
4    check_type, compare_type, debug_type, difference_type, replace_type, simplify_check,
5    stringify_type, substitute_type, Alias, Callable, Check, CheckError, Comparison,
6    ComparisonContext, HashMap, HashSet, StandardTypes, Type, TypePath,
7};
8
9pub type TypeId = Id<Type>;
10
11#[derive(Debug, Clone)]
12pub struct TypeSystem {
13    arena: Arena<Type>,
14    types: StandardTypes,
15    names: HashMap<TypeId, String>,
16}
17
18impl Default for TypeSystem {
19    fn default() -> Self {
20        let mut arena = Arena::new();
21
22        let unknown = arena.alloc(Type::Unknown);
23        let never = arena.alloc(Type::Never);
24        let any = arena.alloc(Type::Any);
25        let bytes = arena.alloc(Type::Bytes);
26        let bytes32 = arena.alloc(Type::Bytes32);
27        let public_key = arena.alloc(Type::PublicKey);
28        let int = arena.alloc(Type::Int);
29        let true_bool = arena.alloc(Type::True);
30        let false_bool = arena.alloc(Type::False);
31        let nil = arena.alloc(Type::Nil);
32        let bool = arena.alloc(Type::Union(vec![false_bool, true_bool]));
33
34        let generic_list_item = arena.alloc(Type::Generic);
35        let inner = arena.alloc(Type::Unknown);
36        let unmapped_list = arena.alloc(Type::Unknown);
37        arena[unmapped_list] = Type::Alias(Alias {
38            original_type_id: unmapped_list,
39            type_id: inner,
40            generic_types: vec![generic_list_item],
41        });
42        let pair = arena.alloc(Type::Pair(generic_list_item, unmapped_list));
43        arena[inner] = Type::Union(vec![pair, nil]);
44
45        let mut names = HashMap::new();
46        names.insert(never, "Never".to_string());
47        names.insert(any, "Any".to_string());
48        names.insert(bytes, "Bytes".to_string());
49        names.insert(bytes32, "Bytes32".to_string());
50        names.insert(public_key, "PublicKey".to_string());
51        names.insert(int, "Int".to_string());
52        names.insert(bool, "Bool".to_string());
53        names.insert(true_bool, "True".to_string());
54        names.insert(false_bool, "False".to_string());
55        names.insert(nil, "Nil".to_string());
56        names.insert(unmapped_list, "List".to_string());
57        names.insert(generic_list_item, "{item}".to_string());
58
59        Self {
60            arena,
61            types: StandardTypes {
62                unknown,
63                never,
64                any,
65                unmapped_list,
66                generic_list_item,
67                bytes,
68                bytes32,
69                public_key,
70                int,
71                bool,
72                true_bool,
73                false_bool,
74                nil,
75            },
76            names,
77        }
78    }
79}
80
81impl TypeSystem {
82    pub fn new() -> Self {
83        Self::default()
84    }
85
86    pub fn std(&self) -> StandardTypes {
87        self.types
88    }
89
90    pub fn alloc(&mut self, ty: Type) -> TypeId {
91        self.arena.alloc(ty)
92    }
93
94    pub fn get_raw(&self, type_id: TypeId) -> &Type {
95        &self.arena[type_id]
96    }
97
98    pub fn get_raw_mut(&mut self, type_id: TypeId) -> &mut Type {
99        &mut self.arena[type_id]
100    }
101
102    pub fn get_recursive(&self, type_id: TypeId) -> &Type {
103        match self.get(type_id) {
104            Type::Alias(ty) => self.get_recursive(ty.type_id),
105            Type::Struct(ty) => self.get_recursive(ty.type_id),
106            Type::Enum(ty) => self.get_recursive(ty.type_id),
107            Type::Variant(ty) => self.get_recursive(ty.type_id),
108            ty => ty,
109        }
110    }
111
112    pub fn get_unaliased(&self, type_id: TypeId) -> &Type {
113        match self.get(type_id) {
114            Type::Alias(ty) => self.get_unaliased(ty.type_id),
115            ty => ty,
116        }
117    }
118
119    pub fn get(&self, type_id: TypeId) -> &Type {
120        match &self.arena[type_id] {
121            Type::Ref(type_id) => self.get(*type_id),
122            ty => ty,
123        }
124    }
125
126    pub fn get_mut(&mut self, type_id: TypeId) -> &mut Type {
127        match &self.arena[type_id] {
128            Type::Ref(type_id) => self.get_mut(*type_id),
129            _ => &mut self.arena[type_id],
130        }
131    }
132
133    pub fn get_pair(&self, type_id: TypeId) -> Option<(TypeId, TypeId)> {
134        match self.get(type_id) {
135            Type::Pair(first, rest) => Some((*first, *rest)),
136            _ => None,
137        }
138    }
139
140    pub fn get_union(&self, type_id: TypeId) -> Option<&[TypeId]> {
141        match self.get(type_id) {
142            Type::Union(types) => Some(types),
143            _ => None,
144        }
145    }
146
147    pub fn get_callable(&self, type_id: TypeId) -> Option<&Callable> {
148        match self.get(type_id) {
149            Type::Callable(callable) => Some(callable),
150            _ => None,
151        }
152    }
153
154    pub fn get_callable_recursive(&mut self, type_id: TypeId) -> Option<&Callable> {
155        match self.get_recursive(type_id) {
156            Type::Callable(callable) => Some(callable),
157            _ => None,
158        }
159    }
160
161    pub fn stringify_named(&self, type_id: TypeId, mut names: HashMap<TypeId, String>) -> String {
162        for (id, name) in &self.names {
163            names.entry(*id).or_insert_with(|| name.clone());
164        }
165        stringify_type(self, type_id, &names, &mut HashSet::new())
166    }
167
168    pub fn stringify(&self, type_id: TypeId) -> String {
169        self.stringify_named(type_id, HashMap::new())
170    }
171
172    pub fn debug(&self, type_id: TypeId) -> String {
173        debug_type(self, "", type_id, 0, &mut HashSet::new())
174    }
175
176    pub fn compare(&self, lhs: TypeId, rhs: TypeId) -> Comparison {
177        self.compare_with_generics(lhs, rhs, &mut Vec::new(), false)
178    }
179
180    pub fn compare_with_generics(
181        &self,
182        lhs: TypeId,
183        rhs: TypeId,
184        substitution_stack: &mut Vec<HashMap<TypeId, TypeId>>,
185        infer_generics: bool,
186    ) -> Comparison {
187        compare_type(
188            self,
189            lhs,
190            rhs,
191            &mut ComparisonContext {
192                visited: HashSet::new(),
193                lhs_substitutions: Vec::new(),
194                rhs_substitutions: Vec::new(),
195                inferred: substitution_stack,
196                infer_generics,
197            },
198        )
199    }
200
201    pub fn substitute(
202        &mut self,
203        type_id: TypeId,
204        substitutions: HashMap<TypeId, TypeId>,
205    ) -> TypeId {
206        substitute_type(self, type_id, &mut vec![substitutions])
207    }
208
209    pub fn check(&mut self, lhs: TypeId, rhs: TypeId) -> Result<Check, CheckError> {
210        check_type(self, lhs, rhs, &mut HashSet::new()).map(simplify_check)
211    }
212
213    pub fn difference(&mut self, lhs: TypeId, rhs: TypeId) -> TypeId {
214        difference_type(self, lhs, rhs, &mut HashSet::new())
215    }
216
217    pub fn replace(
218        &mut self,
219        type_id: TypeId,
220        replace_type_id: TypeId,
221        path: &[TypePath],
222    ) -> TypeId {
223        replace_type(self, type_id, replace_type_id, path)
224    }
225}