mod ast_printer;
use std::{cell::RefCell, rc::Rc};
use crate::{ast::ast_printer::AstPrinter, custom_types::Variant, lexer::Lexer, symbol::Span};
#[derive(Default)]
pub struct Ast {
pub binds: Vec<Bind>,
}
pub enum Stmt {
Bind(Bind),
TypeDecl(Variant),
}
pub struct Bind {
pub name: Option<Span>,
pub expr: Rc<RefCell<Expr>>,
}
pub enum Expr {
Literal(LiteralExpr),
Construction(ConstructExpr),
Var(VarExpr),
Fun(FunExpr),
Application(ApplicationExpr),
LetIn(LetInExpr),
BinOp(BinOpExpr),
Conditional(CondExpr),
PatternMatch(PatternMatchExpr),
Tuple(TupleExpr),
}
#[derive(Clone)]
pub enum LiteralExpr {
Integer(i64, Span),
Unit(Span),
}
#[derive(Clone)]
pub struct VarExpr {
pub id: Span,
pub poly_args: Option<String>,
}
pub struct FunExpr {
pub params: Vec<Span>,
pub body: Rc<RefCell<Expr>>,
pub captures: Vec<String>,
pub recursive_bind: Option<String>,
pub span: Span,
}
pub struct ApplicationExpr {
pub fun: Rc<RefCell<Expr>>,
pub binds: Vec<Rc<RefCell<Expr>>>,
pub span: Span,
}
pub struct ConstructExpr {
pub cons: Span,
pub arg: Option<Rc<RefCell<Expr>>>,
pub span: Span,
}
pub struct LetInExpr {
pub bind: (Span, Rc<RefCell<Expr>>),
pub expr: Rc<RefCell<Expr>>,
pub span: Span,
}
pub struct BinOpExpr {
pub op: Operator,
pub lhs: Rc<RefCell<Expr>>,
pub rhs: Rc<RefCell<Expr>>,
pub span: Span,
}
pub struct CondExpr {
pub cond: Rc<RefCell<Expr>>,
pub yes: Rc<RefCell<Expr>>,
pub no: Rc<RefCell<Expr>>,
pub span: Span,
}
pub struct PatternMatchExpr {
pub matched: Rc<RefCell<Expr>>,
pub branches: Vec<(Pattern, Rc<RefCell<Expr>>)>,
pub span: Span,
}
#[derive(Clone)]
pub enum Pattern {
Tuple(Vec<Pattern>),
Constructor(Span, Option<Box<Pattern>>),
Identifier(Span),
Literal(LiteralExpr),
None,
}
pub struct TupleExpr {
pub elements: Vec<Rc<RefCell<Expr>>>,
pub span: Span,
}
#[derive(Copy, Clone, PartialEq)]
pub enum Operator {
Plus,
Minus,
Star,
Slash,
Eq,
Neq,
Lte,
Lt,
Gte,
Gt,
}
impl From<Bind> for Ast {
fn from(bind: Bind) -> Self {
let mut ast = Self::default();
ast.insert_stmt(bind);
ast
}
}
impl Ast {
pub fn insert_stmt(&mut self, bind: Bind) {
self.binds.push(bind);
}
}
impl Expr {
pub fn span(&self) -> &Span {
match self {
Expr::Literal(LiteralExpr::Integer(_, span)) => span,
Expr::Literal(LiteralExpr::Unit(span)) => span,
Expr::Var(VarExpr { id, .. }) => id,
Expr::Fun(FunExpr { span, .. }) => span,
Expr::Application(ApplicationExpr { span, .. }) => span,
Expr::LetIn(LetInExpr { span, .. }) => span,
Expr::BinOp(BinOpExpr { span, .. }) => span,
Expr::Conditional(CondExpr { span, .. }) => span,
Expr::Tuple(TupleExpr { span, .. }) => span,
Expr::PatternMatch(PatternMatchExpr { span, .. }) => span,
Expr::Construction(ConstructExpr { span, .. }) => span,
}
}
pub fn var(span: Span) -> Self {
Self::Var(VarExpr {
id: span,
poly_args: None,
})
}
pub fn fun(
params: Vec<Span>,
body: Rc<RefCell<Expr>>,
captures: Vec<String>,
recursive_bind: Option<String>,
span: Span,
) -> Self {
Self::Fun(FunExpr {
params,
body,
captures,
recursive_bind,
span,
})
}
pub fn binop(op: Operator, lhs: Rc<RefCell<Expr>>, rhs: Rc<RefCell<Expr>>, span: Span) -> Self {
Self::BinOp(BinOpExpr { op, lhs, rhs, span })
}
}
impl Ast {
pub fn pretty_print(&self, lexer: &Lexer) {
let printer = AstPrinter::new(lexer);
printer.pretty_print(self);
}
}
impl ApplicationExpr {
pub fn pretty_print(&self, lexer: &Lexer) {
let mut printer = AstPrinter::new(lexer);
printer.pretty_print_application_expr(self);
}
}
impl VarExpr {
pub fn base_name<'a>(&self, lexer: &'a Lexer) -> &'a str {
lexer.str_from_span(&self.id)
}
pub fn mono_name(&self, lexer: &Lexer) -> String {
let mut var_name = self.base_name(lexer).to_string();
if let Some(poly_args) = &self.poly_args {
var_name += ".";
var_name += poly_args;
}
var_name
}
}
impl Pattern {
pub fn has_literal(&self) -> bool {
match self {
Pattern::Tuple(elements) => elements.iter().any(|e| e.has_literal()),
Pattern::Identifier(_) => false,
Pattern::Literal(_) => true,
Pattern::None => false,
Pattern::Constructor(_, _) => true,
}
}
pub fn has_identifier(&self) -> bool {
match self {
Pattern::Tuple(elements) => elements.iter().any(|e| e.has_identifier()),
Pattern::Identifier(_) => true,
Pattern::Literal(_) => false,
Pattern::None => false,
Pattern::Constructor(_, None) => false,
Pattern::Constructor(_, Some(arg)) => arg.has_identifier(),
}
}
}
impl Operator {
pub fn cmp_fun_prefix(&self) -> &'static str {
match self {
Operator::Eq => "oonta.$eq",
Operator::Neq => "oonta.$eq",
Operator::Lte => "oonta.$lte",
Operator::Lt => "oonta.$lt",
Operator::Gte => "oonta.$gte",
Operator::Gt => "oonta.$gt",
_ => panic!("Only use cmp name on comparison operators"),
}
}
}
impl std::fmt::Display for Operator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Operator::Plus => write!(f, "+"),
Operator::Minus => write!(f, "-"),
Operator::Star => write!(f, "*"),
Operator::Slash => write!(f, "/"),
Operator::Eq => write!(f, "="),
Operator::Neq => write!(f, "<>"),
Operator::Lte => write!(f, "<="),
Operator::Lt => write!(f, "<"),
Operator::Gte => write!(f, ">="),
Operator::Gt => write!(f, ">"),
}
}
}