1use std::{
2 cmp::Reverse,
3 collections::HashMap,
4 mem,
5 ops::{Deref, DerefMut, Range},
6 sync::Arc,
7};
8
9use rowan::TextRange;
10use rue_diagnostic::{Diagnostic, DiagnosticKind, Source, SourceKind, SrcLoc};
11use rue_hir::{
12 Builtins, Constraint, Database, Declaration, Scope, ScopeId, Symbol, SymbolId, TypePath, Value,
13 replace_type,
14};
15use rue_options::CompilerOptions;
16use rue_parser::{SyntaxNode, SyntaxToken};
17use rue_types::{Check, CheckError, Comparison, TypeId};
18
19#[derive(Debug, Clone)]
20pub struct Compiler {
21 _options: CompilerOptions,
22 source: Source,
23 diagnostics: Vec<Diagnostic>,
24 db: Database,
25 scope_stack: Vec<ScopeId>,
26 mapping_stack: Vec<HashMap<SymbolId, TypeId>>,
27 builtins: Builtins,
28 defaults: HashMap<TypeId, HashMap<String, Value>>,
29 declaration_stack: Vec<Declaration>,
30}
31
32impl Deref for Compiler {
33 type Target = Database;
34
35 fn deref(&self) -> &Self::Target {
36 &self.db
37 }
38}
39
40impl DerefMut for Compiler {
41 fn deref_mut(&mut self) -> &mut Self::Target {
42 &mut self.db
43 }
44}
45
46impl Compiler {
47 pub fn new(options: CompilerOptions) -> Self {
48 let mut db = Database::new();
49
50 let builtins = Builtins::new(&mut db);
51
52 Self {
53 _options: options,
54 source: Source::new(Arc::from(""), SourceKind::Std),
55 diagnostics: Vec::new(),
56 db,
57 scope_stack: vec![builtins.scope],
58 mapping_stack: vec![],
59 builtins,
60 defaults: HashMap::new(),
61 declaration_stack: Vec::new(),
62 }
63 }
64
65 pub fn source(&self) -> &Source {
66 &self.source
67 }
68
69 pub fn set_source(&mut self, source: Source) {
70 self.source = source;
71 }
72
73 pub fn take_diagnostics(&mut self) -> Vec<Diagnostic> {
74 mem::take(&mut self.diagnostics)
75 }
76
77 pub fn builtins(&self) -> &Builtins {
78 &self.builtins
79 }
80
81 pub fn diagnostic(&mut self, node: &impl GetTextRange, kind: DiagnosticKind) {
82 let range = node.text_range();
83 let span: Range<usize> = range.start().into()..range.end().into();
84 self.diagnostics.push(Diagnostic::new(
85 SrcLoc::new(self.source.clone(), span),
86 kind,
87 ));
88 }
89
90 pub fn push_scope(&mut self, scope: ScopeId) {
91 self.scope_stack.push(scope);
92 }
93
94 pub fn pop_scope(&mut self) {
95 self.scope_stack.pop().unwrap();
96 }
97
98 pub fn last_scope(&self) -> &Scope {
99 let scope = *self.scope_stack.last().unwrap();
100 self.scope(scope)
101 }
102
103 pub fn last_scope_mut(&mut self) -> &mut Scope {
104 let scope = *self.scope_stack.last().unwrap();
105 self.scope_mut(scope)
106 }
107
108 pub fn resolve_symbol(&self, name: &str) -> Option<SymbolId> {
109 for scope in self.scope_stack.iter().rev() {
110 if let Some(symbol) = self.scope(*scope).symbol(name) {
111 return Some(symbol);
112 }
113 }
114 None
115 }
116
117 pub fn resolve_type(&self, name: &str) -> Option<TypeId> {
118 for scope in self.scope_stack.iter().rev() {
119 if let Some(ty) = self.scope(*scope).ty(name) {
120 return Some(ty);
121 }
122 }
123 None
124 }
125
126 pub fn type_name(&mut self, ty: TypeId) -> String {
127 for scope in self.scope_stack.iter().rev() {
128 if let Some(name) = self.scope(*scope).type_name(ty) {
129 return name.to_string();
130 }
131 }
132
133 rue_types::stringify(self.db.types_mut(), ty)
134 }
135
136 pub fn symbol_type(&self, symbol: SymbolId) -> TypeId {
137 for map in self.mapping_stack.iter().rev() {
138 if let Some(ty) = map.get(&symbol) {
139 return *ty;
140 }
141 }
142
143 match self.symbol(symbol) {
144 Symbol::Unresolved | Symbol::Module(_) | Symbol::Builtin(_) => {
145 self.builtins().unresolved.ty
146 }
147 Symbol::Function(function) => function.ty,
148 Symbol::Parameter(parameter) => parameter.ty,
149 Symbol::Constant(constant) => constant.value.ty,
150 Symbol::Binding(binding) => binding.value.ty,
151 }
152 }
153
154 pub fn push_mappings(
155 &mut self,
156 mappings: HashMap<SymbolId, HashMap<Vec<TypePath>, TypeId>>,
157 ) -> usize {
158 let mut result = HashMap::new();
159
160 for (symbol, paths) in mappings {
161 let mut ty = self.symbol_type(symbol);
162
163 let mut paths = paths.into_iter().collect::<Vec<_>>();
164 paths.sort_by_key(|(path, _)| (Reverse(path.len()), path.last().copied()));
165
166 for (path, replacement) in paths {
167 ty = replace_type(&mut self.db, ty, replacement, &path);
168 }
169
170 result.insert(symbol, ty);
171 }
172
173 let index = self.mapping_stack.len();
174 self.mapping_stack.push(result);
175 index
176 }
177
178 pub fn mapping_checkpoint(&self) -> usize {
179 self.mapping_stack.len()
180 }
181
182 pub fn revert_mappings(&mut self, index: usize) {
183 self.mapping_stack.truncate(index);
184 }
185
186 pub fn is_assignable(&mut self, from: TypeId, to: TypeId) -> bool {
187 let comparison = rue_types::compare(self.db.types_mut(), &self.builtins.types, from, to);
188 comparison == Comparison::Assign
189 }
190
191 pub fn is_castable(&mut self, from: TypeId, to: TypeId) -> bool {
192 let comparison = rue_types::compare(self.db.types_mut(), &self.builtins.types, from, to);
193 matches!(comparison, Comparison::Assign | Comparison::Cast)
194 }
195
196 pub fn assign_type(&mut self, node: &impl GetTextRange, from: TypeId, to: TypeId) {
197 self.compare_type(node, from, to, false, None);
198 }
199
200 pub fn cast_type(&mut self, node: &impl GetTextRange, from: TypeId, to: TypeId) {
201 self.compare_type(node, from, to, true, None);
202 }
203
204 pub fn guard_type(&mut self, node: &impl GetTextRange, from: TypeId, to: TypeId) -> Constraint {
205 let check = match rue_types::check(self.db.types_mut(), &self.builtins.types, from, to) {
206 Ok(check) => check,
207 Err(CheckError::DepthExceeded) => {
208 self.diagnostic(node, DiagnosticKind::TypeCheckDepthExceeded);
209 return Constraint::new(Check::Impossible);
210 }
211 Err(CheckError::FunctionType) => {
212 self.diagnostic(node, DiagnosticKind::FunctionTypeCheck);
213 return Constraint::new(Check::Impossible);
214 }
215 };
216
217 let from_name = self.type_name(from);
218 let to_name = self.type_name(to);
219
220 if check == Check::None {
221 self.diagnostic(node, DiagnosticKind::UnnecessaryGuard(from_name, to_name));
222 } else if check == Check::Impossible {
223 self.diagnostic(node, DiagnosticKind::IncompatibleGuard(from_name, to_name));
224 }
225
226 let else_id = rue_types::subtract(self.db.types_mut(), &self.builtins.types, from, to);
227
228 Constraint::new(check).with_else(else_id)
229 }
230
231 pub fn check_condition(&mut self, node: &impl GetTextRange, ty: TypeId) {
232 if self.is_castable(ty, self.builtins().types.bool_true) {
233 self.diagnostic(node, DiagnosticKind::AlwaysTrueCondition);
234 } else if self.is_castable(ty, self.builtins().types.bool_false) {
235 self.diagnostic(node, DiagnosticKind::AlwaysFalseCondition);
236 } else {
237 self.assign_type(node, ty, self.builtins().types.bool);
238 }
239 }
240
241 pub fn infer_type(
242 &mut self,
243 node: &impl GetTextRange,
244 from: TypeId,
245 to: TypeId,
246 infer: &mut HashMap<TypeId, TypeId>,
247 ) {
248 self.compare_type(node, from, to, false, Some(infer));
249 }
250
251 fn compare_type(
252 &mut self,
253 node: &impl GetTextRange,
254 from: TypeId,
255 to: TypeId,
256 cast: bool,
257 infer: Option<&mut HashMap<TypeId, TypeId>>,
258 ) {
259 let comparison = rue_types::compare_with_inference(
260 self.db.types_mut(),
261 &self.builtins.types,
262 from,
263 to,
264 infer,
265 );
266
267 match comparison {
268 Comparison::Assign => {
269 if cast
270 && rue_types::compare_with_inference(
271 self.db.types_mut(),
272 &self.builtins.types,
273 to,
274 from,
275 None,
276 ) == Comparison::Assign
277 {
278 let from = self.type_name(from);
279 let to = self.type_name(to);
280
281 self.diagnostic(node, DiagnosticKind::UnnecessaryCast(from, to));
282 }
283 }
284 Comparison::Cast => {
285 if !cast {
286 let from = self.type_name(from);
287 let to = self.type_name(to);
288 self.diagnostic(node, DiagnosticKind::UnassignableType(from, to));
289 }
290 }
291 Comparison::Invalid => {
292 let check =
293 match rue_types::check(self.db.types_mut(), &self.builtins.types, from, to) {
294 Ok(check) => check,
295 Err(CheckError::DepthExceeded | CheckError::FunctionType) => {
296 Check::Impossible
297 }
298 };
299
300 let from = self.type_name(from);
301 let to = self.type_name(to);
302
303 if check != Check::Impossible {
304 self.diagnostic(node, DiagnosticKind::UnconstrainableComparison(from, to));
305 } else if cast {
306 self.diagnostic(node, DiagnosticKind::IncompatibleCast(from, to));
307 } else {
308 self.diagnostic(node, DiagnosticKind::IncompatibleType(from, to));
309 }
310 }
311 }
312 }
313
314 pub fn insert_default_field(&mut self, ty: TypeId, name: String, value: Value) {
315 self.defaults.entry(ty).or_default().insert(name, value);
316 }
317
318 pub fn default_field(&self, ty: TypeId, name: &str) -> Option<Value> {
319 self.defaults
320 .get(&ty)
321 .and_then(|map| map.get(name).cloned())
322 }
323
324 pub fn push_declaration(&mut self, declaration: Declaration) {
325 if let Some(last) = self.declaration_stack.last() {
326 self.db.add_declaration(*last, declaration);
327 }
328
329 self.declaration_stack.push(declaration);
330
331 if self.source.kind.check_unused() {
332 self.db.add_relevant_declaration(declaration);
333 }
334 }
335
336 pub fn pop_declaration(&mut self) {
337 self.declaration_stack.pop().unwrap();
338 }
339
340 pub fn reference(&mut self, reference: Declaration) {
341 if let Some(last) = self.declaration_stack.last() {
342 self.db.add_reference(*last, reference);
343 }
344 }
345}
346
347pub trait GetTextRange {
348 fn text_range(&self) -> TextRange;
349}
350
351impl GetTextRange for TextRange {
352 fn text_range(&self) -> TextRange {
353 *self
354 }
355}
356
357impl GetTextRange for SyntaxNode {
358 fn text_range(&self) -> TextRange {
359 self.text_range()
360 }
361}
362
363impl GetTextRange for SyntaxToken {
364 fn text_range(&self) -> TextRange {
365 self.text_range()
366 }
367}