use std::{cell::RefCell, collections::HashSet, rc::Rc};
use crate::{
ast::{
ApplicationExpr, Ast, Bind, CondExpr, ConstructExpr, Expr, LetInExpr, LiteralExpr,
Operator, Pattern, PatternMatchExpr, Stmt, TupleExpr,
},
custom_types::{Constructor, CustomTypes, Variant},
lexer::Lexer,
symbol::{NonTerminal, Rule, Span, Symbol, Terminal, TerminalClass},
typ::{Primitive, Type, Variable},
};
pub struct AstBuilder<'a> {
lexer: &'a Lexer,
current_closure_ctx: Option<ClosureCtx>,
current_local_bindings: Option<LocalBindings>,
ast: Ast,
custom_types: CustomTypes,
}
struct ClosureCtx {
parent: Option<Box<ClosureCtx>>,
params: Vec<Span>,
recursive_name: Option<Span>,
captures: Vec<String>,
}
#[derive(Default)]
struct LocalBindings {
parent: Option<Box<LocalBindings>>,
bindings: HashSet<String>,
}
impl<'a> AstBuilder<'a> {
pub fn new(lexer: &'a Lexer) -> Self {
Self {
lexer,
current_closure_ctx: None,
current_local_bindings: None,
ast: Ast::default(),
custom_types: CustomTypes::default(),
}
}
pub fn build(mut self, cst_root: &Symbol) -> (Ast, CustomTypes) {
self.visit_stmts(cst_root);
(self.ast, self.custom_types)
}
fn visit_stmts(&mut self, cst_root: &Symbol) {
let rule = extract_rule(cst_root);
match rule.number {
1 => {
self.visit_stmts(&rule.components[0]);
self.visit_stmt(&rule.components[1]);
}
2 => {
self.visit_stmt(&rule.components[0]);
}
_ => unreachable!(),
}
}
fn visit_stmt(&mut self, symbol: &Symbol) {
let rule = extract_rule(symbol);
let stmt = match rule.number {
3 => Stmt::Bind(self.visit_var_bind(extract_components(&rule.components[0]))),
4 => Stmt::Bind(self.visit_fun_bind(extract_components(&rule.components[0]), false)),
5 => Stmt::Bind(self.visit_fun_bind(extract_components(&rule.components[0]), true)),
6 => Stmt::Bind(self.visit_unit_bind(extract_components(&rule.components[0]))),
7 => Stmt::TypeDecl(self.visit_type_decl(extract_components(&rule.components[0]))),
8 => Stmt::TypeDecl(self.visit_param_type_decl(&rule.components[0])),
_ => unreachable!(),
};
match stmt {
Stmt::Bind(bind) => self.ast.insert_stmt(bind),
Stmt::TypeDecl(variant) => self.custom_types.add_variant(variant),
}
}
fn visit_unit_bind(&mut self, components: &[Symbol]) -> Bind {
let expr = self.visit_expr(&components[3]);
Bind { name: None, expr }
}
fn visit_var_bind(&mut self, components: &[Symbol]) -> Bind {
let name = extract_span(&components[1]).clone();
let expr = self.visit_expr(&components[3]);
Bind {
name: Some(name),
expr,
}
}
fn visit_fun_bind(&mut self, components: &[Symbol], recursive: bool) -> Bind {
let name_idx = if recursive { 2 } else { 1 };
let params_idx = if recursive { 3 } else { 2 };
let expr_idx = if recursive { 5 } else { 4 };
let name = extract_span(&components[name_idx]).clone();
let params = self.visit_params(&components[params_idx]);
self.push_closure_ctx(params, Some(name.clone()));
let expr = self.visit_expr(&components[expr_idx]);
let recursive_bind = if recursive {
Some(self.lexer.str_from_span(&name).to_string())
} else {
None
};
let expr = self.new_fun_expr(expr, recursive_bind);
Bind {
name: Some(name),
expr,
}
}
fn visit_type_decl(&mut self, components: &[Symbol]) -> Variant {
let name = self.lexer.str_from_span(extract_span(&components[1]));
let constructors = self.visit_constructors(components.last().unwrap());
Variant::new(name.to_string(), vec![], constructors)
}
fn visit_param_type_decl(&mut self, symbol: &Symbol) -> Variant {
let rule = extract_rule(symbol);
match rule.number {
15 | 16 => {
let name = self.lexer.str_from_span(extract_span(&rule.components[2]));
let constructors = self.visit_constructors(rule.components.last().unwrap());
let params = vec![self.type_variable_id(&rule.components[1])];
Variant::new(name.to_string(), params, constructors)
}
17 | 18 => {
let name = self.lexer.str_from_span(extract_span(&rule.components[4]));
let constructors = self.visit_constructors(rule.components.last().unwrap());
let params = self.visit_type_variable_list(&rule.components[2]);
Variant::new(name.to_string(), params, constructors)
}
_ => unreachable!(),
}
}
fn visit_type_variable_list(&mut self, symbol: &Symbol) -> Vec<usize> {
let rule = extract_rule(symbol);
match rule.number {
22 => {
let mut params = self.visit_type_variable_list(&rule.components[0]);
let param = self.type_variable_id(&rule.components[2]);
params.push(param);
params
}
23 => {
vec![self.type_variable_id(&rule.components[0])]
}
_ => unreachable!(),
}
}
fn visit_expr(&mut self, symbol: &Symbol) -> Rc<RefCell<Expr>> {
match symbol {
Symbol::NonTerminal(non_terminal) => self.visit_non_terminal_expr(non_terminal),
Symbol::Terminal(terminal) => self.visit_terminal_expr(terminal),
}
}
fn visit_non_terminal_expr(&mut self, non_terminal: &NonTerminal) -> Rc<RefCell<Expr>> {
match non_terminal.rule.number {
37..=46 | 83..=86 | 88..=89 => self.visit_expr(&non_terminal.rule.components[0]),
47 | 48 | 87 => self.visit_construction_expr(&non_terminal.rule.components),
49 => self.visit_if_then_else_expr(&non_terminal.rule.components),
50 => self.visit_tuple_expr(&non_terminal.rule.components),
53 => self.visit_anonymous_fun(&non_terminal.rule.components),
54 => self.visit_expr(&non_terminal.rule.components[1]),
55 | 56 => self.visit_let_in_expr(&non_terminal.rule.components),
57 | 58 => self.visit_pattern_match_expr(&non_terminal.rule.components),
71..=79 => self.visit_binop_expr(&non_terminal.rule.components),
80 => self.visit_append_application(&non_terminal.rule.components),
81 | 82 => self.visit_application(&non_terminal.rule.components),
_ => unreachable!(),
}
}
fn visit_construction_expr(&mut self, components: &[Symbol]) -> Rc<RefCell<Expr>> {
let cons = extract_span(&components[0]).clone();
let (arg, span) = if components.len() == 2 {
let arg = self.visit_expr(&components[1]);
let span = Span::new(cons.start_pos(), arg.borrow().span().end_pos());
(Some(arg), span)
} else {
(None, cons.clone())
};
let construct_expr = ConstructExpr { cons, arg, span };
Rc::new(RefCell::new(Expr::Construction(construct_expr)))
}
fn visit_if_then_else_expr(&mut self, components: &[Symbol]) -> Rc<RefCell<Expr>> {
let cond = self.visit_expr(&components[1]);
let yes = self.visit_expr(&components[3]);
let no = self.visit_expr(&components[5]);
let span = Span::new(
extract_span(&components[0]).start_pos(),
no.borrow().span().end_pos(),
);
let cond_expr = Expr::Conditional(CondExpr {
cond,
yes,
no,
span,
});
Rc::new(RefCell::new(cond_expr))
}
fn visit_pattern_match_expr(&mut self, components: &[Symbol]) -> Rc<RefCell<Expr>> {
let matched = self.visit_expr(&components[1]);
let branches = self.visit_branches(components.last().unwrap());
let span = Span::new(
extract_span(&components[0]).start_pos(),
branches.last().unwrap().1.borrow().span().end_pos(),
);
let pattern_match_expr = PatternMatchExpr {
matched,
branches,
span,
};
Rc::new(RefCell::new(Expr::PatternMatch(pattern_match_expr)))
}
fn visit_tuple_expr(&mut self, components: &[Symbol]) -> Rc<RefCell<Expr>> {
let elements = self.visit_expr_list(&components[1]);
let start_pos = extract_span(&components[0]).start_pos();
let end_pos = extract_span(&components[2]).end_pos();
let span = Span::new(start_pos, end_pos);
let tuple_expr = TupleExpr { elements, span };
Rc::new(RefCell::new(Expr::Tuple(tuple_expr)))
}
fn visit_anonymous_fun(&mut self, components: &[Symbol]) -> Rc<RefCell<Expr>> {
let params = self.visit_params(&components[1]);
self.push_closure_ctx(params, None);
let expr = self.visit_expr(&components[3]);
self.new_fun_expr(expr, None)
}
fn visit_let_in_expr(&mut self, components: &[Symbol]) -> Rc<RefCell<Expr>> {
self.push_local_bindings();
let bind_name = extract_span(&components[1]).clone();
let bind_expr = self.visit_expr(&components[3]);
self.insert_name_to_local_bindings(&bind_name);
let expr = self.visit_expr(&components[5]);
self.pop_local_bindings();
let span = Span::new(
extract_span(&components[0]).start_pos(),
expr.borrow().span().end_pos(),
);
let let_in_expr = Expr::LetIn(LetInExpr {
bind: (bind_name, bind_expr),
expr,
span,
});
Rc::new(RefCell::new(let_in_expr))
}
fn visit_application(&mut self, components: &[Symbol]) -> Rc<RefCell<Expr>> {
let fun = self.visit_expr(&components[0]);
let arg = self.visit_expr(&components[1]);
self.new_application_expr(fun, arg)
}
fn visit_append_application(&mut self, components: &[Symbol]) -> Rc<RefCell<Expr>> {
let app = self.visit_expr(&components[0]);
let arg = self.visit_expr(&components[1]);
if let Expr::Application(app_expr) = &mut *app.borrow_mut() {
let span = Span::new(app_expr.span.start_pos(), arg.borrow().span().end_pos());
app_expr.span = span;
app_expr.binds.push(arg);
} else {
unreachable!()
}
app
}
fn visit_binop_expr(&mut self, components: &[Symbol]) -> Rc<RefCell<Expr>> {
let lhs = self.visit_expr(&components[0]);
let op = match extract_terminal_class(&components[1]) {
TerminalClass::Plus => Operator::Plus,
TerminalClass::Minus => Operator::Minus,
TerminalClass::Star => Operator::Star,
TerminalClass::Slash => Operator::Slash,
TerminalClass::Eq => Operator::Eq,
TerminalClass::Lte => Operator::Lte,
TerminalClass::Lt => Operator::Lt,
TerminalClass::Gte => Operator::Gte,
TerminalClass::Gt => Operator::Gt,
_ => unreachable!(),
};
let rhs = self.visit_expr(&components[2]);
let span = Span::new(
lhs.borrow().span().start_pos(),
rhs.borrow().span().end_pos(),
);
Rc::new(RefCell::new(Expr::binop(op, lhs, rhs, span)))
}
fn visit_terminal_expr(&mut self, terminal: &Terminal) -> Rc<RefCell<Expr>> {
match terminal.class() {
TerminalClass::Identifier => self.new_var_expr(terminal),
TerminalClass::Number | TerminalClass::Unit => {
let literal_expr = self.visit_literal_expr(terminal);
let expr = Expr::Literal(literal_expr);
Rc::new(RefCell::new(expr))
}
TerminalClass::ConstructorIdentifier => todo!(),
_ => unreachable!(),
}
}
fn visit_literal_expr(&mut self, terminal: &Terminal) -> LiteralExpr {
match terminal.class() {
TerminalClass::Number => self.new_integer_expr(terminal),
TerminalClass::Unit => self.new_unit_expr(terminal),
_ => unreachable!(),
}
}
fn visit_constructors(&self, symbol: &Symbol) -> Vec<Constructor> {
let rule = extract_rule(symbol);
match rule.number {
24 => {
let mut constructors = self.visit_constructors(&rule.components[0]);
constructors.push(self.visit_constructor(&rule.components[2]));
constructors
}
25 => vec![self.visit_constructor(&rule.components[0])],
_ => unreachable!(),
}
}
fn visit_constructor(&self, symbol: &Symbol) -> Constructor {
let rule = extract_rule(symbol);
match rule.number {
26 => {
let name = self.lexer.str_from_span(extract_span(&rule.components[0]));
Constructor::new(name.to_string(), None)
}
27 => {
let name = self.lexer.str_from_span(extract_span(&rule.components[0]));
let typ = self.visit_type_string(&rule.components[2]);
Constructor::new(name.to_string(), Some(typ))
}
_ => unreachable!(),
}
}
fn visit_type_string(&self, symbol: &Symbol) -> Rc<RefCell<Type>> {
let rule = extract_rule(symbol);
match rule.number {
28 => {
let typs = self.visit_type_strings(&rule.components[1]);
Rc::new(RefCell::new(Type::Tuple(typs)))
}
29 => {
let id = self.lexer.str_from_span(extract_span(&rule.components[0]));
match id {
"int" => Rc::new(RefCell::new(Type::Primitive(Primitive::Integer))),
"bool" => Rc::new(RefCell::new(Type::Primitive(Primitive::Bool))),
_ => Rc::new(RefCell::new(Type::Custom(id.to_string(), vec![]))),
}
}
30 => {
let arg = self.visit_type_string(&rule.components[0]);
let id = self.lexer.str_from_span(extract_span(&rule.components[1]));
Rc::new(RefCell::new(Type::Custom(id.to_string(), vec![arg])))
}
31 => {
let args = self.visit_type_args(&rule.components[1]);
let id = self.lexer.str_from_span(extract_span(&rule.components[3]));
Rc::new(RefCell::new(Type::Custom(id.to_string(), args)))
}
32 => {
let id = self.type_variable_id(&rule.components[0]);
Rc::new(RefCell::new(Type::Variable(Variable::Unbound(id))))
}
_ => unreachable!(),
}
}
fn visit_type_strings(&self, symbol: &Symbol) -> Vec<Rc<RefCell<Type>>> {
let rule = extract_rule(symbol);
match rule.number {
33 => {
let mut typs = self.visit_type_strings(&rule.components[0]);
typs.push(self.visit_type_string(&rule.components[2]));
typs
}
34 => vec![
self.visit_type_string(&rule.components[0]),
self.visit_type_string(&rule.components[2]),
],
_ => unreachable!(),
}
}
fn visit_type_args(&self, symbol: &Symbol) -> Vec<Rc<RefCell<Type>>> {
let rule = extract_rule(symbol);
match rule.number {
35 => {
let mut args = self.visit_type_args(&rule.components[0]);
args.push(self.visit_type_string(&rule.components[2]));
args
}
36 => vec![self.visit_type_string(&rule.components[0])],
_ => unreachable!(),
}
}
fn visit_param(&self, symbol: &Symbol) -> Span {
extract_span(&extract_components(symbol)[0]).clone()
}
fn visit_params(&self, symbol: &Symbol) -> Vec<Span> {
let rule = extract_rule(symbol);
match rule.number {
19 => {
let mut param_list = self.visit_params(&rule.components[0]);
let param = self.visit_param(&rule.components[1]);
param_list.push(param);
param_list
}
20 => vec![self.visit_param(&rule.components[0])],
_ => unreachable!(),
}
}
fn visit_expr_list(&mut self, symbol: &Symbol) -> Vec<Rc<RefCell<Expr>>> {
let rule = extract_rule(symbol);
match rule.number {
51 => {
let mut expr_list = self.visit_expr_list(&rule.components[0]);
let expr = self.visit_expr(&rule.components[2]);
expr_list.push(expr);
expr_list
}
52 => {
vec![
self.visit_expr(&rule.components[0]),
self.visit_expr(&rule.components[2]),
]
}
_ => unreachable!(),
}
}
fn visit_branches(&mut self, symbol: &Symbol) -> Vec<(Pattern, Rc<RefCell<Expr>>)> {
let rule = extract_rule(symbol);
match rule.number {
59 => {
let mut branches = self.visit_branches(&rule.components[0]);
let branch = self.visit_branch(&rule.components[2]);
branches.push(branch);
branches
}
60 => vec![self.visit_branch(&rule.components[0])],
_ => unreachable!(),
}
}
fn visit_branch(&mut self, symbol: &Symbol) -> (Pattern, Rc<RefCell<Expr>>) {
self.push_local_bindings();
let rule = extract_rule(symbol);
let pattern = match rule.number {
61 => self.visit_pattern(&rule.components[0]),
62 => Pattern::Tuple(self.visit_patterns(&rule.components[0])),
_ => unreachable!(),
};
let expr = self.visit_expr(&rule.components[2]);
self.pop_local_bindings();
(pattern, expr)
}
fn visit_pattern(&mut self, symbol: &Symbol) -> Pattern {
let rule = extract_rule(symbol);
match rule.number {
63 => {
let patterns = self.visit_patterns(&rule.components[1]);
Pattern::Tuple(patterns)
}
64 => {
let span = extract_span(&rule.components[0]).clone();
let arg = self.visit_pattern(&rule.components[1]);
Pattern::Constructor(span, Some(Box::new(arg)))
}
65 => {
let span = extract_span(&rule.components[0]).clone();
Pattern::Constructor(span, None)
}
66 => self.visit_singular_pattern(&rule.components[0]),
_ => unreachable!(),
}
}
fn visit_patterns(&mut self, symbol: &Symbol) -> Vec<Pattern> {
let rule = extract_rule(symbol);
match rule.number {
67 => {
let mut patterns = self.visit_patterns(&rule.components[0]);
let pattern = self.visit_pattern(&rule.components[2]);
patterns.push(pattern);
patterns
}
68 => vec![
self.visit_pattern(&rule.components[0]),
self.visit_pattern(&rule.components[2]),
],
_ => unreachable!(),
}
}
fn visit_singular_pattern(&mut self, symbol: &Symbol) -> Pattern {
let rule = extract_rule(symbol);
match rule.number {
69 => {
let span = extract_span(&rule.components[0]);
self.insert_name_to_local_bindings(span);
match self.lexer.str_from_span(span) {
"_" => Pattern::None,
_ => Pattern::Identifier(span.clone()),
}
}
70 => {
let literal = &rule.components[0];
let terminal = &extract_components(literal)[0];
Pattern::Literal(self.visit_literal_expr(extract_terminal(terminal)))
}
_ => unreachable!(),
}
}
fn new_fun_expr(
&mut self,
body: Rc<RefCell<Expr>>,
recursive_bind: Option<String>,
) -> Rc<RefCell<Expr>> {
let closure_ctx = self.pop_closure_ctx();
let mut params = closure_ctx.params;
let mut captures = closure_ctx.captures;
let span = body.borrow().span().clone();
let body = if let Expr::Fun(fun_expr) = &mut *body.borrow_mut() {
params.append(&mut fun_expr.params);
captures.append(&mut fun_expr.captures);
fun_expr.body.clone()
} else {
body
};
Rc::new(RefCell::new(Expr::fun(
params,
body,
captures,
recursive_bind,
span,
)))
}
fn new_application_expr(
&mut self,
fun: Rc<RefCell<Expr>>,
arg: Rc<RefCell<Expr>>,
) -> Rc<RefCell<Expr>> {
let span = Span::new(
fun.borrow().span().start_pos(),
arg.borrow().span().end_pos(),
);
let app_expr = ApplicationExpr {
fun,
binds: vec![arg],
span,
};
Rc::new(RefCell::new(Expr::Application(app_expr)))
}
fn new_unit_expr(&self, terminal: &Terminal) -> LiteralExpr {
let span = terminal.span().clone();
LiteralExpr::Unit(span)
}
fn new_integer_expr(&self, terminal: &Terminal) -> LiteralExpr {
let lexeme = self.lexer.get_lexeme(terminal);
let span = terminal.span().clone();
let value = lexeme.parse().unwrap();
LiteralExpr::Integer(value, span)
}
fn new_var_expr(&mut self, terminal: &Terminal) -> Rc<RefCell<Expr>> {
let name = self.lexer.get_lexeme(terminal);
if !self.does_local_bindings_have_name(name)
&& let Some(ctx) = &mut self.current_closure_ctx
&& !ctx.is_in_params(name, self.lexer)
&& !ctx.is_recursive_name(name, self.lexer)
{
ctx.captures.push(name.to_string());
}
let id = terminal.span().clone();
Rc::new(RefCell::new(Expr::var(id)))
}
fn push_closure_ctx(&mut self, params: Vec<Span>, recursive_name: Option<Span>) {
let parent = self.current_closure_ctx.take().map(Box::new);
self.current_closure_ctx = Some(ClosureCtx {
parent,
params,
recursive_name,
captures: vec![],
});
}
fn pop_closure_ctx(&mut self) -> ClosureCtx {
let mut current_closure_ctx = self.current_closure_ctx.take().unwrap();
if let Some(parent) = current_closure_ctx.parent.take() {
self.current_closure_ctx = Some(*parent);
};
current_closure_ctx
}
fn push_local_bindings(&mut self) {
let parent = self.current_local_bindings.take().map(Box::new);
let local_bindings = LocalBindings {
parent,
..Default::default()
};
self.current_local_bindings = Some(local_bindings);
}
fn pop_local_bindings(&mut self) {
if let Some(parent) = self.current_local_bindings.take().map(|l| l.parent) {
self.current_local_bindings = parent.map(|l| *l);
}
}
fn insert_name_to_local_bindings(&mut self, span: &Span) {
if let Some(local_bindings) = &mut self.current_local_bindings {
let name = self.lexer.str_from_span(span);
local_bindings.insert_name(name.to_string());
}
}
fn does_local_bindings_have_name(&self, name: &str) -> bool {
if let Some(local_bindings) = &self.current_local_bindings {
local_bindings.has_name(name)
} else {
false
}
}
fn type_variable_id(&self, type_variable: &Symbol) -> usize {
let span = extract_span(type_variable);
let var_char = self.lexer.str_from_span(span).chars().last().unwrap();
var_char as usize - ('a' as usize)
}
}
fn extract_span(symbol: &Symbol) -> &Span {
extract_terminal(symbol).span()
}
fn extract_terminal_class(symbol: &Symbol) -> TerminalClass {
extract_terminal(symbol).class()
}
fn extract_terminal(symbol: &Symbol) -> &Terminal {
if let Symbol::Terminal(t) = symbol {
t
} else {
panic!("extract_terminal should only be used with Terminal")
}
}
fn extract_components(symbol: &Symbol) -> &Vec<Symbol> {
&extract_rule(symbol).components
}
fn extract_rule(symbol: &Symbol) -> &Rule {
if let Symbol::NonTerminal(NonTerminal { rule, .. }) = symbol {
rule
} else {
panic!("extract_rule should only be used with NonTerminal")
}
}
impl ClosureCtx {
fn is_in_params(&self, name: &str, lexer: &Lexer) -> bool {
for param in &self.params {
if lexer.str_from_span(param) == name {
return true;
}
}
false
}
fn is_recursive_name(&self, name: &str, lexer: &Lexer) -> bool {
if let Some(span) = &self.recursive_name {
let recursive_name = lexer.str_from_span(span);
if name == recursive_name {
return true;
}
}
false
}
}
impl LocalBindings {
fn insert_name(&mut self, name: String) {
self.bindings.insert(name);
}
fn has_name(&self, name: &str) -> bool {
if self.bindings.contains(name) {
true
} else if let Some(parent) = &self.parent {
parent.has_name(name)
} else {
false
}
}
}