expr_solver/
program.rs

1use crate::ir::Instr;
2use crate::symbol::{SymTable, Symbol};
3use bincode::config;
4use colored::Colorize;
5use rust_decimal::Decimal;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use thiserror::Error;
9
10/// Current version of the program format
11const PROGRAM_VERSION: &str = env!("CARGO_PKG_VERSION");
12
13/// Expression parsing and evaluation errors.
14#[derive(Error, Debug)]
15pub enum ProgramError {
16    #[error("Compilation error: {0}")]
17    CompileError(String),
18    #[error("Decoding error: {0}")]
19    DecodingError(#[from] bincode::error::DecodeError),
20    #[error("incompatible program version: expected {0}, got {1}")]
21    IncompatibleVersions(String, String),
22    #[error("Unknown symbol: {0}")]
23    UnknownSymbol(String),
24    #[error("Symbol '{0}' is not a {1}")]
25    SymbolKindMismatch(String, String),
26    #[error("Function '{0}' incorrect arity")]
27    InvalidFuncArity(String),
28    #[error("Corrupted instruction: {0}")]
29    CorrupedInstruction(String),
30}
31
32/// Executable program containing bytecode instructions.
33///
34/// Programs reference symbols from a [`SymTable`] and can be serialized
35/// to binary format for storage or transmission.
36#[derive(Default)]
37pub struct Program<'sym> {
38    pub version: String,
39    pub code: Vec<Instr<'sym>>,
40}
41
42/// A compact, fully-owned form of a program that can be serialized without lifetimes.
43/// Symbols are stored by name and kind; instructions reference symbols by index.
44#[derive(Debug, Clone, Serialize, Deserialize)]
45struct Binary {
46    version: String,
47    symbols: Vec<BinarySymbol>,
48    code: Vec<BinaryInstr>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52enum BinarySymbol {
53    /// Named constant
54    Const(String),
55    /// Named function
56    Func {
57        name: String,
58        args: usize,
59        variadic: bool,
60    },
61}
62
63impl BinarySymbol {
64    fn name(&self) -> String {
65        match self {
66            BinarySymbol::Const(name) => name.clone(),
67            BinarySymbol::Func { name, .. } => name.clone(),
68        }
69    }
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
73enum BinaryInstr {
74    Push(Decimal),
75    Load(u32), // index into `symbols`
76    Neg,
77    Add,
78    Sub,
79    Mul,
80    Div,
81    Pow,
82    Fact,
83    Call(u32, usize), // index into `symbols` and argument count
84    Equal,
85    NotEqual,
86    Less,
87    LessEqual,
88    Greater,
89    GreaterEqual,
90}
91
92impl<'sym> Program<'sym> {
93    /// Creates a new empty program.
94    pub fn new() -> Self {
95        Self {
96            version: PROGRAM_VERSION.to_string(),
97            code: Vec::new(),
98        }
99    }
100
101    /// Compiles the program to binary format for serialization.
102    pub fn compile(&self) -> Result<Vec<u8>, ProgramError> {
103        let binary = self.to_binary();
104        let config = config::standard();
105        bincode::serde::encode_to_vec(&binary, config)
106            .map_err(|err| ProgramError::CompileError(format!("failed to encode program: {}", err)))
107    }
108
109    /// Loads a program from binary data with the given symbol table.
110    ///
111    /// The binary data must have been created with [`compile`](Self::compile).
112    pub fn load(data: &[u8], table: &'sym SymTable) -> Result<Program<'sym>, ProgramError> {
113        let config = config::standard();
114        let (decoded, _): (Binary, usize) = bincode::serde::decode_from_slice(&data, config)
115            .map_err(ProgramError::DecodingError)?;
116
117        Self::validate_version(&decoded.version)?;
118
119        let get_sym = |bin_sym: &BinarySymbol| -> Result<&'sym Symbol, ProgramError> {
120            let name = bin_sym.name();
121            table.get(&name).ok_or(ProgramError::UnknownSymbol(name))
122        };
123
124        let mut program = Program::new();
125        program.version = decoded.version.clone();
126
127        for instr in &decoded.code {
128            match instr {
129                BinaryInstr::Push(v) => {
130                    program.code.push(Instr::Push(*v));
131                }
132                BinaryInstr::Load(idx) => {
133                    let bin_sym = &decoded.symbols[*idx as usize];
134                    let sym = get_sym(&bin_sym)?;
135                    match bin_sym {
136                        BinarySymbol::Const(_) => {
137                            if !matches!(sym, Symbol::Const { .. }) {
138                                return Err(ProgramError::SymbolKindMismatch(
139                                    sym.name().to_string(),
140                                    "constant".to_string(),
141                                ));
142                            }
143                        }
144                        _ => {
145                            return Err(ProgramError::CorrupedInstruction("LOAD".to_string()));
146                        }
147                    }
148                    program.code.push(Instr::Load(sym))
149                }
150                BinaryInstr::Neg => program.code.push(Instr::Neg),
151                BinaryInstr::Add => program.code.push(Instr::Add),
152                BinaryInstr::Sub => program.code.push(Instr::Sub),
153                BinaryInstr::Mul => program.code.push(Instr::Mul),
154                BinaryInstr::Div => program.code.push(Instr::Div),
155                BinaryInstr::Pow => program.code.push(Instr::Pow),
156                BinaryInstr::Fact => program.code.push(Instr::Fact),
157                BinaryInstr::Call(idx, argc) => {
158                    let bin_sym = &decoded.symbols[*idx as usize];
159                    let sym = get_sym(&bin_sym)?;
160                    if !matches!(sym, Symbol::Func { .. }) {
161                        return Err(ProgramError::SymbolKindMismatch(
162                            sym.name().to_string(),
163                            "function".to_string(),
164                        ));
165                    }
166                    program.code.push(Instr::Call(sym, *argc));
167                }
168                // Comparison operators
169                BinaryInstr::Equal => program.code.push(Instr::Equal),
170                BinaryInstr::NotEqual => program.code.push(Instr::NotEqual),
171                BinaryInstr::Less => program.code.push(Instr::Less),
172                BinaryInstr::LessEqual => program.code.push(Instr::LessEqual),
173                BinaryInstr::Greater => program.code.push(Instr::Greater),
174                BinaryInstr::GreaterEqual => program.code.push(Instr::GreaterEqual),
175            }
176        }
177
178        Ok(program)
179    }
180
181    fn validate_version(version: &String) -> Result<(), ProgramError> {
182        if version != PROGRAM_VERSION {
183            return Err(ProgramError::IncompatibleVersions(
184                PROGRAM_VERSION.to_string(),
185                version.clone(),
186            ));
187        }
188        Ok(())
189    }
190
191    /// Returns a human-readable assembly representation of the program.
192    pub fn get_assembly(&self) -> String {
193        use std::fmt::Write as _;
194
195        let mut out = String::new();
196        out += &format!("; VERSION {}\n", self.version)
197            .bright_black()
198            .to_string();
199
200        let emit = |mnemonic: &str| -> String { format!("{}", mnemonic.magenta()) };
201        let emit1 = |mnemonic: &str, op: &str| -> String {
202            format!("{} {}", mnemonic.magenta(), op.green())
203        };
204
205        for (i, instr) in self.code.iter().enumerate() {
206            let _ = write!(out, "{} ", format!("{:04X}", i).yellow());
207            let line = match instr {
208                Instr::Push(v) => emit1("PUSH", &v.to_string().green()),
209                Instr::Load(sym) => emit1("LOAD", &sym.name().blue()),
210                Instr::Neg => emit("NEG"),
211                Instr::Add => emit("ADD"),
212                Instr::Sub => emit("SUB"),
213                Instr::Mul => emit("MUL"),
214                Instr::Div => emit("DIV"),
215                Instr::Pow => emit("POW"),
216                Instr::Fact => emit("FACT"),
217                Instr::Call(sym, argc) => format!(
218                    "{} {} args: {}",
219                    emit("CALL"),
220                    sym.name().cyan(),
221                    argc.to_string().bright_blue()
222                ),
223                Instr::Equal => emit("EQ"),
224                Instr::NotEqual => emit("NEQ"),
225                Instr::Less => emit("LT"),
226                Instr::LessEqual => emit("LTE"),
227                Instr::Greater => emit("GT"),
228                Instr::GreaterEqual => emit("GTE"),
229            };
230            let _ = writeln!(out, "{}", line);
231        }
232        out
233    }
234
235    fn to_binary(&self) -> Binary {
236        let mut map: HashMap<String, u32> = HashMap::new();
237        let mut binary = Binary {
238            version: self.version.clone(),
239            symbols: Vec::new(),
240            code: Vec::new(),
241        };
242
243        let mut get_index = |sym: &'sym Symbol| -> u32 {
244            map.get(sym.name()).map(|val| *val).unwrap_or_else(|| {
245                let i = binary.symbols.len() as u32;
246                map.insert(sym.name().to_string(), i);
247                binary.symbols.push(match sym {
248                    Symbol::Const { .. } => BinarySymbol::Const(sym.name().to_string()),
249                    Symbol::Func { args, variadic, .. } => BinarySymbol::Func {
250                        name: sym.name().to_string(),
251                        args: *args,
252                        variadic: *variadic,
253                    },
254                });
255                i
256            })
257        };
258
259        for instr in &self.code {
260            match instr {
261                Instr::Push(v) => {
262                    binary.code.push(BinaryInstr::Push(*v));
263                }
264                Instr::Load(sym) => {
265                    let idx = get_index(sym);
266                    binary.code.push(BinaryInstr::Load(idx));
267                }
268                Instr::Neg => {
269                    binary.code.push(BinaryInstr::Neg);
270                }
271                Instr::Add => {
272                    binary.code.push(BinaryInstr::Add);
273                }
274                Instr::Sub => {
275                    binary.code.push(BinaryInstr::Sub);
276                }
277                Instr::Mul => {
278                    binary.code.push(BinaryInstr::Mul);
279                }
280                Instr::Div => {
281                    binary.code.push(BinaryInstr::Div);
282                }
283                Instr::Pow => {
284                    binary.code.push(BinaryInstr::Pow);
285                }
286                Instr::Fact => {
287                    binary.code.push(BinaryInstr::Fact);
288                }
289                Instr::Call(sym, argc) => {
290                    let idx = get_index(sym);
291                    binary.code.push(BinaryInstr::Call(idx, *argc));
292                }
293                Instr::Equal => {
294                    binary.code.push(BinaryInstr::Equal);
295                }
296                Instr::NotEqual => {
297                    binary.code.push(BinaryInstr::NotEqual);
298                }
299                Instr::Less => {
300                    binary.code.push(BinaryInstr::Less);
301                }
302                Instr::LessEqual => {
303                    binary.code.push(BinaryInstr::LessEqual);
304                }
305                Instr::Greater => {
306                    binary.code.push(BinaryInstr::Greater);
307                }
308                Instr::GreaterEqual => {
309                    binary.code.push(BinaryInstr::GreaterEqual);
310                }
311            }
312        }
313
314        binary
315    }
316}