rue_compiler/
compiler.rs

1use std::{
2    cmp::Reverse,
3    collections::HashMap,
4    mem,
5    ops::{Deref, DerefMut, Range},
6    sync::Arc,
7};
8
9use rowan::{TextRange, TextSize};
10use rue_diagnostic::{Diagnostic, DiagnosticKind, Source, SourceKind, SrcLoc};
11use rue_hir::{
12    Builtins, Constraint, Database, Declaration, Scope, ScopeId, Symbol, SymbolId, TypePath, Value,
13    replace_type,
14};
15use rue_options::CompilerOptions;
16use rue_parser::{SyntaxNode, SyntaxToken};
17use rue_types::{Check, CheckError, Comparison, Type, TypeId};
18
19use crate::{SyntaxItem, SyntaxItemKind, SyntaxMap};
20
21#[derive(Debug, Clone)]
22pub struct Compiler {
23    options: CompilerOptions,
24    source: Source,
25    diagnostics: Vec<Diagnostic>,
26    db: Database,
27    syntax_maps: HashMap<SourceKind, SyntaxMap>,
28    scope_stack: Vec<(TextSize, ScopeId)>,
29    builtins: Builtins,
30    defaults: HashMap<TypeId, HashMap<String, Value>>,
31    declaration_stack: Vec<Declaration>,
32}
33
34impl Deref for Compiler {
35    type Target = Database;
36
37    fn deref(&self) -> &Self::Target {
38        &self.db
39    }
40}
41
42impl DerefMut for Compiler {
43    fn deref_mut(&mut self) -> &mut Self::Target {
44        &mut self.db
45    }
46}
47
48impl Compiler {
49    pub fn new(options: CompilerOptions) -> Self {
50        let mut db = Database::new();
51
52        let builtins = Builtins::new(&mut db);
53
54        Self {
55            options,
56            source: Source::new(Arc::from(""), SourceKind::Std),
57            diagnostics: Vec::new(),
58            db,
59            syntax_maps: HashMap::new(),
60            scope_stack: vec![(TextSize::from(0), builtins.scope)],
61            builtins,
62            defaults: HashMap::new(),
63            declaration_stack: Vec::new(),
64        }
65    }
66
67    pub fn source(&self) -> &Source {
68        &self.source
69    }
70
71    pub fn options(&self) -> &CompilerOptions {
72        &self.options
73    }
74
75    #[allow(clippy::cast_possible_truncation)]
76    pub fn set_source(&mut self, source: Source) {
77        self.syntax_maps
78            .entry(source.kind.clone())
79            .or_default()
80            .add_item(SyntaxItem::new(
81                SyntaxItemKind::Scope(self.builtins.scope),
82                TextRange::new(TextSize::from(0), TextSize::from(source.text.len() as u32)),
83            ));
84        self.source = source;
85    }
86
87    pub fn syntax_map(&self, source_kind: &SourceKind) -> Option<&SyntaxMap> {
88        self.syntax_maps.get(source_kind)
89    }
90
91    pub fn syntax_map_mut(&mut self) -> &mut SyntaxMap {
92        self.syntax_maps
93            .entry(self.source.kind.clone())
94            .or_default()
95    }
96
97    pub fn take_diagnostics(&mut self) -> Vec<Diagnostic> {
98        mem::take(&mut self.diagnostics)
99    }
100
101    pub fn builtins(&self) -> &Builtins {
102        &self.builtins
103    }
104
105    pub fn diagnostic(&mut self, node: &impl GetTextRange, kind: DiagnosticKind) {
106        let range = node.text_range();
107        let span: Range<usize> = range.start().into()..range.end().into();
108        self.diagnostics.push(Diagnostic::new(
109            SrcLoc::new(self.source.clone(), span),
110            kind,
111        ));
112    }
113
114    pub fn alloc_child_scope(&mut self) -> ScopeId {
115        let parent_scope = self.last_scope_id();
116        self.alloc_scope(Scope::new(Some(parent_scope)))
117    }
118
119    pub fn push_scope(&mut self, scope: ScopeId, start: TextSize) {
120        self.scope_stack.push((start, scope));
121    }
122
123    pub fn pop_scope(&mut self, end: TextSize) {
124        let (start, scope) = self.scope_stack.pop().unwrap();
125
126        self.syntax_map_mut().add_item(SyntaxItem::new(
127            SyntaxItemKind::Scope(scope),
128            TextRange::new(start, end),
129        ));
130    }
131
132    pub fn last_scope(&self) -> &Scope {
133        let scope = *self.scope_stack.last().unwrap();
134        self.scope(scope.1)
135    }
136
137    pub fn last_scope_mut(&mut self) -> &mut Scope {
138        let scope = *self.scope_stack.last().unwrap();
139        self.scope_mut(scope.1)
140    }
141
142    pub fn last_scope_id(&self) -> ScopeId {
143        self.scope_stack.last().unwrap().1
144    }
145
146    pub fn resolve_symbol(&self, name: &str) -> Option<SymbolId> {
147        let mut current = Some(self.last_scope_id());
148
149        while let Some(scope) = current {
150            if let Some(symbol) = self.scope(scope).symbol(name) {
151                return Some(symbol);
152            }
153            current = self.scope(scope).parent();
154        }
155
156        None
157    }
158
159    pub fn resolve_type(&self, name: &str) -> Option<TypeId> {
160        let mut current = Some(self.last_scope_id());
161
162        while let Some(scope) = current {
163            if let Some(ty) = self.scope(scope).ty(name) {
164                return Some(ty);
165            }
166            current = self.scope(scope).parent();
167        }
168
169        None
170    }
171
172    pub fn type_name(&mut self, ty: TypeId) -> String {
173        let mut current = Some(self.last_scope_id());
174
175        while let Some(scope) = current {
176            if let Some(name) = self.scope(scope).type_name(ty) {
177                return name.to_string();
178            }
179            current = self.scope(scope).parent();
180        }
181
182        rue_types::stringify(self.db.types_mut(), ty)
183    }
184
185    pub fn symbol_name(&self, symbol: SymbolId) -> String {
186        let mut current = Some(self.last_scope_id());
187
188        while let Some(scope) = current {
189            if let Some(name) = self.scope(scope).symbol_name(symbol) {
190                return name.to_string();
191            }
192            current = self.scope(scope).parent();
193        }
194
195        match self.symbol(symbol) {
196            Symbol::Binding(binding) => binding.name.as_ref().map(|name| name.text().to_string()),
197            Symbol::Constant(constant) => {
198                constant.name.as_ref().map(|name| name.text().to_string())
199            }
200            Symbol::Function(function) => {
201                function.name.as_ref().map(|name| name.text().to_string())
202            }
203            Symbol::Module(module) => module.name.as_ref().map(|name| name.text().to_string()),
204            Symbol::Builtin(builtin) => Some(builtin.to_string()),
205            Symbol::Parameter(parameter) => {
206                parameter.name.as_ref().map(|name| name.text().to_string())
207            }
208            Symbol::Unresolved => None,
209        }
210        .unwrap_or("{unknown}".to_string())
211    }
212
213    pub fn symbol_type(&self, symbol: SymbolId) -> TypeId {
214        let mut current = Some(self.last_scope_id());
215
216        while let Some(scope) = current {
217            if let Some(ty) = self.scope(scope).symbol_override_type(symbol) {
218                return ty;
219            }
220            current = self.scope(scope).parent();
221        }
222
223        match self.symbol(symbol) {
224            Symbol::Unresolved | Symbol::Module(_) | Symbol::Builtin(_) => {
225                self.builtins().unresolved.ty
226            }
227            Symbol::Function(function) => function.ty,
228            Symbol::Parameter(parameter) => parameter.ty,
229            Symbol::Constant(constant) => constant.value.ty,
230            Symbol::Binding(binding) => binding.value.ty,
231        }
232    }
233
234    pub fn push_mappings(
235        &mut self,
236        mappings: HashMap<SymbolId, HashMap<Vec<TypePath>, TypeId>>,
237        start: TextSize,
238    ) -> usize {
239        let scope = self.alloc_child_scope();
240
241        for (symbol, paths) in mappings {
242            let mut ty = self.symbol_type(symbol);
243
244            let mut paths = paths.into_iter().collect::<Vec<_>>();
245            paths.sort_by_key(|(path, _)| (Reverse(path.len()), path.last().copied()));
246
247            for (path, replacement) in paths {
248                ty = replace_type(&mut self.db, ty, replacement, &path);
249            }
250
251            self.scope_mut(scope).override_symbol_type(symbol, ty);
252        }
253
254        let index = self.scope_stack.len();
255        self.push_scope(scope, start);
256        index
257    }
258
259    pub fn mapping_checkpoint(&self) -> usize {
260        self.scope_stack.len()
261    }
262
263    pub fn revert_mappings(&mut self, index: usize, end: TextSize) {
264        while self.scope_stack.len() > index {
265            self.pop_scope(end);
266        }
267    }
268
269    pub fn is_assignable(&mut self, from: TypeId, to: TypeId) -> bool {
270        let comparison = rue_types::compare(self.db.types_mut(), &self.builtins.types, from, to);
271        comparison == Comparison::Assign
272    }
273
274    pub fn is_castable(&mut self, from: TypeId, to: TypeId) -> bool {
275        let comparison = rue_types::compare(self.db.types_mut(), &self.builtins.types, from, to);
276        matches!(comparison, Comparison::Assign | Comparison::Cast)
277    }
278
279    pub fn assign_type(&mut self, node: &impl GetTextRange, from: TypeId, to: TypeId) {
280        self.compare_type(node, from, to, false, None);
281    }
282
283    pub fn cast_type(&mut self, node: &impl GetTextRange, from: TypeId, to: TypeId) {
284        self.compare_type(node, from, to, true, None);
285    }
286
287    pub fn guard_type(&mut self, node: &impl GetTextRange, from: TypeId, to: TypeId) -> Constraint {
288        let check = match rue_types::check(self.db.types_mut(), &self.builtins.types, from, to) {
289            Ok(check) => check,
290            Err(CheckError::DepthExceeded) => {
291                self.diagnostic(node, DiagnosticKind::TypeCheckDepthExceeded);
292                return Constraint::new(Check::Impossible);
293            }
294            Err(CheckError::FunctionType) => {
295                self.diagnostic(node, DiagnosticKind::FunctionTypeCheck);
296                return Constraint::new(Check::Impossible);
297            }
298        };
299
300        let from_name = self.type_name(from);
301        let to_name = self.type_name(to);
302
303        if check == Check::None {
304            self.diagnostic(node, DiagnosticKind::UnnecessaryGuard(from_name, to_name));
305        } else if check == Check::Impossible {
306            self.diagnostic(node, DiagnosticKind::IncompatibleGuard(from_name, to_name));
307        }
308
309        let else_id = rue_types::subtract(self.db.types_mut(), &self.builtins.types, from, to);
310
311        Constraint::new(check).with_else(else_id)
312    }
313
314    pub fn check_condition(&mut self, node: &impl GetTextRange, ty: TypeId) {
315        if self.is_castable(ty, self.builtins().types.bool_true) {
316            self.diagnostic(node, DiagnosticKind::AlwaysTrueCondition);
317        } else if self.is_castable(ty, self.builtins().types.bool_false) {
318            self.diagnostic(node, DiagnosticKind::AlwaysFalseCondition);
319        } else {
320            self.assign_type(node, ty, self.builtins().types.bool);
321        }
322    }
323
324    pub fn infer_type(
325        &mut self,
326        node: &impl GetTextRange,
327        from: TypeId,
328        to: TypeId,
329        infer: &mut HashMap<TypeId, TypeId>,
330    ) {
331        self.compare_type(node, from, to, false, Some(infer));
332    }
333
334    fn compare_type(
335        &mut self,
336        node: &impl GetTextRange,
337        from: TypeId,
338        to: TypeId,
339        cast: bool,
340        infer: Option<&mut HashMap<TypeId, TypeId>>,
341    ) {
342        let comparison = rue_types::compare_with_inference(
343            self.db.types_mut(),
344            &self.builtins.types,
345            from,
346            to,
347            infer,
348        );
349
350        match comparison {
351            Comparison::Assign => {
352                if cast
353                    && rue_types::compare_with_inference(
354                        self.db.types_mut(),
355                        &self.builtins.types,
356                        to,
357                        from,
358                        None,
359                    ) == Comparison::Assign
360                {
361                    let from = self.type_name(from);
362                    let to = self.type_name(to);
363
364                    self.diagnostic(node, DiagnosticKind::UnnecessaryCast(from, to));
365                }
366            }
367            Comparison::Cast => {
368                if !cast {
369                    let from = self.type_name(from);
370                    let to = self.type_name(to);
371                    self.diagnostic(node, DiagnosticKind::UnassignableType(from, to));
372                }
373            }
374            Comparison::Invalid => {
375                let check =
376                    match rue_types::check(self.db.types_mut(), &self.builtins.types, from, to) {
377                        Ok(check) => check,
378                        Err(CheckError::DepthExceeded | CheckError::FunctionType) => {
379                            Check::Impossible
380                        }
381                    };
382
383                let from = self.type_name(from);
384                let to = self.type_name(to);
385
386                if check != Check::Impossible {
387                    self.diagnostic(node, DiagnosticKind::UnconstrainableComparison(from, to));
388                } else if cast {
389                    self.diagnostic(node, DiagnosticKind::IncompatibleCast(from, to));
390                } else {
391                    self.diagnostic(node, DiagnosticKind::IncompatibleType(from, to));
392                }
393            }
394        }
395    }
396
397    pub fn insert_default_field(&mut self, ty: TypeId, name: String, value: Value) {
398        self.defaults.entry(ty).or_default().insert(name, value);
399    }
400
401    pub fn default_field(&self, ty: TypeId, name: &str) -> Option<Value> {
402        self.defaults
403            .get(&ty)
404            .and_then(|map| map.get(name).cloned())
405    }
406
407    pub fn push_declaration(&mut self, declaration: Declaration) {
408        if let Some(last) = self.declaration_stack.last() {
409            self.db.add_declaration(*last, declaration);
410        }
411
412        self.declaration_stack.push(declaration);
413
414        if self.source.kind.check_unused() {
415            self.db.add_relevant_declaration(declaration);
416        }
417    }
418
419    pub fn pop_declaration(&mut self) {
420        self.declaration_stack.pop().unwrap();
421    }
422
423    pub fn reference(&mut self, reference: Declaration) {
424        if let Some(last) = self.declaration_stack.last() {
425            self.db.add_reference(*last, reference);
426        }
427    }
428
429    pub fn declaration_span(&mut self, declaration: Declaration, span: TextRange) {
430        self.syntax_map_mut().add_item(SyntaxItem::new(
431            match declaration {
432                Declaration::Symbol(symbol) => SyntaxItemKind::SymbolDeclaration(symbol),
433                Declaration::Type(ty) => SyntaxItemKind::TypeDeclaration(ty),
434            },
435            span,
436        ));
437    }
438
439    pub fn reference_span(&mut self, reference: Declaration, span: TextRange) {
440        self.syntax_map_mut().add_item(SyntaxItem::new(
441            match reference {
442                Declaration::Symbol(symbol) => SyntaxItemKind::SymbolReference(symbol),
443                Declaration::Type(ty) => SyntaxItemKind::TypeReference(ty),
444            },
445            span,
446        ));
447    }
448
449    pub fn is_unresolved(&mut self, ty: TypeId) -> bool {
450        let semantic = rue_types::unwrap_semantic(self.db.types_mut(), ty, true);
451        matches!(self.ty(semantic), Type::Unresolved)
452    }
453}
454
455pub trait GetTextRange {
456    fn text_range(&self) -> TextRange;
457}
458
459impl GetTextRange for TextRange {
460    fn text_range(&self) -> TextRange {
461        *self
462    }
463}
464
465impl GetTextRange for SyntaxNode {
466    fn text_range(&self) -> TextRange {
467        self.text_range()
468    }
469}
470
471impl GetTextRange for SyntaxToken {
472    fn text_range(&self) -> TextRange {
473        self.text_range()
474    }
475}