rue_compiler/
compiler.rs

1use std::{
2    cmp::Reverse,
3    collections::HashMap,
4    ops::{Deref, DerefMut, Range},
5    sync::Arc,
6};
7
8use rowan::TextRange;
9use rue_diagnostic::{Diagnostic, DiagnosticKind, Source, SourceKind, SrcLoc};
10use rue_hir::{
11    Builtins, Constraint, Database, Scope, ScopeId, Symbol, SymbolId, TypePath, Value, replace_type,
12};
13use rue_options::CompilerOptions;
14use rue_parser::{SyntaxNode, SyntaxToken};
15use rue_types::{Check, CheckError, Comparison, TypeId};
16
17#[derive(Debug, Clone)]
18pub struct Compiler {
19    _options: CompilerOptions,
20    source: Source,
21    diagnostics: Vec<Diagnostic>,
22    db: Database,
23    scope_stack: Vec<ScopeId>,
24    mapping_stack: Vec<HashMap<SymbolId, TypeId>>,
25    builtins: Builtins,
26    defaults: HashMap<TypeId, HashMap<String, Value>>,
27}
28
29impl Deref for Compiler {
30    type Target = Database;
31
32    fn deref(&self) -> &Self::Target {
33        &self.db
34    }
35}
36
37impl DerefMut for Compiler {
38    fn deref_mut(&mut self) -> &mut Self::Target {
39        &mut self.db
40    }
41}
42
43impl Compiler {
44    pub fn new(options: CompilerOptions) -> Self {
45        let mut db = Database::new();
46
47        let builtins = Builtins::new(&mut db);
48
49        Self {
50            _options: options,
51            source: Source::new(Arc::from(""), SourceKind::Std),
52            diagnostics: Vec::new(),
53            db,
54            scope_stack: vec![builtins.scope],
55            mapping_stack: vec![],
56            builtins,
57            defaults: HashMap::new(),
58        }
59    }
60
61    pub fn diagnostics(&self) -> &[Diagnostic] {
62        &self.diagnostics
63    }
64
65    pub fn builtins(&self) -> &Builtins {
66        &self.builtins
67    }
68
69    pub fn set_source(&mut self, source: Source) {
70        self.source = source;
71    }
72
73    pub fn diagnostic(&mut self, node: &impl GetTextRange, kind: DiagnosticKind) {
74        let range = node.text_range();
75        let span: Range<usize> = range.start().into()..range.end().into();
76        self.diagnostics.push(Diagnostic::new(
77            SrcLoc::new(self.source.clone(), span),
78            kind,
79        ));
80    }
81
82    pub fn push_scope(&mut self, scope: ScopeId) {
83        self.scope_stack.push(scope);
84    }
85
86    pub fn pop_scope(&mut self) {
87        self.scope_stack.pop().unwrap();
88    }
89
90    pub fn last_scope(&self) -> &Scope {
91        let scope = *self.scope_stack.last().unwrap();
92        self.scope(scope)
93    }
94
95    pub fn last_scope_mut(&mut self) -> &mut Scope {
96        let scope = *self.scope_stack.last().unwrap();
97        self.scope_mut(scope)
98    }
99
100    pub fn resolve_symbol(&self, name: &str) -> Option<SymbolId> {
101        for scope in self.scope_stack.iter().rev() {
102            if let Some(symbol) = self.scope(*scope).symbol(name) {
103                return Some(symbol);
104            }
105        }
106        None
107    }
108
109    pub fn resolve_type(&self, name: &str) -> Option<TypeId> {
110        for scope in self.scope_stack.iter().rev() {
111            if let Some(ty) = self.scope(*scope).ty(name) {
112                return Some(ty);
113            }
114        }
115        None
116    }
117
118    pub fn type_name(&mut self, ty: TypeId) -> String {
119        for scope in self.scope_stack.iter().rev() {
120            if let Some(name) = self.scope(*scope).type_name(ty) {
121                return name.to_string();
122            }
123        }
124
125        rue_types::stringify(self.db.types_mut(), ty)
126    }
127
128    pub fn symbol_type(&self, symbol: SymbolId) -> TypeId {
129        for map in self.mapping_stack.iter().rev() {
130            if let Some(ty) = map.get(&symbol) {
131                return *ty;
132            }
133        }
134
135        match self.symbol(symbol) {
136            Symbol::Module(_) => self.builtins().unresolved.ty,
137            Symbol::Function(function) => function.ty,
138            Symbol::Parameter(parameter) => parameter.ty,
139            Symbol::Constant(constant) => constant.ty,
140            Symbol::Binding(binding) => binding.ty,
141        }
142    }
143
144    pub fn push_mappings(
145        &mut self,
146        mappings: HashMap<SymbolId, HashMap<Vec<TypePath>, TypeId>>,
147    ) -> usize {
148        let mut result = HashMap::new();
149
150        for (symbol, paths) in mappings {
151            let mut ty = self.symbol_type(symbol);
152
153            let mut paths = paths.into_iter().collect::<Vec<_>>();
154            paths.sort_by_key(|(path, _)| (Reverse(path.len()), path.last().copied()));
155
156            for (path, replacement) in paths {
157                ty = replace_type(&mut self.db, ty, replacement, &path);
158            }
159
160            result.insert(symbol, ty);
161        }
162
163        let index = self.mapping_stack.len();
164        self.mapping_stack.push(result);
165        index
166    }
167
168    pub fn mapping_checkpoint(&self) -> usize {
169        self.mapping_stack.len()
170    }
171
172    pub fn revert_mappings(&mut self, index: usize) {
173        self.mapping_stack.truncate(index);
174    }
175
176    pub fn is_assignable(&mut self, from: TypeId, to: TypeId) -> bool {
177        let comparison = rue_types::compare(self.db.types_mut(), &self.builtins.types, from, to);
178        comparison == Comparison::Assign
179    }
180
181    pub fn is_castable(&mut self, from: TypeId, to: TypeId) -> bool {
182        let comparison = rue_types::compare(self.db.types_mut(), &self.builtins.types, from, to);
183        matches!(comparison, Comparison::Assign | Comparison::Cast)
184    }
185
186    pub fn assign_type(&mut self, node: &impl GetTextRange, from: TypeId, to: TypeId) {
187        self.compare_type(node, from, to, false, None);
188    }
189
190    pub fn cast_type(&mut self, node: &impl GetTextRange, from: TypeId, to: TypeId) {
191        self.compare_type(node, from, to, true, None);
192    }
193
194    pub fn guard_type(&mut self, node: &impl GetTextRange, from: TypeId, to: TypeId) -> Constraint {
195        let check = match rue_types::check(self.db.types_mut(), &self.builtins.types, from, to) {
196            Ok(check) => check,
197            Err(CheckError::DepthExceeded) => {
198                self.diagnostic(node, DiagnosticKind::TypeCheckDepthExceeded);
199                return Constraint::new(Check::Impossible);
200            }
201            Err(CheckError::FunctionType) => {
202                self.diagnostic(node, DiagnosticKind::FunctionTypeCheck);
203                return Constraint::new(Check::Impossible);
204            }
205        };
206
207        let from_name = self.type_name(from);
208        let to_name = self.type_name(to);
209
210        if check == Check::None {
211            self.diagnostic(node, DiagnosticKind::UnnecessaryGuard(from_name, to_name));
212        } else if check == Check::Impossible {
213            self.diagnostic(node, DiagnosticKind::IncompatibleGuard(from_name, to_name));
214        }
215
216        let else_id = rue_types::subtract(self.db.types_mut(), &self.builtins.types, from, to);
217
218        Constraint::new(check).with_else(else_id)
219    }
220
221    pub fn infer_type(
222        &mut self,
223        node: &impl GetTextRange,
224        from: TypeId,
225        to: TypeId,
226        infer: &mut HashMap<TypeId, TypeId>,
227    ) {
228        self.compare_type(node, from, to, false, Some(infer));
229    }
230
231    fn compare_type(
232        &mut self,
233        node: &impl GetTextRange,
234        from: TypeId,
235        to: TypeId,
236        cast: bool,
237        infer: Option<&mut HashMap<TypeId, TypeId>>,
238    ) {
239        let comparison = rue_types::compare_with_inference(
240            self.db.types_mut(),
241            &self.builtins.types,
242            from,
243            to,
244            infer,
245        );
246
247        match comparison {
248            Comparison::Assign => {
249                if cast {
250                    let from = self.type_name(from);
251                    let to = self.type_name(to);
252                    self.diagnostic(node, DiagnosticKind::UnnecessaryCast(from, to));
253                }
254            }
255            Comparison::Cast => {
256                if !cast {
257                    let from = self.type_name(from);
258                    let to = self.type_name(to);
259                    self.diagnostic(node, DiagnosticKind::UnassignableType(from, to));
260                }
261            }
262            Comparison::Invalid => {
263                let check =
264                    match rue_types::check(self.db.types_mut(), &self.builtins.types, from, to) {
265                        Ok(check) => check,
266                        Err(CheckError::DepthExceeded | CheckError::FunctionType) => {
267                            Check::Impossible
268                        }
269                    };
270
271                let from = self.type_name(from);
272                let to = self.type_name(to);
273
274                if check != Check::Impossible {
275                    self.diagnostic(node, DiagnosticKind::UnconstrainableComparison(from, to));
276                } else if cast {
277                    self.diagnostic(node, DiagnosticKind::IncompatibleCast(from, to));
278                } else {
279                    self.diagnostic(node, DiagnosticKind::IncompatibleType(from, to));
280                }
281            }
282        }
283    }
284
285    pub fn insert_default_field(&mut self, ty: TypeId, name: String, value: Value) {
286        self.defaults.entry(ty).or_default().insert(name, value);
287    }
288
289    pub fn default_field(&self, ty: TypeId, name: &str) -> Option<Value> {
290        self.defaults
291            .get(&ty)
292            .and_then(|map| map.get(name).cloned())
293    }
294}
295
296pub trait GetTextRange {
297    fn text_range(&self) -> TextRange;
298}
299
300impl GetTextRange for SyntaxNode {
301    fn text_range(&self) -> TextRange {
302        self.text_range()
303    }
304}
305
306impl GetTextRange for SyntaxToken {
307    fn text_range(&self) -> TextRange {
308        self.text_range()
309    }
310}