1use std::{collections::BTreeMap, rc::Rc};
2
3use petr_ast::{dependency::Dependency, Ast, Binding, ExprId, Expression, FunctionDeclaration, Ty, TypeDeclaration};
4use petr_utils::{idx_map_key, Identifier, IndexMap, Path, SpannedItem, SymbolId, SymbolInterner};
5idx_map_key!(
10 ScopeId
12);
13
14idx_map_key!(
15 FunctionParameterId
17);
18
19idx_map_key!(
20 FunctionId
22);
23
24idx_map_key!(
25 BindingId
27);
28
29idx_map_key!(
30 ModuleId
32);
33
34#[derive(Clone, Debug)]
35pub enum Item {
36 Binding(BindingId),
37 Function(FunctionId, ScopeId),
39 Type(petr_utils::TypeId),
40 FunctionParameter(Ty),
41 Module(ModuleId),
42 Import { path: Path, alias: Option<Identifier> },
43}
44
45pub struct Binder {
46 scopes: IndexMap<ScopeId, Scope<SpannedItem<Item>>>,
47 scope_chain: Vec<ScopeId>,
48 exprs: BTreeMap<ExprId, ScopeId>,
51 bindings: IndexMap<BindingId, Binding>,
52 functions: IndexMap<FunctionId, SpannedItem<FunctionDeclaration>>,
53 types: IndexMap<petr_utils::TypeId, TypeDeclaration>,
54 modules: IndexMap<ModuleId, Module>,
55 root_scope: ScopeId,
56}
57
58#[derive(Debug)]
59pub struct Module {
60 pub root_scope: ScopeId,
61 pub exports: BTreeMap<Identifier, Item>,
62}
63
64pub struct Scope<T> {
65 parent: Option<ScopeId>,
68 items: BTreeMap<SymbolId, T>,
71 #[allow(dead_code)]
72 kind: ScopeKind,
74}
75
76#[derive(Clone, Copy, Debug)]
79pub enum ScopeKind {
80 Module(Identifier),
82 Function,
85 Root,
88 TypeConstructor,
90 ExpressionWithBindings,
92}
93
94impl<T> Scope<T> {
95 pub fn insert(
96 &mut self,
97 k: SymbolId,
98 v: T,
99 ) {
100 if self.items.insert(k, v).is_some() {
102 todo!("throw error for overriding symbol name {k}")
103 }
104 }
105
106 pub fn parent(&self) -> Option<ScopeId> {
107 self.parent
108 }
109
110 pub fn iter(&self) -> impl Iterator<Item = (&SymbolId, &T)> {
111 self.items.iter()
112 }
113}
114
115impl Binder {
116 fn new() -> Self {
117 let mut scopes = IndexMap::default();
118 let root_scope = Scope {
119 parent: None,
120 items: Default::default(),
121 kind: ScopeKind::Root,
122 };
123 let root_scope = scopes.insert(root_scope);
124 Self {
125 scopes,
126 scope_chain: vec![root_scope],
127 root_scope,
128 functions: IndexMap::default(),
129 types: IndexMap::default(),
130 bindings: IndexMap::default(),
131 modules: IndexMap::default(),
132 exprs: BTreeMap::new(),
133 }
134 }
135
136 pub fn current_scope_id(&self) -> ScopeId {
137 *self.scope_chain.last().expect("there's always at least one scope")
138 }
139
140 pub fn get_function(
141 &self,
142 function_id: FunctionId,
143 ) -> &SpannedItem<FunctionDeclaration> {
144 self.functions.get(function_id)
145 }
146
147 pub fn get_type(
148 &self,
149 type_id: petr_utils::TypeId,
150 ) -> &TypeDeclaration {
151 self.types.get(type_id)
152 }
153
154 pub fn find_symbol_in_scope(
156 &self,
157 name: SymbolId,
158 scope_id: ScopeId,
159 ) -> Option<&Item> {
160 self.find_spanned_symbol_in_scope(name, scope_id).map(|item| item.item())
161 }
162
163 pub fn find_spanned_symbol_in_scope(
165 &self,
166 name: SymbolId,
167 scope_id: ScopeId,
168 ) -> Option<&SpannedItem<Item>> {
169 let scope = self.scopes.get(scope_id);
170 if let Some(item) = scope.items.get(&name) {
171 return Some(item);
172 }
173
174 if let Some(parent_id) = scope.parent() {
175 return self.find_spanned_symbol_in_scope(name, parent_id);
176 }
177
178 None
179 }
180
181 pub fn scope_iter(&self) -> impl Iterator<Item = (ScopeId, &Scope<SpannedItem<Item>>)> {
183 self.scopes.iter()
184 }
185
186 pub fn insert_into_current_scope(
187 &mut self,
188 name: SymbolId,
189 item: SpannedItem<Item>,
190 ) {
191 let scope_id = self.current_scope_id();
192 self.scopes.get_mut(scope_id).insert(name, item);
193 }
194
195 fn push_scope(
196 &mut self,
197 kind: ScopeKind,
198 ) -> ScopeId {
199 let id = self.create_scope(kind);
200
201 self.scope_chain.push(id);
202
203 id
204 }
205
206 pub fn get_scope(
207 &self,
208 scope: ScopeId,
209 ) -> &Scope<SpannedItem<Item>> {
210 self.scopes.get(scope)
211 }
212
213 pub fn get_scope_kind(
214 &self,
215 scope: ScopeId,
216 ) -> ScopeKind {
217 self.scopes.get(scope).kind
218 }
219
220 fn pop_scope(&mut self) {
221 let _ = self.scope_chain.pop();
222 }
223
224 pub fn with_scope<F, R>(
225 &mut self,
226 kind: ScopeKind,
227 f: F,
228 ) -> R
229 where
230 F: FnOnce(&mut Self, ScopeId) -> R,
231 {
232 let id = self.push_scope(kind);
233 let res = f(self, id);
234 self.pop_scope();
235 res
236 }
237
238 pub(crate) fn insert_type(
240 &mut self,
241 ty_decl: &SpannedItem<&TypeDeclaration>,
242 ) -> Option<(Identifier, Item)> {
243 let type_id = self.types.insert((*ty_decl.item()).clone());
246 let type_item = Item::Type(type_id);
247 self.insert_into_current_scope(ty_decl.item().name.id, ty_decl.span().with_item(type_item.clone()));
248
249 ty_decl.item().variants.iter().for_each(|variant| {
250 let span = variant.span();
251 let variant = variant.item();
252 let (fields_as_parameters, _func_scope) = self.with_scope(ScopeKind::TypeConstructor, |_, scope| {
253 (
254 variant
255 .fields
256 .iter()
257 .map(|field| petr_ast::FunctionParameter {
258 name: field.item().name,
259 ty: field.item().ty,
260 })
261 .collect::<Vec<_>>(),
262 scope,
263 )
264 });
265 let type_constructor_exprs = variant
267 .fields
268 .iter()
269 .map(|field| field.span().with_item(Expression::Variable(field.item().name)))
270 .collect::<Vec<_>>();
271
272 let function = FunctionDeclaration {
273 name: variant.name,
274 parameters: fields_as_parameters.into_boxed_slice(),
275 return_type: Ty::Named(ty_decl.item().name),
276 body: span.with_item(Expression::TypeConstructor(type_id, type_constructor_exprs.into_boxed_slice())),
277 visibility: ty_decl.item().visibility,
278 };
279
280 self.insert_function(&ty_decl.span().with_item(&function));
281 });
282 if ty_decl.item().is_exported() {
283 Some((ty_decl.item().name, type_item))
284 } else {
285 None
286 }
287 }
288
289 pub(crate) fn insert_function(
290 &mut self,
291 func: &SpannedItem<&FunctionDeclaration>,
292 ) -> Option<(Identifier, Item)> {
293 let span = func.span();
294 let func = func.item();
295 let function_id = self.functions.insert(span.with_item((*func).clone()));
296 let func_body_scope = self.with_scope(ScopeKind::Function, |binder, function_body_scope| {
297 for param in func.parameters.iter() {
298 binder.insert_into_current_scope(param.name.id, param.name.span().with_item(Item::FunctionParameter(param.ty)));
299 }
300
301 func.body.bind(binder);
302 function_body_scope
303 });
304 let item = Item::Function(function_id, func_body_scope);
305 self.insert_into_current_scope(func.name.id, span.with_item(item.clone()));
306 if func.is_exported() {
307 Some((func.name, item))
308 } else {
309 None
310 }
311 }
312
313 pub(crate) fn insert_binding(
314 &mut self,
315 binding: Binding,
316 ) -> BindingId {
317 self.bindings.insert(binding)
318 }
319
320 pub fn from_ast(ast: &Ast) -> Self {
324 let mut binder = Self::new();
325
326 ast.modules.iter().for_each(|module| {
327 let module_scope = binder.create_scope_from_path(&module.name);
328 binder.with_specified_scope(module_scope, |binder, scope_id| {
329 let exports = module.nodes.iter().filter_map(|node| match node.item() {
330 petr_ast::AstNode::FunctionDeclaration(decl) => node.span().with_item(decl.item()).bind(binder),
331 petr_ast::AstNode::TypeDeclaration(decl) => node.span().with_item(decl.item()).bind(binder),
332 petr_ast::AstNode::ImportStatement(stmt) => stmt.bind(binder),
333 });
334 let exports = BTreeMap::from_iter(exports);
335 let _module_id = binder.modules.insert(Module {
339 root_scope: scope_id,
340 exports,
341 });
342 });
343 });
344
345 binder
346 }
347
348 pub fn from_ast_and_deps(
349 ast: &Ast,
350 dependencies: Vec<Dependency>,
351 interner: &mut SymbolInterner,
352 ) -> Self {
353 let mut binder = Self::new();
354
355 for Dependency {
356 key: _,
357 name,
358 dependencies: _,
359 ast: dep_ast,
360 } in dependencies
361 {
362 let span = dep_ast.span_pointing_to_beginning_of_ast();
363 let id = interner.insert(Rc::from(name));
364 let name = Identifier { id, span };
365 let dep_scope = binder.create_scope_from_path(&Path::new(vec![name]));
366 binder.with_specified_scope(dep_scope, |binder, _scope_id| {
367 for module in dep_ast.modules {
368 let module_scope = binder.create_scope_from_path(&module.name);
369 binder.with_specified_scope(module_scope, |binder, scope_id| {
370 let exports = module.nodes.iter().filter_map(|node| match node.item() {
371 petr_ast::AstNode::FunctionDeclaration(decl) => node.span().with_item(decl.item()).bind(binder),
372 petr_ast::AstNode::TypeDeclaration(decl) => node.span().with_item(decl.item()).bind(binder),
373 petr_ast::AstNode::ImportStatement(stmt) => stmt.bind(binder),
374 });
375 let exports = BTreeMap::from_iter(exports);
376 let _module_id = binder.modules.insert(Module {
378 root_scope: scope_id,
379 exports,
380 });
381 });
382 }
383 })
384 }
385
386 for module in &ast.modules {
387 let module_scope = binder.create_scope_from_path(&module.name);
388 binder.with_specified_scope(module_scope, |binder, scope_id| {
389 let exports = module.nodes.iter().filter_map(|node| match node.item() {
390 petr_ast::AstNode::FunctionDeclaration(decl) => node.span().with_item(decl.item()).bind(binder),
391 petr_ast::AstNode::TypeDeclaration(decl) => node.span().with_item(decl.item()).bind(binder),
392 petr_ast::AstNode::ImportStatement(stmt) => stmt.bind(binder),
393 });
394 let exports = BTreeMap::from_iter(exports);
395 let _module_id = binder.modules.insert(Module {
397 root_scope: scope_id,
398 exports,
399 });
400 });
401 }
402
403 binder
404 }
405
406 fn create_scope_from_path(
409 &mut self,
410 path: &Path,
411 ) -> ScopeId {
412 let mut current_scope_id = self.current_scope_id();
413 for segment in path.iter() {
414 if let Some(Item::Module(module_id)) = self.find_symbol_in_scope(segment.id, current_scope_id) {
417 current_scope_id = self.modules.get(*module_id).root_scope;
418 continue;
419 }
420
421 let next_scope = self.create_scope(ScopeKind::Module(*segment));
422 let module = Module {
423 root_scope: next_scope,
424 exports: BTreeMap::new(),
425 };
426 let module_id = self.modules.insert(module);
427 self.insert_into_specified_scope(current_scope_id, *segment, Item::Module(module_id));
428 current_scope_id = next_scope
429 }
430 current_scope_id
431 }
432
433 pub fn insert_into_specified_scope(
434 &mut self,
435 scope: ScopeId,
436 name: Identifier,
437 item: Item,
438 ) {
439 let scope = self.scopes.get_mut(scope);
440 scope.insert(name.id, name.span.with_item(item));
441 }
442
443 pub fn get_module(
444 &self,
445 id: ModuleId,
446 ) -> &Module {
447 self.modules.get(id)
448 }
449
450 pub fn get_binding(
451 &self,
452 binding_id: BindingId,
453 ) -> &Binding {
454 self.bindings.get(binding_id)
455 }
456
457 pub fn create_scope(
458 &mut self,
459 kind: ScopeKind,
460 ) -> ScopeId {
461 let scope = Scope {
462 parent: Some(self.current_scope_id()),
463 items: BTreeMap::new(),
464 kind,
465 };
466 self.scopes.insert(scope)
467 }
468
469 fn with_specified_scope<F, R>(
470 &mut self,
471 scope: ScopeId,
472 f: F,
473 ) -> R
474 where
475 F: FnOnce(&mut Self, ScopeId) -> R,
476 {
477 let old_scope_chain = self.scope_chain.clone();
478 self.scope_chain = vec![self.root_scope, scope];
479 let res = f(self, scope);
480 self.scope_chain = old_scope_chain;
481 res
482 }
483
484 pub fn iter_scope(
485 &self,
486 scope: ScopeId,
487 ) -> impl Iterator<Item = (&SymbolId, &SpannedItem<Item>)> {
488 self.scopes.get(scope).items.iter()
489 }
490
491 pub fn insert_expression(
492 &mut self,
493 id: ExprId,
494 scope: ScopeId,
495 ) {
496 self.exprs.insert(id, scope);
497 }
498
499 pub fn get_expr_scope(
500 &self,
501 id: ExprId,
502 ) -> Option<ScopeId> {
503 self.exprs.get(&id).copied()
504 }
505}
506
507pub trait Bind {
508 type Output;
509 fn bind(
510 &self,
511 binder: &mut Binder,
512 ) -> Self::Output;
513}
514
515#[cfg(test)]
516mod tests {
517 fn check(
518 input: impl Into<String>,
519 expect: Expect,
520 ) {
521 let input = input.into();
522 let parser = petr_parse::Parser::new(vec![("test", input)]);
523 let (ast, errs, interner, source_map) = parser.into_result();
524 if !errs.is_empty() {
525 errs.into_iter().for_each(|err| eprintln!("{:?}", render_error(&source_map, err)));
526 panic!("fmt failed: code didn't parse");
527 }
528 let binder = Binder::from_ast(&ast);
529 let result = pretty_print_bindings(&binder, &interner);
530 expect.assert_eq(&result);
531 }
532
533 use expect_test::{expect, Expect};
534 use petr_utils::{render_error, SymbolInterner};
535
536 use super::*;
537 fn pretty_print_bindings(
538 binder: &Binder,
539 interner: &SymbolInterner,
540 ) -> String {
541 let mut result = String::new();
542 result.push_str("__Scopes__\n");
543 for (scope_id, scope) in binder.scopes.iter() {
544 result.push_str(&format!(
545 "{}: {} (parent {}):\n",
546 Into::<usize>::into(scope_id),
547 match scope.kind {
548 ScopeKind::Module(name) => format!("Module {}", interner.get(name.id)),
549 ScopeKind::Function => "Function".into(),
550 ScopeKind::Root => "Root".into(),
551 ScopeKind::TypeConstructor => "Type Cons".into(),
552 ScopeKind::ExpressionWithBindings => "Expr w/ Bindings".into(),
553 },
554 scope.parent.map(|x| x.to_string()).unwrap_or_else(|| "none".into())
555 ));
556 for (symbol_id, item) in &scope.items {
557 let symbol_name = interner.get(*symbol_id);
558 let item_description = match item.item() {
559 Item::Binding(bind_id) => format!("Binding {:?}", bind_id),
560 Item::Function(function_id, _function_scope) => {
561 format!("Function {:?}", function_id)
562 },
563 Item::Type(type_id) => format!("Type {:?}", type_id),
564 Item::FunctionParameter(param) => {
565 format!("FunctionParameter {:?}", param)
566 },
567 Item::Module(a) => {
568 format!("Module {:?}", binder.modules.get(*a))
569 },
570 Item::Import { .. } => todo!(),
571 };
572 result.push_str(&format!(" {}: {}\n", symbol_name, item_description));
573 }
574 }
575 result
576 }
577
578 #[test]
579 fn bind_type_decl() {
580 check(
581 "type trinary_boolean = True | False | maybe ",
582 expect![[r#"
583 __Scopes__
584 0: Root (parent none):
585 test: Module Module { root_scope: ScopeId(1), exports: {} }
586 1: Module test (parent scopeid0):
587 trinary_boolean: Type TypeId(0)
588 True: Function FunctionId(0)
589 False: Function FunctionId(1)
590 maybe: Function FunctionId(2)
591 2: Type Cons (parent scopeid1):
592 3: Function (parent scopeid1):
593 4: Type Cons (parent scopeid1):
594 5: Function (parent scopeid1):
595 6: Type Cons (parent scopeid1):
596 7: Function (parent scopeid1):
597 "#]],
598 );
599 }
600 #[test]
601 fn bind_function_decl() {
602 check(
603 "fn add(a in 'Int, b in 'Int) returns 'Int + 1 2",
604 expect![[r#"
605 __Scopes__
606 0: Root (parent none):
607 test: Module Module { root_scope: ScopeId(1), exports: {} }
608 1: Module test (parent scopeid0):
609 add: Function FunctionId(0)
610 2: Function (parent scopeid1):
611 a: FunctionParameter Named(Identifier { id: SymbolId(3), span: Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(13), length: 3 } } })
612 b: FunctionParameter Named(Identifier { id: SymbolId(3), span: Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(24), length: 3 } } })
613 "#]],
614 );
615 }
616
617 #[test]
618 fn bind_list_new_scope() {
619 check(
620 "fn add(a in 'Int, b in 'Int) returns 'Int [ 1, 2, 3, 4, 5, 6 ]",
621 expect![[r#"
622 __Scopes__
623 0: Root (parent none):
624 test: Module Module { root_scope: ScopeId(1), exports: {} }
625 1: Module test (parent scopeid0):
626 add: Function FunctionId(0)
627 2: Function (parent scopeid1):
628 a: FunctionParameter Named(Identifier { id: SymbolId(3), span: Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(13), length: 3 } } })
629 b: FunctionParameter Named(Identifier { id: SymbolId(3), span: Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(25), length: 3 } } })
630 "#]],
631 );
632 }
633}