1use std::{
2 cmp::Reverse,
3 collections::HashMap,
4 ops::{Deref, DerefMut, Range},
5 sync::Arc,
6};
7
8use rowan::TextRange;
9use rue_diagnostic::{Diagnostic, DiagnosticKind, Source, SourceKind, SrcLoc};
10use rue_hir::{
11 Builtins, Constraint, Database, Scope, ScopeId, Symbol, SymbolId, TypePath, Value, replace_type,
12};
13use rue_options::CompilerOptions;
14use rue_parser::{SyntaxNode, SyntaxToken};
15use rue_types::{Check, CheckError, Comparison, TypeId};
16
17#[derive(Debug, Clone)]
18pub struct Compiler {
19 _options: CompilerOptions,
20 source: Source,
21 diagnostics: Vec<Diagnostic>,
22 db: Database,
23 scope_stack: Vec<ScopeId>,
24 mapping_stack: Vec<HashMap<SymbolId, TypeId>>,
25 builtins: Builtins,
26 defaults: HashMap<TypeId, HashMap<String, Value>>,
27}
28
29impl Deref for Compiler {
30 type Target = Database;
31
32 fn deref(&self) -> &Self::Target {
33 &self.db
34 }
35}
36
37impl DerefMut for Compiler {
38 fn deref_mut(&mut self) -> &mut Self::Target {
39 &mut self.db
40 }
41}
42
43impl Compiler {
44 pub fn new(options: CompilerOptions) -> Self {
45 let mut db = Database::new();
46
47 let builtins = Builtins::new(&mut db);
48
49 Self {
50 _options: options,
51 source: Source::new(Arc::from(""), SourceKind::Std),
52 diagnostics: Vec::new(),
53 db,
54 scope_stack: vec![builtins.scope],
55 mapping_stack: vec![],
56 builtins,
57 defaults: HashMap::new(),
58 }
59 }
60
61 pub fn diagnostics(&self) -> &[Diagnostic] {
62 &self.diagnostics
63 }
64
65 pub fn builtins(&self) -> &Builtins {
66 &self.builtins
67 }
68
69 pub fn set_source(&mut self, source: Source) {
70 self.source = source;
71 }
72
73 pub fn diagnostic(&mut self, node: &impl GetTextRange, kind: DiagnosticKind) {
74 let range = node.text_range();
75 let span: Range<usize> = range.start().into()..range.end().into();
76 self.diagnostics.push(Diagnostic::new(
77 SrcLoc::new(self.source.clone(), span),
78 kind,
79 ));
80 }
81
82 pub fn push_scope(&mut self, scope: ScopeId) {
83 self.scope_stack.push(scope);
84 }
85
86 pub fn pop_scope(&mut self) {
87 self.scope_stack.pop().unwrap();
88 }
89
90 pub fn last_scope(&self) -> &Scope {
91 let scope = *self.scope_stack.last().unwrap();
92 self.scope(scope)
93 }
94
95 pub fn last_scope_mut(&mut self) -> &mut Scope {
96 let scope = *self.scope_stack.last().unwrap();
97 self.scope_mut(scope)
98 }
99
100 pub fn resolve_symbol(&self, name: &str) -> Option<SymbolId> {
101 for scope in self.scope_stack.iter().rev() {
102 if let Some(symbol) = self.scope(*scope).symbol(name) {
103 return Some(symbol);
104 }
105 }
106 None
107 }
108
109 pub fn resolve_type(&self, name: &str) -> Option<TypeId> {
110 for scope in self.scope_stack.iter().rev() {
111 if let Some(ty) = self.scope(*scope).ty(name) {
112 return Some(ty);
113 }
114 }
115 None
116 }
117
118 pub fn type_name(&mut self, ty: TypeId) -> String {
119 for scope in self.scope_stack.iter().rev() {
120 if let Some(name) = self.scope(*scope).type_name(ty) {
121 return name.to_string();
122 }
123 }
124
125 rue_types::stringify(self.db.types_mut(), ty)
126 }
127
128 pub fn symbol_type(&self, symbol: SymbolId) -> TypeId {
129 for map in self.mapping_stack.iter().rev() {
130 if let Some(ty) = map.get(&symbol) {
131 return *ty;
132 }
133 }
134
135 match self.symbol(symbol) {
136 Symbol::Module(_) => self.builtins().unresolved.ty,
137 Symbol::Function(function) => function.ty,
138 Symbol::Parameter(parameter) => parameter.ty,
139 Symbol::Constant(constant) => constant.ty,
140 Symbol::Binding(binding) => binding.ty,
141 }
142 }
143
144 pub fn push_mappings(
145 &mut self,
146 mappings: HashMap<SymbolId, HashMap<Vec<TypePath>, TypeId>>,
147 ) -> usize {
148 let mut result = HashMap::new();
149
150 for (symbol, paths) in mappings {
151 let mut ty = self.symbol_type(symbol);
152
153 let mut paths = paths.into_iter().collect::<Vec<_>>();
154 paths.sort_by_key(|(path, _)| (Reverse(path.len()), path.last().copied()));
155
156 for (path, replacement) in paths {
157 ty = replace_type(&mut self.db, ty, replacement, &path);
158 }
159
160 result.insert(symbol, ty);
161 }
162
163 let index = self.mapping_stack.len();
164 self.mapping_stack.push(result);
165 index
166 }
167
168 pub fn mapping_checkpoint(&self) -> usize {
169 self.mapping_stack.len()
170 }
171
172 pub fn revert_mappings(&mut self, index: usize) {
173 self.mapping_stack.truncate(index);
174 }
175
176 pub fn is_assignable(&mut self, from: TypeId, to: TypeId) -> bool {
177 let comparison = rue_types::compare(self.db.types_mut(), &self.builtins.types, from, to);
178 comparison == Comparison::Assign
179 }
180
181 pub fn is_castable(&mut self, from: TypeId, to: TypeId) -> bool {
182 let comparison = rue_types::compare(self.db.types_mut(), &self.builtins.types, from, to);
183 matches!(comparison, Comparison::Assign | Comparison::Cast)
184 }
185
186 pub fn assign_type(&mut self, node: &impl GetTextRange, from: TypeId, to: TypeId) {
187 self.compare_type(node, from, to, false, None);
188 }
189
190 pub fn cast_type(&mut self, node: &impl GetTextRange, from: TypeId, to: TypeId) {
191 self.compare_type(node, from, to, true, None);
192 }
193
194 pub fn guard_type(&mut self, node: &impl GetTextRange, from: TypeId, to: TypeId) -> Constraint {
195 let check = match rue_types::check(self.db.types_mut(), &self.builtins.types, from, to) {
196 Ok(check) => check,
197 Err(CheckError::DepthExceeded) => {
198 self.diagnostic(node, DiagnosticKind::TypeCheckDepthExceeded);
199 return Constraint::new(Check::Impossible);
200 }
201 Err(CheckError::FunctionType) => {
202 self.diagnostic(node, DiagnosticKind::FunctionTypeCheck);
203 return Constraint::new(Check::Impossible);
204 }
205 };
206
207 let from_name = self.type_name(from);
208 let to_name = self.type_name(to);
209
210 if check == Check::None {
211 self.diagnostic(node, DiagnosticKind::UnnecessaryGuard(from_name, to_name));
212 } else if check == Check::Impossible {
213 self.diagnostic(node, DiagnosticKind::IncompatibleGuard(from_name, to_name));
214 }
215
216 let else_id = rue_types::subtract(self.db.types_mut(), &self.builtins.types, from, to);
217
218 Constraint::new(check).with_else(else_id)
219 }
220
221 pub fn infer_type(
222 &mut self,
223 node: &impl GetTextRange,
224 from: TypeId,
225 to: TypeId,
226 infer: &mut HashMap<TypeId, TypeId>,
227 ) {
228 self.compare_type(node, from, to, false, Some(infer));
229 }
230
231 fn compare_type(
232 &mut self,
233 node: &impl GetTextRange,
234 from: TypeId,
235 to: TypeId,
236 cast: bool,
237 infer: Option<&mut HashMap<TypeId, TypeId>>,
238 ) {
239 let comparison = rue_types::compare_with_inference(
240 self.db.types_mut(),
241 &self.builtins.types,
242 from,
243 to,
244 infer,
245 );
246
247 match comparison {
248 Comparison::Assign => {
249 if cast {
250 let from = self.type_name(from);
251 let to = self.type_name(to);
252 self.diagnostic(node, DiagnosticKind::UnnecessaryCast(from, to));
253 }
254 }
255 Comparison::Cast => {
256 if !cast {
257 let from = self.type_name(from);
258 let to = self.type_name(to);
259 self.diagnostic(node, DiagnosticKind::UnassignableType(from, to));
260 }
261 }
262 Comparison::Invalid => {
263 let check =
264 match rue_types::check(self.db.types_mut(), &self.builtins.types, from, to) {
265 Ok(check) => check,
266 Err(CheckError::DepthExceeded | CheckError::FunctionType) => {
267 Check::Impossible
268 }
269 };
270
271 let from = self.type_name(from);
272 let to = self.type_name(to);
273
274 if check != Check::Impossible {
275 self.diagnostic(node, DiagnosticKind::UnconstrainableComparison(from, to));
276 } else if cast {
277 self.diagnostic(node, DiagnosticKind::IncompatibleCast(from, to));
278 } else {
279 self.diagnostic(node, DiagnosticKind::IncompatibleType(from, to));
280 }
281 }
282 }
283 }
284
285 pub fn insert_default_field(&mut self, ty: TypeId, name: String, value: Value) {
286 self.defaults.entry(ty).or_default().insert(name, value);
287 }
288
289 pub fn default_field(&self, ty: TypeId, name: &str) -> Option<Value> {
290 self.defaults
291 .get(&ty)
292 .and_then(|map| map.get(name).cloned())
293 }
294}
295
296pub trait GetTextRange {
297 fn text_range(&self) -> TextRange;
298}
299
300impl GetTextRange for SyntaxNode {
301 fn text_range(&self) -> TextRange {
302 self.text_range()
303 }
304}
305
306impl GetTextRange for SyntaxToken {
307 fn text_range(&self) -> TextRange {
308 self.text_range()
309 }
310}