expr_solver/
program.rs

1use crate::ir::Instr;
2use crate::symbol::{SymTable, Symbol};
3use bincode::config;
4use colored::Colorize;
5use rust_decimal::Decimal;
6use std::collections::HashMap;
7use serde::{Deserialize, Serialize};
8use thiserror::Error;
9
10/// Expression parsing and evaluation errors.
11#[derive(Error, Debug)]
12pub enum ProgramError {
13    #[error("Compilation error: {0}")]
14    CompileError(String),
15    #[error("Decoding error: {0}")]
16    DecodingError(#[from] bincode::error::DecodeError),
17    #[error("incompatible program version: expected {0}, got {1}")]
18    IncompatibleVersions(String, String),
19    #[error("Unknown symbol: {0}")]
20    UnknownSymbol(String),
21    #[error("Symbol '{0}' is not a {1}")]
22    SymbolKindMismatch(String, String),
23    #[error("Function '{0}' incorrect arity")]
24    InvalidFuncArity(String),
25    #[error("Corrupted instruction: {0}")]
26    CorrupedInstruction(String),
27}
28
29/// Executable program is still a sequence of `Instr<'sym>` referencing symbols
30/// inside a provided `SymTable`.
31#[derive(Default)]
32pub struct Program<'sym> {
33    pub version: String,
34    pub code: Vec<Instr<'sym>>,
35}
36
37/// A compact, fully-owned form of a program that can be serialized without lifetimes.
38/// Symbols are stored by name and kind; instructions reference symbols by index.
39#[derive(Debug, Clone, Serialize, Deserialize)]
40struct Binary {
41    version: String,
42    symbols: Vec<BinarySymbol>,
43    code: Vec<BinaryInstr>,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47enum BinarySymbol {
48    /// Named constant
49    Const(String),
50    /// Named function
51    Func {
52        name: String,
53        args: usize,
54        variadic: bool,
55    },
56}
57
58impl BinarySymbol {
59    fn name(&self) -> String {
60        match self {
61            BinarySymbol::Const(name) => name.clone(),
62            BinarySymbol::Func { name, .. } => name.clone(),
63        }
64    }
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
68enum BinaryInstr {
69    Push(Decimal),
70    Load(u32), // index into `symbols`
71    Neg,
72    Add,
73    Sub,
74    Mul,
75    Div,
76    Pow,
77    Fact,
78    Call(u32, usize), // index into `symbols` and argument count
79    Equal,
80    NotEqual,
81    Less,
82    LessEqual,
83    Greater,
84    GreaterEqual,
85}
86
87impl<'sym> Program<'sym> {
88    pub fn new() -> Self {
89        Self {
90            version: env!("CARGO_PKG_VERSION").to_string(),
91            code: Vec::new(),
92        }
93    }
94
95    pub fn compile(&self) -> Result<Vec<u8>, ProgramError> {
96        let binary = self.to_binary();
97        let config = config::standard();
98        bincode::serde::encode_to_vec(&binary, config)
99            .map_err(|err| ProgramError::CompileError(format!("failed to encode program: {}", err)))
100    }
101
102    pub fn load(data: &[u8], table: &'sym SymTable) -> Result<Program<'sym>, ProgramError> {
103        let config = config::standard();
104        let (decoded, _): (Binary, usize) = bincode::serde::decode_from_slice(&data, config)
105            .map_err(ProgramError::DecodingError)?;
106
107        Self::validate_version(&decoded.version)?;
108
109        let get_sym = |bin_sym: &BinarySymbol| -> Result<&'sym Symbol, ProgramError> {
110            let name = bin_sym.name();
111            table.get(&name).ok_or(ProgramError::UnknownSymbol(name))
112        };
113
114        let mut program = Program::new();
115        program.version = decoded.version.clone();
116
117        for instr in &decoded.code {
118            match instr {
119                BinaryInstr::Push(v) => {
120                    program.code.push(Instr::Push(*v));
121                }
122                BinaryInstr::Load(idx) => {
123                    let bin_sym = &decoded.symbols[*idx as usize];
124                    let sym = get_sym(&bin_sym)?;
125                    match bin_sym {
126                        BinarySymbol::Const(_) => {
127                            if !matches!(sym, Symbol::Const { .. }) {
128                                return Err(ProgramError::SymbolKindMismatch(
129                                    sym.name().to_string(),
130                                    "constant".to_string(),
131                                ));
132                            }
133                        }
134                        _ => {
135                            return Err(ProgramError::CorrupedInstruction("LOAD".to_string()));
136                        }
137                    }
138                    program.code.push(Instr::Load(sym))
139                }
140                BinaryInstr::Neg => program.code.push(Instr::Neg),
141                BinaryInstr::Add => program.code.push(Instr::Add),
142                BinaryInstr::Sub => program.code.push(Instr::Sub),
143                BinaryInstr::Mul => program.code.push(Instr::Mul),
144                BinaryInstr::Div => program.code.push(Instr::Div),
145                BinaryInstr::Pow => program.code.push(Instr::Pow),
146                BinaryInstr::Fact => program.code.push(Instr::Fact),
147                BinaryInstr::Call(idx, argc) => {
148                    let bin_sym = &decoded.symbols[*idx as usize];
149                    let sym = get_sym(&bin_sym)?;
150                    if !matches!(sym, Symbol::Func { .. }) {
151                        return Err(ProgramError::SymbolKindMismatch(
152                            sym.name().to_string(),
153                            "function".to_string(),
154                        ));
155                    }
156                    program.code.push(Instr::Call(sym, *argc));
157                }
158                // Comparison operators
159                BinaryInstr::Equal => program.code.push(Instr::Equal),
160                BinaryInstr::NotEqual => program.code.push(Instr::NotEqual),
161                BinaryInstr::Less => program.code.push(Instr::Less),
162                BinaryInstr::LessEqual => program.code.push(Instr::LessEqual),
163                BinaryInstr::Greater => program.code.push(Instr::Greater),
164                BinaryInstr::GreaterEqual => program.code.push(Instr::GreaterEqual),
165            }
166        }
167
168        Ok(program)
169    }
170
171    fn validate_version(version: &String) -> Result<(), ProgramError> {
172        let current_version = env!("CARGO_PKG_VERSION");
173        if version != current_version {
174            return Err(ProgramError::IncompatibleVersions(
175                current_version.to_string(),
176                version.clone(),
177            ));
178        }
179        Ok(())
180    }
181
182    pub fn get_assembly(&self) -> String {
183        use std::fmt::Write as _;
184
185        let mut out = String::new();
186        out += &format!("; VERSION {}\n", self.version).bright_black().to_string();
187
188        let emit = |mnemonic: &str| -> String { format!("{}", mnemonic.magenta()) };
189        let emit1 = |mnemonic: &str, op: &str| -> String {
190            format!("{} {}", mnemonic.magenta(), op.green())
191        };
192
193        for (i, instr) in self.code.iter().enumerate() {
194            let _ = write!(out, "{} ", format!("{:04X}", i).yellow());
195            let line = match instr {
196                Instr::Push(v) => emit1("PUSH", &v.to_string().green()),
197                Instr::Load(sym) => emit1("LOAD", &sym.name().blue()),
198                Instr::Neg => emit("NEG"),
199                Instr::Add => emit("ADD"),
200                Instr::Sub => emit("SUB"),
201                Instr::Mul => emit("MUL"),
202                Instr::Div => emit("DIV"),
203                Instr::Pow => emit("POW"),
204                Instr::Fact => emit("FACT"),
205                Instr::Call(sym, argc) => format!(
206                    "{} {} args: {}",
207                    emit("CALL"),
208                    sym.name().cyan(),
209                    argc.to_string().bright_blue()
210                ),
211                Instr::Equal => emit("EQ"),
212                Instr::NotEqual => emit("NEQ"),
213                Instr::Less => emit("LT"),
214                Instr::LessEqual => emit("LTE"),
215                Instr::Greater => emit("GT"),
216                Instr::GreaterEqual => emit("GTE"),
217            };
218            let _ = writeln!(out, "{}", line);
219        }
220        out
221    }
222
223    fn to_binary(&self) -> Binary {
224        let mut map: HashMap<String, u32> = HashMap::new();
225        let mut binary = Binary {
226            version: self.version.clone(),
227            symbols: Vec::new(),
228            code: Vec::new(),
229        };
230
231        let mut get_index = |sym: &'sym Symbol| -> u32 {
232            map.get(sym.name()).map(|val| *val).unwrap_or_else(|| {
233                let i = binary.symbols.len() as u32;
234                map.insert(sym.name().to_string(), i);
235                binary.symbols.push(match sym {
236                    Symbol::Const { .. } => BinarySymbol::Const(sym.name().to_string()),
237                    Symbol::Func { args, variadic, .. } => BinarySymbol::Func {
238                        name: sym.name().to_string(),
239                        args: *args,
240                        variadic: *variadic,
241                    },
242                });
243                i
244            })
245        };
246
247        for instr in &self.code {
248            match instr {
249                Instr::Push(v) => {
250                    binary.code.push(BinaryInstr::Push(*v));
251                }
252                Instr::Load(sym) => {
253                    let idx = get_index(sym);
254                    binary.code.push(BinaryInstr::Load(idx));
255                }
256                Instr::Neg => {
257                    binary.code.push(BinaryInstr::Neg);
258                }
259                Instr::Add => {
260                    binary.code.push(BinaryInstr::Add);
261                }
262                Instr::Sub => {
263                    binary.code.push(BinaryInstr::Sub);
264                }
265                Instr::Mul => {
266                    binary.code.push(BinaryInstr::Mul);
267                }
268                Instr::Div => {
269                    binary.code.push(BinaryInstr::Div);
270                }
271                Instr::Pow => {
272                    binary.code.push(BinaryInstr::Pow);
273                }
274                Instr::Fact => {
275                    binary.code.push(BinaryInstr::Fact);
276                }
277                Instr::Call(sym, argc) => {
278                    let idx = get_index(sym);
279                    binary.code.push(BinaryInstr::Call(idx, *argc));
280                }
281                Instr::Equal => {
282                    binary.code.push(BinaryInstr::Equal);
283                }
284                Instr::NotEqual => {
285                    binary.code.push(BinaryInstr::NotEqual);
286                }
287                Instr::Less => {
288                    binary.code.push(BinaryInstr::Less);
289                }
290                Instr::LessEqual => {
291                    binary.code.push(BinaryInstr::LessEqual);
292                }
293                Instr::Greater => {
294                    binary.code.push(BinaryInstr::Greater);
295                }
296                Instr::GreaterEqual => {
297                    binary.code.push(BinaryInstr::GreaterEqual);
298                }
299            }
300        }
301
302        binary
303    }
304}