1use crate::ast::{BinOp, Expr, ExprKind, UnOp};
2use crate::program::Program;
3use crate::span::Span;
4use crate::symbol::Symbol;
5use rust_decimal::Decimal;
6use thiserror::Error;
7
8#[derive(Error, Debug, Clone)]
10pub enum IrError {
11 #[error("Undefined symbol {0}")]
12 UndefinedSymbol(String, Span),
13}
14
15#[derive(Debug, Clone)]
16pub enum Instr<'sym> {
17 Push(Decimal),
18 Load(&'sym Symbol),
19 Neg,
20 Add,
21 Sub,
22 Mul,
23 Div,
24 Pow,
25 Fact,
26 Call(&'sym Symbol, usize), Equal,
29 NotEqual,
30 Less,
31 LessEqual,
32 Greater,
33 GreaterEqual,
34}
35
36pub struct IrBuilder<'sym> {
37 prog: Program<'sym>,
38}
39
40impl<'src, 'sym> IrBuilder<'sym> {
41 pub fn new() -> Self {
42 Self {
43 prog: Program::new(),
44 }
45 }
46
47 pub fn build(mut self, expr: &Expr<'src, 'sym>) -> Result<Program<'sym>, IrError> {
48 self.emit(expr)?;
49 Ok(self.prog)
50 }
51
52 fn emit(&mut self, e: &Expr<'src, 'sym>) -> Result<(), IrError> {
53 match &e.kind {
54 ExprKind::Literal(v) => {
55 self.prog.code.push(Instr::Push(*v));
56 }
57 ExprKind::Ident { name, sym } => {
58 if sym.is_none() {
59 return Err(IrError::UndefinedSymbol(name.to_string(), e.span));
60 }
61 self.prog.code.push(Instr::Load(sym.unwrap()));
62 }
63 ExprKind::Unary { op, expr } => {
64 self.emit(expr)?;
65 match op {
66 UnOp::Neg => self.prog.code.push(Instr::Neg),
67 UnOp::Fact => self.prog.code.push(Instr::Fact),
68 }
69 }
70 ExprKind::Binary { op, left, right } => {
71 self.emit(left)?;
72 self.emit(right)?;
73 self.prog.code.push(match op {
74 BinOp::Add => Instr::Add,
75 BinOp::Sub => Instr::Sub,
76 BinOp::Mul => Instr::Mul,
77 BinOp::Div => Instr::Div,
78 BinOp::Pow => Instr::Pow,
79 BinOp::Equal => Instr::Equal,
80 BinOp::NotEqual => Instr::NotEqual,
81 BinOp::Less => Instr::Less,
82 BinOp::LessEqual => Instr::LessEqual,
83 BinOp::Greater => Instr::Greater,
84 BinOp::GreaterEqual => Instr::GreaterEqual,
85 });
86 }
87 ExprKind::Call { name, args, sym } => {
88 if sym.is_none() {
89 return Err(IrError::UndefinedSymbol(name.to_string(), e.span));
90 }
91 for a in args.iter() {
92 self.emit(a)?;
93 }
94 self.prog.code.push(Instr::Call(sym.unwrap(), args.len()));
95 }
96 }
97 Ok(())
98 }
99}