expr_solver/
ir.rs

1use crate::ast::{BinOp, Expr, ExprKind, UnOp};
2use crate::program::Program;
3use crate::span::Span;
4use crate::symbol::Symbol;
5use rust_decimal::Decimal;
6use thiserror::Error;
7
8/// IR building errors.
9#[derive(Error, Debug, Clone)]
10pub enum IrError {
11    #[error("Undefined symbol {0}")]
12    UndefinedSymbol(String, Span),
13}
14
15#[derive(Debug, Clone)]
16pub enum Instr<'sym> {
17    Push(Decimal),
18    Load(&'sym Symbol),
19    Neg,
20    Add,
21    Sub,
22    Mul,
23    Div,
24    Pow,
25    Fact,
26    Call(&'sym Symbol, usize), // Symbol and argument count
27    // Comparison operators
28    Equal,
29    NotEqual,
30    Less,
31    LessEqual,
32    Greater,
33    GreaterEqual,
34}
35
36pub struct IrBuilder<'sym> {
37    prog: Program<'sym>,
38}
39
40impl<'src, 'sym> IrBuilder<'sym> {
41    pub fn new() -> Self {
42        Self {
43            prog: Program::new(),
44        }
45    }
46
47    pub fn build(mut self, expr: &Expr<'src, 'sym>) -> Result<Program<'sym>, IrError> {
48        self.emit(expr)?;
49        Ok(self.prog)
50    }
51
52    fn emit(&mut self, e: &Expr<'src, 'sym>) -> Result<(), IrError> {
53        match &e.kind {
54            ExprKind::Literal(v) => {
55                self.prog.code.push(Instr::Push(*v));
56            }
57            ExprKind::Ident { name, sym } => {
58                if sym.is_none() {
59                    return Err(IrError::UndefinedSymbol(name.to_string(), e.span));
60                }
61                self.prog.code.push(Instr::Load(sym.unwrap()));
62            }
63            ExprKind::Unary { op, expr } => {
64                self.emit(expr)?;
65                match op {
66                    UnOp::Neg => self.prog.code.push(Instr::Neg),
67                    UnOp::Fact => self.prog.code.push(Instr::Fact),
68                }
69            }
70            ExprKind::Binary { op, left, right } => {
71                self.emit(left)?;
72                self.emit(right)?;
73                self.prog.code.push(match op {
74                    BinOp::Add => Instr::Add,
75                    BinOp::Sub => Instr::Sub,
76                    BinOp::Mul => Instr::Mul,
77                    BinOp::Div => Instr::Div,
78                    BinOp::Pow => Instr::Pow,
79                    BinOp::Equal => Instr::Equal,
80                    BinOp::NotEqual => Instr::NotEqual,
81                    BinOp::Less => Instr::Less,
82                    BinOp::LessEqual => Instr::LessEqual,
83                    BinOp::Greater => Instr::Greater,
84                    BinOp::GreaterEqual => Instr::GreaterEqual,
85                });
86            }
87            ExprKind::Call { name, args, sym } => {
88                if sym.is_none() {
89                    return Err(IrError::UndefinedSymbol(name.to_string(), e.span));
90                }
91                for a in args.iter() {
92                    self.emit(a)?;
93                }
94                self.prog.code.push(Instr::Call(sym.unwrap(), args.len()));
95            }
96        }
97        Ok(())
98    }
99}