use std::{cell::RefCell, rc::Rc};
use crate::{
ast::{ApplicationExpr, Ast, Bind, CondExpr, Expr, LetInExpr, LiteralExpr, Operator},
lexer::Lexer,
symbol::{NonTerminal, Rule, Span, Symbol, Terminal, TerminalClass},
};
pub struct AstBuilder<'a> {
lexer: &'a Lexer,
current_closure_ctx: Option<ClosureCtx>,
}
struct ClosureCtx {
parent: Option<Box<ClosureCtx>>,
params: Vec<Span>,
recursive_name: Option<Span>,
captures: Vec<String>,
}
impl<'a> AstBuilder<'a> {
pub fn new(lexer: &'a Lexer) -> Self {
Self {
lexer,
current_closure_ctx: None,
}
}
pub fn visit(&mut self, cst_root: &Symbol) -> Ast {
let rule = extract_rule(cst_root);
match rule.number {
1 => self.visit_append_ast(&rule.components),
2 => Ast::from(self.visit_bind(&rule.components[0])),
_ => unreachable!(),
}
}
fn visit_append_ast(&mut self, components: &[Symbol]) -> Ast {
let mut ast = self.visit(&components[0]);
let new_bind = self.visit_bind(&components[1]);
ast.append(new_bind);
ast
}
fn visit_bind(&mut self, symbol: &Symbol) -> Bind {
let rule = extract_rule(symbol);
match rule.number {
3 => self.visit_var_bind(extract_components(&rule.components[0])),
4 => self.visit_fun_bind(extract_components(&rule.components[0]), false),
5 => self.visit_fun_bind(extract_components(&rule.components[0]), true),
6 => self.visit_unit_bind(extract_components(&rule.components[0])),
_ => unreachable!(),
}
}
fn visit_unit_bind(&mut self, components: &[Symbol]) -> Bind {
let expr = self.visit_expr(&components[3]);
Bind { name: None, expr }
}
fn visit_var_bind(&mut self, components: &[Symbol]) -> Bind {
let name = extract_span(&components[1]).clone();
let expr = self.visit_expr(&components[3]);
Bind {
name: Some(name),
expr,
}
}
fn visit_fun_bind(&mut self, components: &[Symbol], recursive: bool) -> Bind {
let name_idx = if recursive { 2 } else { 1 };
let params_idx = if recursive { 3 } else { 2 };
let expr_idx = if recursive { 5 } else { 4 };
let name = extract_span(&components[name_idx]).clone();
let params = self.visit_params(&components[params_idx]);
self.push_closure_ctx(params, Some(name.clone()));
let expr = self.visit_expr(&components[expr_idx]);
let recursive_bind = if recursive {
Some(self.lexer.str_from_span(&name).to_string())
} else {
None
};
let expr = self.new_fun_expr(expr, recursive_bind);
Bind {
name: Some(name),
expr,
}
}
fn visit_expr(&mut self, symbol: &Symbol) -> Rc<RefCell<Expr>> {
match symbol {
Symbol::NonTerminal(non_terminal) => self.visit_non_terminal_expr(non_terminal),
Symbol::Terminal(terminal) => self.visit_terminal_expr(terminal),
}
}
fn visit_non_terminal_expr(&mut self, non_terminal: &NonTerminal) -> Rc<RefCell<Expr>> {
match non_terminal.rule.number {
14..=21 | 38..=41 => self.visit_expr(&non_terminal.rule.components[0]),
22 => self.visit_if_then_else_expr(&non_terminal.rule.components),
23 => self.visit_anonymous_fun(&non_terminal.rule.components),
24 => self.visit_expr(&non_terminal.rule.components[1]),
25 => self.visit_let_in_expr(&non_terminal.rule.components),
26..=34 => self.visit_binop_expr(&non_terminal.rule.components),
35 => self.visit_append_application(&non_terminal.rule.components),
36 | 37 => self.visit_application(&non_terminal.rule.components),
_ => unreachable!(),
}
}
fn visit_if_then_else_expr(&mut self, components: &[Symbol]) -> Rc<RefCell<Expr>> {
let cond = self.visit_expr(&components[1]);
let yes = self.visit_expr(&components[3]);
let no = self.visit_expr(&components[5]);
let span = Span::new(
extract_span(&components[0]).start_pos(),
no.borrow().span().end_pos(),
);
let cond_expr = Expr::Conditional(CondExpr {
cond,
yes,
no,
span,
});
Rc::new(RefCell::new(cond_expr))
}
fn visit_anonymous_fun(&mut self, components: &[Symbol]) -> Rc<RefCell<Expr>> {
let params = self.visit_params(&components[1]);
self.push_closure_ctx(params, None);
let expr = self.visit_expr(&components[3]);
self.new_fun_expr(expr, None)
}
fn visit_let_in_expr(&mut self, components: &[Symbol]) -> Rc<RefCell<Expr>> {
let bind_name = extract_span(&components[1]).clone();
let bind_expr = self.visit_expr(&components[3]);
let expr = self.visit_expr(&components[5]);
let span = Span::new(
extract_span(&components[0]).start_pos(),
expr.borrow().span().end_pos(),
);
let let_in_expr = Expr::LetIn(LetInExpr {
bind: (bind_name, bind_expr),
expr,
span,
});
Rc::new(RefCell::new(let_in_expr))
}
fn visit_application(&mut self, components: &[Symbol]) -> Rc<RefCell<Expr>> {
let fun = self.visit_expr(&components[0]);
let arg = self.visit_expr(&components[1]);
self.new_application_expr(fun, arg)
}
fn visit_append_application(&mut self, components: &[Symbol]) -> Rc<RefCell<Expr>> {
let app = self.visit_expr(&components[0]);
let arg = self.visit_expr(&components[1]);
if let Expr::Application(app_expr) = &mut *app.borrow_mut() {
let span = Span::new(app_expr.span.start_pos(), arg.borrow().span().end_pos());
app_expr.span = span;
app_expr.binds.push(arg);
} else {
unreachable!()
}
app
}
fn visit_binop_expr(&mut self, components: &[Symbol]) -> Rc<RefCell<Expr>> {
let lhs = self.visit_expr(&components[0]);
let op = match extract_terminal_class(&components[1]) {
TerminalClass::Plus => Operator::Plus,
TerminalClass::Minus => Operator::Minus,
TerminalClass::Star => Operator::Star,
TerminalClass::Slash => Operator::Slash,
TerminalClass::Eq => Operator::Eq,
TerminalClass::Lte => Operator::Lte,
TerminalClass::Lt => Operator::Lt,
TerminalClass::Gte => Operator::Gte,
TerminalClass::Gt => Operator::Gt,
_ => unreachable!(),
};
let rhs = self.visit_expr(&components[2]);
let span = Span::new(
lhs.borrow().span().start_pos(),
rhs.borrow().span().end_pos(),
);
Rc::new(RefCell::new(Expr::binop(op, lhs, rhs, span)))
}
fn visit_terminal_expr(&mut self, terminal: &Terminal) -> Rc<RefCell<Expr>> {
match terminal.class() {
TerminalClass::Number => self.new_integer_expr(terminal),
TerminalClass::Identifier => self.new_var_expr(terminal),
TerminalClass::Unit => self.new_unit_expr(terminal),
_ => unreachable!(),
}
}
fn visit_param(&self, symbol: &Symbol) -> Span {
extract_span(&extract_components(symbol)[0]).clone()
}
fn visit_params(&self, symbol: &Symbol) -> Vec<Span> {
let rule = extract_rule(symbol);
match rule.number {
11 => {
let mut param_list = self.visit_params(&rule.components[0]);
let param = self.visit_param(&rule.components[1]);
param_list.push(param);
param_list
}
12 => vec![self.visit_param(&rule.components[0])],
_ => unreachable!(),
}
}
fn new_fun_expr(
&mut self,
body: Rc<RefCell<Expr>>,
recursive_bind: Option<String>,
) -> Rc<RefCell<Expr>> {
let closure_ctx = self.pop_closure_ctx();
let mut params = closure_ctx.params;
let mut captures = closure_ctx.captures;
let span = body.borrow().span().clone();
let body = if let Expr::Fun(fun_expr) = &mut *body.borrow_mut() {
params.append(&mut fun_expr.params);
captures.append(&mut fun_expr.captures);
fun_expr.body.clone()
} else {
body
};
Rc::new(RefCell::new(Expr::fun(
params,
body,
captures,
recursive_bind,
span,
)))
}
fn new_application_expr(
&mut self,
fun: Rc<RefCell<Expr>>,
arg: Rc<RefCell<Expr>>,
) -> Rc<RefCell<Expr>> {
let span = Span::new(
fun.borrow().span().start_pos(),
arg.borrow().span().end_pos(),
);
let app_expr = ApplicationExpr {
fun,
binds: vec![arg],
span,
};
Rc::new(RefCell::new(Expr::Application(app_expr)))
}
fn new_unit_expr(&self, terminal: &Terminal) -> Rc<RefCell<Expr>> {
let span = terminal.span().clone();
let literal_expr = LiteralExpr::Unit(span);
Rc::new(RefCell::new(Expr::Literal(literal_expr)))
}
fn new_integer_expr(&self, terminal: &Terminal) -> Rc<RefCell<Expr>> {
let lexeme = self.lexer.get_lexeme(terminal);
let span = terminal.span().clone();
let value = lexeme.parse().unwrap();
Rc::new(RefCell::new(Expr::integer(value, span)))
}
fn new_var_expr(&mut self, terminal: &Terminal) -> Rc<RefCell<Expr>> {
let name = self.lexer.get_lexeme(terminal);
if let Some(ctx) = &mut self.current_closure_ctx
&& !ctx.is_in_params(name, self.lexer)
&& !ctx.is_recursive_name(name, self.lexer)
{
ctx.captures.push(name.to_string());
}
let id = terminal.span().clone();
Rc::new(RefCell::new(Expr::var(id)))
}
fn push_closure_ctx(&mut self, params: Vec<Span>, recursive_name: Option<Span>) {
let parent = self.current_closure_ctx.take().map(Box::new);
self.current_closure_ctx = Some(ClosureCtx {
parent,
params,
recursive_name,
captures: vec![],
});
}
fn pop_closure_ctx(&mut self) -> ClosureCtx {
let mut current_closure_ctx = self.current_closure_ctx.take().unwrap();
if let Some(parent) = current_closure_ctx.parent.take() {
self.current_closure_ctx = Some(*parent);
};
current_closure_ctx
}
}
fn extract_span(symbol: &Symbol) -> &Span {
if let Symbol::Terminal(t) = symbol {
t.span()
} else {
panic!("extract_span should only be used with Terminal")
}
}
fn extract_terminal_class(symbol: &Symbol) -> TerminalClass {
if let Symbol::Terminal(t) = symbol {
t.class()
} else {
panic!("extract_terminal_class should only be used with Terminal")
}
}
fn extract_components(symbol: &Symbol) -> &Vec<Symbol> {
&extract_rule(symbol).components
}
fn extract_rule(symbol: &Symbol) -> &Rule {
if let Symbol::NonTerminal(NonTerminal { rule, .. }) = symbol {
rule
} else {
panic!("extract_rule should only be used with NonTerminal")
}
}
impl ClosureCtx {
fn is_in_params(&self, name: &str, lexer: &Lexer) -> bool {
for param in &self.params {
if lexer.str_from_span(param) == name {
return true;
}
}
false
}
fn is_recursive_name(&self, name: &str, lexer: &Lexer) -> bool {
if let Some(span) = &self.recursive_name {
let recursive_name = lexer.str_from_span(span);
if name == recursive_name {
return true;
}
}
false
}
}