petr_bind/
binder.rs

1use std::{collections::BTreeMap, rc::Rc};
2
3use petr_ast::{dependency::Dependency, Ast, Binding, ExprId, Expression, FunctionDeclaration, Ty, TypeDeclaration};
4use petr_utils::{idx_map_key, Identifier, IndexMap, Path, SpannedItem, SymbolId, SymbolInterner};
5// TODO:
6// - i don't know if type cons needs a scope. Might be good to remove that.
7// - replace "scope_chain.last().expect()" with "self.current_scope()" which doesn't return an option
8
9idx_map_key!(
10    /// The ID type of a Scope in the Binder.
11    ScopeId
12);
13
14idx_map_key!(
15    /// The ID type of a functoin parameter
16    FunctionParameterId
17);
18
19idx_map_key!(
20    /// The ID type of a function.
21    FunctionId
22);
23
24idx_map_key!(
25    /// The ID type of a variable binding.
26   BindingId
27);
28
29idx_map_key!(
30    /// The ID type of a module.
31   ModuleId
32);
33
34#[derive(Clone, Debug)]
35pub enum Item {
36    Binding(BindingId),
37    // the `ScopeId` is the scope of the function body
38    Function(FunctionId, ScopeId),
39    Type(petr_utils::TypeId),
40    FunctionParameter(Ty),
41    Module(ModuleId),
42    Import { path: Path, alias: Option<Identifier> },
43}
44
45pub struct Binder {
46    scopes:      IndexMap<ScopeId, Scope<SpannedItem<Item>>>,
47    scope_chain: Vec<ScopeId>,
48    /// Some expressions define their own scopes, like expressions with bindings
49    // TODO rename to expr_scopes
50    exprs: BTreeMap<ExprId, ScopeId>,
51    bindings:    IndexMap<BindingId, Binding>,
52    functions:   IndexMap<FunctionId, SpannedItem<FunctionDeclaration>>,
53    types:       IndexMap<petr_utils::TypeId, TypeDeclaration>,
54    modules:     IndexMap<ModuleId, Module>,
55    root_scope:  ScopeId,
56}
57
58#[derive(Debug)]
59pub struct Module {
60    pub root_scope: ScopeId,
61    pub exports:    BTreeMap<Identifier, Item>,
62}
63
64pub struct Scope<T> {
65    /// A `Scope` always has a parent, unless it is the root scope of the user code.
66    /// All scopes are descendents of one single root scope.
67    parent: Option<ScopeId>,
68    /// A mapping of the symbols that were declared in this scope. Note that any scopes that are
69    /// children of this scope inherit these symbols as well.
70    items:  BTreeMap<SymbolId, T>,
71    #[allow(dead_code)]
72    // this will be read but is also very useful for debugging
73    kind: ScopeKind,
74}
75
76/// Not used in the compiler heavily yet, but extremely useful for understanding what kind of scope
77/// you are in.
78#[derive(Clone, Copy, Debug)]
79pub enum ScopeKind {
80    /// A module scope. This is the top level scope for a module.
81    Module(Identifier),
82    /// A function scope. This is the scope of a function body. Notably, function scopes are where
83    /// all the function parameters are declared.
84    Function,
85    /// The root scope of the user code. There is only ever one ScopeKind::Root in a compilation.
86    /// All scopes are descendents of the root.
87    Root,
88    /// This might not be needed -- the scope within a type constructor function.
89    TypeConstructor,
90    /// For a let... expression, this is the scope of the expression and its bindings.
91    ExpressionWithBindings,
92}
93
94impl<T> Scope<T> {
95    pub fn insert(
96        &mut self,
97        k: SymbolId,
98        v: T,
99    ) {
100        // TODO: error handling and/or shadowing rules for this
101        if self.items.insert(k, v).is_some() {
102            todo!("throw error for overriding symbol name {k}")
103        }
104    }
105
106    pub fn parent(&self) -> Option<ScopeId> {
107        self.parent
108    }
109
110    pub fn iter(&self) -> impl Iterator<Item = (&SymbolId, &T)> {
111        self.items.iter()
112    }
113}
114
115impl Binder {
116    fn new() -> Self {
117        let mut scopes = IndexMap::default();
118        let root_scope = Scope {
119            parent: None,
120            items:  Default::default(),
121            kind:   ScopeKind::Root,
122        };
123        let root_scope = scopes.insert(root_scope);
124        Self {
125            scopes,
126            scope_chain: vec![root_scope],
127            root_scope,
128            functions: IndexMap::default(),
129            types: IndexMap::default(),
130            bindings: IndexMap::default(),
131            modules: IndexMap::default(),
132            exprs: BTreeMap::new(),
133        }
134    }
135
136    pub fn current_scope_id(&self) -> ScopeId {
137        *self.scope_chain.last().expect("there's always at least one scope")
138    }
139
140    pub fn get_function(
141        &self,
142        function_id: FunctionId,
143    ) -> &SpannedItem<FunctionDeclaration> {
144        self.functions.get(function_id)
145    }
146
147    pub fn get_type(
148        &self,
149        type_id: petr_utils::TypeId,
150    ) -> &TypeDeclaration {
151        self.types.get(type_id)
152    }
153
154    /// Searches for a symbol in a scope or any of its parents
155    pub fn find_symbol_in_scope(
156        &self,
157        name: SymbolId,
158        scope_id: ScopeId,
159    ) -> Option<&Item> {
160        self.find_spanned_symbol_in_scope(name, scope_id).map(|item| item.item())
161    }
162
163    /// Searches for a symbol in a scope or any of its parents
164    pub fn find_spanned_symbol_in_scope(
165        &self,
166        name: SymbolId,
167        scope_id: ScopeId,
168    ) -> Option<&SpannedItem<Item>> {
169        let scope = self.scopes.get(scope_id);
170        if let Some(item) = scope.items.get(&name) {
171            return Some(item);
172        }
173
174        if let Some(parent_id) = scope.parent() {
175            return self.find_spanned_symbol_in_scope(name, parent_id);
176        }
177
178        None
179    }
180
181    /// Iterate over all scopes in the binder.
182    pub fn scope_iter(&self) -> impl Iterator<Item = (ScopeId, &Scope<SpannedItem<Item>>)> {
183        self.scopes.iter()
184    }
185
186    pub fn insert_into_current_scope(
187        &mut self,
188        name: SymbolId,
189        item: SpannedItem<Item>,
190    ) {
191        let scope_id = self.current_scope_id();
192        self.scopes.get_mut(scope_id).insert(name, item);
193    }
194
195    fn push_scope(
196        &mut self,
197        kind: ScopeKind,
198    ) -> ScopeId {
199        let id = self.create_scope(kind);
200
201        self.scope_chain.push(id);
202
203        id
204    }
205
206    pub fn get_scope(
207        &self,
208        scope: ScopeId,
209    ) -> &Scope<SpannedItem<Item>> {
210        self.scopes.get(scope)
211    }
212
213    pub fn get_scope_kind(
214        &self,
215        scope: ScopeId,
216    ) -> ScopeKind {
217        self.scopes.get(scope).kind
218    }
219
220    fn pop_scope(&mut self) {
221        let _ = self.scope_chain.pop();
222    }
223
224    pub fn with_scope<F, R>(
225        &mut self,
226        kind: ScopeKind,
227        f: F,
228    ) -> R
229    where
230        F: FnOnce(&mut Self, ScopeId) -> R,
231    {
232        let id = self.push_scope(kind);
233        let res = f(self, id);
234        self.pop_scope();
235        res
236    }
237
238    /// TODO (https://github.com/sezna/petr/issues/33)
239    pub(crate) fn insert_type(
240        &mut self,
241        ty_decl: &SpannedItem<&TypeDeclaration>,
242    ) -> Option<(Identifier, Item)> {
243        // insert a function binding for every constructor
244        // and a type binding for the parent type
245        let type_id = self.types.insert((*ty_decl.item()).clone());
246        let type_item = Item::Type(type_id);
247        self.insert_into_current_scope(ty_decl.item().name.id, ty_decl.span().with_item(type_item.clone()));
248
249        ty_decl.item().variants.iter().for_each(|variant| {
250            let span = variant.span();
251            let variant = variant.item();
252            let (fields_as_parameters, _func_scope) = self.with_scope(ScopeKind::TypeConstructor, |_, scope| {
253                (
254                    variant
255                        .fields
256                        .iter()
257                        .map(|field| petr_ast::FunctionParameter {
258                            name: field.item().name,
259                            ty:   field.item().ty,
260                        })
261                        .collect::<Vec<_>>(),
262                    scope,
263                )
264            });
265            // type constructors just access the arguments of the construction function directly
266            let type_constructor_exprs = variant
267                .fields
268                .iter()
269                .map(|field| field.span().with_item(Expression::Variable(field.item().name)))
270                .collect::<Vec<_>>();
271
272            let function = FunctionDeclaration {
273                name:        variant.name,
274                parameters:  fields_as_parameters.into_boxed_slice(),
275                return_type: Ty::Named(ty_decl.item().name),
276                body:        span.with_item(Expression::TypeConstructor(type_id, type_constructor_exprs.into_boxed_slice())),
277                visibility:  ty_decl.item().visibility,
278            };
279
280            self.insert_function(&ty_decl.span().with_item(&function));
281        });
282        if ty_decl.item().is_exported() {
283            Some((ty_decl.item().name, type_item))
284        } else {
285            None
286        }
287    }
288
289    pub(crate) fn insert_function(
290        &mut self,
291        func: &SpannedItem<&FunctionDeclaration>,
292    ) -> Option<(Identifier, Item)> {
293        let span = func.span();
294        let func = func.item();
295        let function_id = self.functions.insert(span.with_item((*func).clone()));
296        let func_body_scope = self.with_scope(ScopeKind::Function, |binder, function_body_scope| {
297            for param in func.parameters.iter() {
298                binder.insert_into_current_scope(param.name.id, param.name.span().with_item(Item::FunctionParameter(param.ty)));
299            }
300
301            func.body.bind(binder);
302            function_body_scope
303        });
304        let item = Item::Function(function_id, func_body_scope);
305        self.insert_into_current_scope(func.name.id, span.with_item(item.clone()));
306        if func.is_exported() {
307            Some((func.name, item))
308        } else {
309            None
310        }
311    }
312
313    pub(crate) fn insert_binding(
314        &mut self,
315        binding: Binding,
316    ) -> BindingId {
317        self.bindings.insert(binding)
318    }
319
320    // TODO add optional prefix here:
321    // if Some(p) then this is a dependency, and p should be prepended to the path of each module
322    // If None then this is user code, and no prefix is needed
323    pub fn from_ast(ast: &Ast) -> Self {
324        let mut binder = Self::new();
325
326        ast.modules.iter().for_each(|module| {
327            let module_scope = binder.create_scope_from_path(&module.name);
328            binder.with_specified_scope(module_scope, |binder, scope_id| {
329                let exports = module.nodes.iter().filter_map(|node| match node.item() {
330                    petr_ast::AstNode::FunctionDeclaration(decl) => node.span().with_item(decl.item()).bind(binder),
331                    petr_ast::AstNode::TypeDeclaration(decl) => node.span().with_item(decl.item()).bind(binder),
332                    petr_ast::AstNode::ImportStatement(stmt) => stmt.bind(binder),
333                });
334                let exports = BTreeMap::from_iter(exports);
335                // we don't need to track this module ID -- it just needs to exist,
336                // and all modules are iterated over in later stages of the compiler.
337                // So we can safely ignore the return value here.
338                let _module_id = binder.modules.insert(Module {
339                    root_scope: scope_id,
340                    exports,
341                });
342            });
343        });
344
345        binder
346    }
347
348    pub fn from_ast_and_deps(
349        ast: &Ast,
350        dependencies: Vec<Dependency>,
351        interner: &mut SymbolInterner,
352    ) -> Self {
353        let mut binder = Self::new();
354
355        for Dependency {
356            key: _,
357            name,
358            dependencies: _,
359            ast: dep_ast,
360        } in dependencies
361        {
362            let span = dep_ast.span_pointing_to_beginning_of_ast();
363            let id = interner.insert(Rc::from(name));
364            let name = Identifier { id, span };
365            let dep_scope = binder.create_scope_from_path(&Path::new(vec![name]));
366            binder.with_specified_scope(dep_scope, |binder, _scope_id| {
367                for module in dep_ast.modules {
368                    let module_scope = binder.create_scope_from_path(&module.name);
369                    binder.with_specified_scope(module_scope, |binder, scope_id| {
370                        let exports = module.nodes.iter().filter_map(|node| match node.item() {
371                            petr_ast::AstNode::FunctionDeclaration(decl) => node.span().with_item(decl.item()).bind(binder),
372                            petr_ast::AstNode::TypeDeclaration(decl) => node.span().with_item(decl.item()).bind(binder),
373                            petr_ast::AstNode::ImportStatement(stmt) => stmt.bind(binder),
374                        });
375                        let exports = BTreeMap::from_iter(exports);
376                        // TODO do I need to track this module id?
377                        let _module_id = binder.modules.insert(Module {
378                            root_scope: scope_id,
379                            exports,
380                        });
381                    });
382                }
383            })
384        }
385
386        for module in &ast.modules {
387            let module_scope = binder.create_scope_from_path(&module.name);
388            binder.with_specified_scope(module_scope, |binder, scope_id| {
389                let exports = module.nodes.iter().filter_map(|node| match node.item() {
390                    petr_ast::AstNode::FunctionDeclaration(decl) => node.span().with_item(decl.item()).bind(binder),
391                    petr_ast::AstNode::TypeDeclaration(decl) => node.span().with_item(decl.item()).bind(binder),
392                    petr_ast::AstNode::ImportStatement(stmt) => stmt.bind(binder),
393                });
394                let exports = BTreeMap::from_iter(exports);
395                // TODO do I need to track this module id?
396                let _module_id = binder.modules.insert(Module {
397                    root_scope: scope_id,
398                    exports,
399                });
400            });
401        }
402
403        binder
404    }
405
406    /// given a path, create a scope for each segment. The last scope is returned.
407    /// e.g. for the path "a.b.c", create scopes for "a", "b", and "c", and return the scope for "c"
408    fn create_scope_from_path(
409        &mut self,
410        path: &Path,
411    ) -> ScopeId {
412        let mut current_scope_id = self.current_scope_id();
413        for segment in path.iter() {
414            // if this scope already exists,
415            // just use that pre-existing ID
416            if let Some(Item::Module(module_id)) = self.find_symbol_in_scope(segment.id, current_scope_id) {
417                current_scope_id = self.modules.get(*module_id).root_scope;
418                continue;
419            }
420
421            let next_scope = self.create_scope(ScopeKind::Module(*segment));
422            let module = Module {
423                root_scope: next_scope,
424                exports:    BTreeMap::new(),
425            };
426            let module_id = self.modules.insert(module);
427            self.insert_into_specified_scope(current_scope_id, *segment, Item::Module(module_id));
428            current_scope_id = next_scope
429        }
430        current_scope_id
431    }
432
433    pub fn insert_into_specified_scope(
434        &mut self,
435        scope: ScopeId,
436        name: Identifier,
437        item: Item,
438    ) {
439        let scope = self.scopes.get_mut(scope);
440        scope.insert(name.id, name.span.with_item(item));
441    }
442
443    pub fn get_module(
444        &self,
445        id: ModuleId,
446    ) -> &Module {
447        self.modules.get(id)
448    }
449
450    pub fn get_binding(
451        &self,
452        binding_id: BindingId,
453    ) -> &Binding {
454        self.bindings.get(binding_id)
455    }
456
457    pub fn create_scope(
458        &mut self,
459        kind: ScopeKind,
460    ) -> ScopeId {
461        let scope = Scope {
462            parent: Some(self.current_scope_id()),
463            items: BTreeMap::new(),
464            kind,
465        };
466        self.scopes.insert(scope)
467    }
468
469    fn with_specified_scope<F, R>(
470        &mut self,
471        scope: ScopeId,
472        f: F,
473    ) -> R
474    where
475        F: FnOnce(&mut Self, ScopeId) -> R,
476    {
477        let old_scope_chain = self.scope_chain.clone();
478        self.scope_chain = vec![self.root_scope, scope];
479        let res = f(self, scope);
480        self.scope_chain = old_scope_chain;
481        res
482    }
483
484    pub fn iter_scope(
485        &self,
486        scope: ScopeId,
487    ) -> impl Iterator<Item = (&SymbolId, &SpannedItem<Item>)> {
488        self.scopes.get(scope).items.iter()
489    }
490
491    pub fn insert_expression(
492        &mut self,
493        id: ExprId,
494        scope: ScopeId,
495    ) {
496        self.exprs.insert(id, scope);
497    }
498
499    pub fn get_expr_scope(
500        &self,
501        id: ExprId,
502    ) -> Option<ScopeId> {
503        self.exprs.get(&id).copied()
504    }
505}
506
507pub trait Bind {
508    type Output;
509    fn bind(
510        &self,
511        binder: &mut Binder,
512    ) -> Self::Output;
513}
514
515#[cfg(test)]
516mod tests {
517    fn check(
518        input: impl Into<String>,
519        expect: Expect,
520    ) {
521        let input = input.into();
522        let parser = petr_parse::Parser::new(vec![("test", input)]);
523        let (ast, errs, interner, source_map) = parser.into_result();
524        if !errs.is_empty() {
525            errs.into_iter().for_each(|err| eprintln!("{:?}", render_error(&source_map, err)));
526            panic!("fmt failed: code didn't parse");
527        }
528        let binder = Binder::from_ast(&ast);
529        let result = pretty_print_bindings(&binder, &interner);
530        expect.assert_eq(&result);
531    }
532
533    use expect_test::{expect, Expect};
534    use petr_utils::{render_error, SymbolInterner};
535
536    use super::*;
537    fn pretty_print_bindings(
538        binder: &Binder,
539        interner: &SymbolInterner,
540    ) -> String {
541        let mut result = String::new();
542        result.push_str("__Scopes__\n");
543        for (scope_id, scope) in binder.scopes.iter() {
544            result.push_str(&format!(
545                "{}: {} (parent {}):\n",
546                Into::<usize>::into(scope_id),
547                match scope.kind {
548                    ScopeKind::Module(name) => format!("Module {}", interner.get(name.id)),
549                    ScopeKind::Function => "Function".into(),
550                    ScopeKind::Root => "Root".into(),
551                    ScopeKind::TypeConstructor => "Type Cons".into(),
552                    ScopeKind::ExpressionWithBindings => "Expr w/ Bindings".into(),
553                },
554                scope.parent.map(|x| x.to_string()).unwrap_or_else(|| "none".into())
555            ));
556            for (symbol_id, item) in &scope.items {
557                let symbol_name = interner.get(*symbol_id);
558                let item_description = match item.item() {
559                    Item::Binding(bind_id) => format!("Binding {:?}", bind_id),
560                    Item::Function(function_id, _function_scope) => {
561                        format!("Function {:?}", function_id)
562                    },
563                    Item::Type(type_id) => format!("Type {:?}", type_id),
564                    Item::FunctionParameter(param) => {
565                        format!("FunctionParameter {:?}", param)
566                    },
567                    Item::Module(a) => {
568                        format!("Module {:?}", binder.modules.get(*a))
569                    },
570                    Item::Import { .. } => todo!(),
571                };
572                result.push_str(&format!("  {}: {}\n", symbol_name, item_description));
573            }
574        }
575        result
576    }
577
578    #[test]
579    fn bind_type_decl() {
580        check(
581            "type trinary_boolean = True | False | maybe ",
582            expect![[r#"
583                __Scopes__
584                0: Root (parent none):
585                  test: Module Module { root_scope: ScopeId(1), exports: {} }
586                1: Module test (parent scopeid0):
587                  trinary_boolean: Type TypeId(0)
588                  True: Function FunctionId(0)
589                  False: Function FunctionId(1)
590                  maybe: Function FunctionId(2)
591                2: Type Cons (parent scopeid1):
592                3: Function (parent scopeid1):
593                4: Type Cons (parent scopeid1):
594                5: Function (parent scopeid1):
595                6: Type Cons (parent scopeid1):
596                7: Function (parent scopeid1):
597            "#]],
598        );
599    }
600    #[test]
601    fn bind_function_decl() {
602        check(
603            "fn add(a in 'Int, b in 'Int) returns 'Int + 1 2",
604            expect![[r#"
605                __Scopes__
606                0: Root (parent none):
607                  test: Module Module { root_scope: ScopeId(1), exports: {} }
608                1: Module test (parent scopeid0):
609                  add: Function FunctionId(0)
610                2: Function (parent scopeid1):
611                  a: FunctionParameter Named(Identifier { id: SymbolId(3), span: Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(13), length: 3 } } })
612                  b: FunctionParameter Named(Identifier { id: SymbolId(3), span: Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(24), length: 3 } } })
613            "#]],
614        );
615    }
616
617    #[test]
618    fn bind_list_new_scope() {
619        check(
620            "fn add(a in 'Int, b in  'Int) returns 'Int [ 1, 2, 3, 4, 5, 6 ]",
621            expect![[r#"
622                __Scopes__
623                0: Root (parent none):
624                  test: Module Module { root_scope: ScopeId(1), exports: {} }
625                1: Module test (parent scopeid0):
626                  add: Function FunctionId(0)
627                2: Function (parent scopeid1):
628                  a: FunctionParameter Named(Identifier { id: SymbolId(3), span: Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(13), length: 3 } } })
629                  b: FunctionParameter Named(Identifier { id: SymbolId(3), span: Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(25), length: 3 } } })
630            "#]],
631        );
632    }
633}