claw_resolver/
function.rs

1use ast::{ExpressionId, NameId, Span, StatementId, TypeId};
2use claw_ast as ast;
3use claw_common::{Source, StackMap};
4
5use cranelift_entity::{entity_impl, EntityList, ListPool, PrimaryMap};
6use std::collections::{HashMap, VecDeque};
7
8#[cfg(test)]
9use miette::{Diagnostic, Report, SourceSpan};
10#[cfg(test)]
11use thiserror::Error;
12
13use crate::expression::*;
14use crate::imports::ImportResolver;
15use crate::statement::*;
16use crate::types::ResolvedType;
17use crate::{ItemId, ResolverError};
18
19pub(crate) struct FunctionResolver<'ctx> {
20    pub(crate) src: Source,
21    pub(crate) component: &'ctx ast::Component,
22    pub(crate) imports: &'ctx ImportResolver,
23    pub(crate) function: &'ctx ast::Function,
24
25    pub(crate) params: PrimaryMap<ParamId, TypeId>,
26
27    // Name Resolution
28    /// Entries for each unique local
29    pub(crate) locals: PrimaryMap<LocalId, LocalInfo>,
30    /// The span for each unique local
31    pub(crate) local_spans: HashMap<LocalId, Span>,
32    /// The association between identifiers and their subjects during resolving
33    pub(crate) mapping: StackMap<String, ItemId>,
34    /// The resolved bindings of expressions to subjects
35    pub(crate) bindings: HashMap<NameId, ItemId>,
36
37    // Type Resolution
38    resolver_queue: VecDeque<(ResolvedType, ResolverItem)>,
39
40    // The parent expression (if there is one) for each expression
41    pub(crate) expr_parent_map: HashMap<ExpressionId, ExpressionId>,
42    /// The type of each expression
43    pub(crate) expression_types: HashMap<ExpressionId, ResolvedType>,
44
45    local_uses_list_pool: ListPool<ExpressionId>,
46    // The expressions which use a given local
47    local_uses: HashMap<LocalId, EntityList<ExpressionId>>,
48
49    // Tye type of each local
50    pub local_types: HashMap<LocalId, ResolvedType>,
51}
52
53#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
54enum ResolverItem {
55    Local(LocalId),
56    Expression(ExpressionId),
57}
58
59#[allow(dead_code)]
60#[derive(Debug, Clone)]
61pub struct LocalInfo {
62    pub ident: NameId,
63    pub mutable: bool,
64    pub annotation: Option<TypeId>,
65}
66
67#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
68pub struct ParamId(u32);
69entity_impl!(ParamId, "param");
70
71#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
72pub struct LocalId(u32);
73entity_impl!(LocalId, "local");
74
75impl<'ctx> FunctionResolver<'ctx> {
76    pub(crate) fn new(
77        src: Source,
78        component: &'ctx ast::Component,
79        imports: &'ctx ImportResolver,
80        function: &'ctx ast::Function,
81        mappings: &'ctx HashMap<String, ItemId>,
82    ) -> Self {
83        let mut params = PrimaryMap::new();
84        let mut mapping: StackMap<String, ItemId> = mappings.clone().into();
85
86        for (ident, valtype) in function.params.iter() {
87            let param = params.push(*valtype);
88            let name = component.get_name(*ident).to_owned();
89            mapping.insert(name, ItemId::Param(param));
90        }
91
92        FunctionResolver {
93            src,
94            component,
95            imports,
96            function,
97            params,
98            mapping,
99            locals: Default::default(),
100            local_spans: Default::default(),
101            bindings: Default::default(),
102            resolver_queue: Default::default(),
103            expr_parent_map: Default::default(),
104            expression_types: Default::default(),
105            local_uses_list_pool: Default::default(),
106            local_uses: Default::default(),
107            local_types: Default::default(),
108        }
109    }
110
111    pub(crate) fn resolve(mut self) -> Result<ResolvedFunction, ResolverError> {
112        self.setup_block(&self.function.body)?;
113        self.resolve_types()?;
114
115        Ok(ResolvedFunction {
116            params: self.params,
117            locals: self.locals,
118            local_spans: self.local_spans,
119            local_types: self.local_types,
120            bindings: self.bindings,
121            expression_types: self.expression_types,
122        })
123    }
124
125    pub(crate) fn setup_block(&mut self, statements: &[StatementId]) -> Result<(), ResolverError> {
126        // Take a checkpoint at the state of the mappings before this block
127        let checkpoint = self.mapping.checkpoint();
128        // Resolve all of the inner statements
129        for statement in statements {
130            self.setup_statement(*statement)?;
131        }
132        // Restore the state of the mappings from before the block
133        self.mapping.restore(checkpoint);
134        Ok(())
135    }
136
137    pub(crate) fn setup_statement(&mut self, statement: StatementId) -> Result<(), ResolverError> {
138        self.component.get_statement(statement).setup_resolve(self)
139    }
140
141    pub(crate) fn setup_expression(
142        &mut self,
143        expression: ExpressionId,
144    ) -> Result<(), ResolverError> {
145        self.component
146            .expr()
147            .get_exp(expression)
148            .setup_resolve(expression, self)
149    }
150
151    pub(crate) fn setup_child_expression(
152        &mut self,
153        parent: ExpressionId,
154        child: ExpressionId,
155    ) -> Result<(), ResolverError> {
156        self.setup_expression(child)?;
157        self.expr_parent_map.insert(child, parent);
158        Ok(())
159    }
160
161    pub(crate) fn define_name(&mut self, ident: NameId, item: ItemId) -> Result<(), ResolverError> {
162        self.bindings.insert(ident, item);
163        let name = self.component.get_name(ident);
164        self.mapping.insert(name.to_owned(), item);
165        Ok(())
166    }
167
168    pub(crate) fn use_name(&mut self, ident: NameId) -> Result<ItemId, ResolverError> {
169        let name = self.component.get_name(ident);
170        let item = match self.mapping.lookup(&name.to_owned()) {
171            Some(item) => *item,
172            None => return self.name_error(ident),
173        };
174        self.bindings.insert(ident, item);
175        Ok(item)
176    }
177
178    pub(crate) fn lookup_name(&self, ident: NameId) -> Result<ItemId, ResolverError> {
179        match self.bindings.get(&ident) {
180            Some(item) => Ok(*item),
181            None => self.name_error(ident),
182        }
183    }
184
185    fn name_error<T>(&self, ident: NameId) -> Result<T, ResolverError> {
186        let span = self.component.name_span(ident);
187        let ident = self.component.get_name(ident).to_owned();
188        Err(ResolverError::NameError {
189            src: self.src.clone(),
190            span,
191            ident,
192        })
193    }
194
195    pub(crate) fn use_local(&mut self, local: LocalId, expression: ExpressionId) {
196        let existing_uses = self.local_uses.get_mut(&local);
197        if let Some(uses) = existing_uses {
198            uses.push(expression, &mut self.local_uses_list_pool);
199        } else {
200            let mut uses = EntityList::new();
201            uses.push(expression, &mut self.local_uses_list_pool);
202            self.local_uses.insert(local, uses);
203        }
204    }
205
206    pub(crate) fn set_expr_type(&mut self, id: ExpressionId, rtype: ResolvedType) {
207        self.resolver_queue
208            .push_back((rtype, ResolverItem::Expression(id)));
209    }
210
211    pub(crate) fn set_local_type(&mut self, id: LocalId, rtype: ResolvedType) {
212        self.resolver_queue
213            .push_back((rtype, ResolverItem::Local(id)));
214    }
215
216    fn resolve_types(&mut self) -> Result<(), ResolverError> {
217        while let Some((next_type, next_item)) = self.resolver_queue.pop_front() {
218            match next_item {
219                ResolverItem::Expression(expression) => {
220                    // Apply the inferred type and detect conflicts
221                    if let Some(existing_type) = self.expression_types.get(&expression) {
222                        if !next_type.type_eq(existing_type, self.component) {
223                            let span = self.component.expr().get_span(expression);
224                            return Err(ResolverError::TypeConflict {
225                                src: self.src.clone(),
226                                span,
227                                type_a: *existing_type,
228                                type_b: next_type,
229                            });
230                        } else {
231                            #[cfg(test)]
232                            self.notify_skipped_expression(expression);
233                            continue;
234                        }
235                    } else {
236                        self.expression_types.insert(expression, next_type);
237                    }
238
239                    #[cfg(test)]
240                    self.notify_resolved_expression(expression);
241
242                    let expression_val = self.component.expr().get_exp(expression);
243                    expression_val.on_resolved(next_type, expression, self)?;
244
245                    if let Some(parent_id) = self.expr_parent_map.get(&expression) {
246                        let parent = self.component.expr().get_exp(*parent_id);
247                        parent.on_child_resolved(next_type, *parent_id, self)?;
248                    } else {
249                        #[cfg(test)]
250                        self.notify_ophaned_expression(expression);
251                    }
252                }
253                ResolverItem::Local(local) => {
254                    if let Some(existing_type) = self.local_types.get(&local) {
255                        if !next_type.type_eq(existing_type, self.component) {
256                            panic!("Local type error!!!");
257                        } else {
258                            #[cfg(test)]
259                            self.notify_skipped_local(local);
260                            continue;
261                        }
262                    } else {
263                        self.local_types.insert(local, next_type);
264                    }
265
266                    #[cfg(test)]
267                    self.notify_resolved_local(local);
268
269                    if self.local_uses.contains_key(&local) {
270                        let uses_len = {
271                            let uses = self.local_uses.get(&local).unwrap();
272                            uses.len(&self.local_uses_list_pool)
273                        };
274                        for i in 0..uses_len {
275                            let local_use = {
276                                let uses = self.local_uses.get(&local).unwrap();
277                                uses.get(i, &self.local_uses_list_pool).unwrap()
278                            };
279                            self.set_expr_type(local_use, next_type);
280                        }
281                    }
282                }
283            }
284        }
285
286        Ok(())
287    }
288
289    #[cfg(test)]
290    fn notify_skipped_expression(&self, expression: ExpressionId) {
291        let src = self.src.clone();
292        let span = self.component.expr().get_span(expression);
293        let notification = Notification::ExpressionSkipped { src, span };
294        println!("{:?}", Report::new(notification));
295    }
296
297    #[cfg(test)]
298    fn notify_resolved_expression(&self, expression: ExpressionId) {
299        let src = self.src.clone();
300        let span = self.component.expr().get_span(expression);
301        let notification = Notification::ExpressionResolved { src, span };
302        println!("{:?}", Report::new(notification));
303    }
304
305    #[cfg(test)]
306    fn notify_ophaned_expression(&self, expression: ExpressionId) {
307        let src = self.src.clone();
308        let span = self.component.expr().get_span(expression);
309        let notification = Notification::ExpressionOrphan { src, span };
310        println!("{:?}", Report::new(notification));
311    }
312
313    #[cfg(test)]
314    fn notify_skipped_local(&self, local: LocalId) {
315        let src = self.src.clone();
316        let span = *self.local_spans.get(&local).unwrap();
317        let notification = Notification::LocalSkipped { src, span };
318        println!("{:?}", Report::new(notification));
319    }
320
321    #[cfg(test)]
322    fn notify_resolved_local(&self, local: LocalId) {
323        let src = self.src.clone();
324        let span = *self.local_spans.get(&local).unwrap();
325        let notification = Notification::LocalResolved { src, span };
326        println!("{:?}", Report::new(notification));
327    }
328}
329
330#[cfg(test)]
331#[derive(Error, Debug, Diagnostic)]
332pub enum Notification {
333    #[error("Skipping already resolved expression")]
334    ExpressionSkipped {
335        #[source_code]
336        src: Source,
337        #[label("here")]
338        span: SourceSpan,
339    },
340    #[error("Resolved type of expression")]
341    ExpressionResolved {
342        #[source_code]
343        src: Source,
344        #[label("here")]
345        span: SourceSpan,
346    },
347    #[error("No parent exists to be updated for")]
348    ExpressionOrphan {
349        #[source_code]
350        src: Source,
351        #[label("here")]
352        span: SourceSpan,
353    },
354    #[error("Skipping already resolved local")]
355    LocalSkipped {
356        #[source_code]
357        src: Source,
358        #[label("here")]
359        span: SourceSpan,
360    },
361    #[error("Resolved type of local")]
362    LocalResolved {
363        #[source_code]
364        src: Source,
365        #[label("here")]
366        span: SourceSpan,
367    },
368}
369
370pub struct ResolvedFunction {
371    pub params: PrimaryMap<ParamId, TypeId>,
372
373    /// Entries for each unique local
374    pub locals: PrimaryMap<LocalId, LocalInfo>,
375    /// The span for each unique local
376    pub local_spans: HashMap<LocalId, Span>,
377    // Tye type of each local
378    pub local_types: HashMap<LocalId, ResolvedType>,
379
380    /// The resolved bindings of expressions to subjects
381    pub bindings: HashMap<NameId, ItemId>,
382    /// The type of each expression
383    pub expression_types: HashMap<ExpressionId, ResolvedType>,
384}
385
386impl ResolvedFunction {
387    pub fn local_type(
388        &self,
389        local: LocalId,
390        context: &ast::Component,
391    ) -> Result<ResolvedType, ResolverError> {
392        let rtype = self.local_types.get(&local);
393        match rtype {
394            Some(rtype) => Ok(*rtype),
395            None => {
396                let span = self.local_spans.get(&local).unwrap().to_owned();
397                Err(ResolverError::Base {
398                    src: context.src.clone(),
399                    span,
400                })
401            }
402        }
403    }
404
405    pub fn expression_type(
406        &self,
407        expression: ExpressionId,
408        context: &ast::Component,
409    ) -> Result<ResolvedType, ResolverError> {
410        let rtype = self.expression_types.get(&expression);
411        match rtype {
412            Some(rtype) => Ok(*rtype),
413            None => {
414                let span = context.expr().get_span(expression);
415                Err(ResolverError::Base {
416                    src: context.src.clone(),
417                    span,
418                })
419            }
420        }
421    }
422}