expr_solver/
vm.rs

1use crate::ir::Instr;
2use crate::number::Number;
3use crate::symbol::{FuncError, Symbol};
4use crate::symtable::SymTable;
5use thiserror::Error;
6
7/// Virtual machine runtime errors.
8#[derive(Error, Debug, Clone)]
9pub enum VmError {
10    #[error("Stack underflow: attempted to pop from empty stack")]
11    StackUnderflow,
12    #[error("Division by zero")]
13    DivisionByZero,
14    #[error("Invalid stack state at program end: expected 1 element, found {count}")]
15    InvalidFinalStack { count: usize },
16    #[error("Invalid factorial: {value} (must be a non-negative integer)")]
17    InvalidFactorial { value: Number },
18    #[error("Arithmetic error: {message}")]
19    ArithmeticError { message: String },
20    #[error("Function error: {0}")]
21    FunctionError(FuncError),
22}
23
24/// Stack-based virtual machine for executing bytecode programs.
25///
26/// The VM evaluates programs by executing bytecode instructions on a stack,
27/// performing arithmetic operations and function calls.
28///
29/// ## Error Handling
30///
31/// Behavior varies based on the numeric backend:
32///
33/// - **f64 mode**: Relaxed error handling. Only catches errors that would cause panics.
34///   Allows `Inf` and `NaN` results from operations like `1/0` or `sqrt(-1)`.
35///
36/// - **Decimal mode**: Strict error handling. All arithmetic operations are checked
37///   for overflow/underflow. Returns errors for domain violations.
38#[derive(Debug)]
39pub struct Vm<'vm> {
40    bytecode: &'vm [Instr],
41    symtable: &'vm mut SymTable,
42    stack: Vec<Number>,
43    ip: usize,
44}
45
46impl<'vm> Vm<'vm> {
47    /// Executes bytecode and returns the result.
48    ///
49    /// # Errors
50    ///
51    /// Returns [`VmError`] if execution fails due to:
52    /// - Stack underflow
53    /// - Division by zero
54    /// - Invalid operations (e.g., factorial of non-integer)
55    /// - Function errors
56    /// - Invalid symbol indices
57    /// - Invalid jumps
58    pub fn run(bytecode: &'vm [Instr], symtable: &'vm mut SymTable) -> Result<Number, VmError> {
59        use crate::number::consts;
60
61        if bytecode.is_empty() {
62            return Ok(consts::ZERO);
63        }
64
65        let mut vm = Vm {
66            bytecode,
67            symtable,
68            stack: Vec::new(),
69            ip: 0,
70        };
71
72        vm.execute()?;
73
74        match vm.stack.as_slice() {
75            [result] => Ok(*result),
76            _ => Err(VmError::InvalidFinalStack {
77                count: vm.stack.len(),
78            }),
79        }
80    }
81
82    fn execute(&mut self) -> Result<(), VmError> {
83        use crate::number::consts;
84
85        while self.ip < self.bytecode.len() {
86            let op = &self.bytecode[self.ip];
87
88            match op {
89                Instr::Jmp(target) => {
90                    self.ip = *target;
91                    continue;
92                }
93                Instr::Jz(target) => {
94                    let cond = self.pop()?;
95                    if cond == consts::ZERO {
96                        self.ip = *target;
97                        continue;
98                    }
99                }
100                Instr::Push(v) => {
101                    self.stack.push(*v);
102                }
103                Instr::Load(idx) => {
104                    let sym = self.symtable.get_by_index(*idx).unwrap();
105                    match sym {
106                        Symbol::Const { value, .. } => {
107                            self.stack.push(*value);
108                        }
109                        _ => unreachable!(),
110                    }
111                }
112                Instr::Store(idx) => {
113                    let top = self.pop()?;
114                    let sym = self.symtable.get_mut_by_index(*idx).unwrap();
115                    match sym {
116                        Symbol::Const { value, .. } => {
117                            *value = top;
118                        }
119                        _ => unreachable!(),
120                    }
121                }
122                Instr::Neg => {
123                    let v = self.pop()?;
124                    self.stack.push(-v);
125                }
126                Instr::Add => self.add_op()?,
127                Instr::Sub => self.sub_op()?,
128                Instr::Mul => self.mul_op()?,
129                Instr::Div => self.div_op()?,
130                Instr::Pow => self.pow_op()?,
131                Instr::Fact => self.fact_op()?,
132                Instr::Call(idx, argc) => self.call_op(*idx, *argc)?,
133                Instr::Equal => self.comparison_op(|a, b| a == b)?,
134                Instr::NotEqual => self.comparison_op(|a, b| a != b)?,
135                Instr::Less => self.comparison_op(|a, b| a < b)?,
136                Instr::LessEqual => self.comparison_op(|a, b| a <= b)?,
137                Instr::Greater => self.comparison_op(|a, b| a > b)?,
138                Instr::GreaterEqual => self.comparison_op(|a, b| a >= b)?,
139            }
140
141            self.ip += 1;
142        }
143        Ok(())
144    }
145
146    fn comparison_op<F>(&mut self, f: F) -> Result<(), VmError>
147    where
148        F: FnOnce(Number, Number) -> bool,
149    {
150        use crate::number::consts;
151
152        let right = self.pop()?;
153        let left = self.pop()?;
154        let result = if f(left, right) {
155            consts::ONE
156        } else {
157            consts::ZERO
158        };
159        self.stack.push(result);
160        Ok(())
161    }
162
163    // Decimal mode: use checked arithmetic for safety
164    #[cfg(feature = "decimal-precision")]
165    fn add_op(&mut self) -> Result<(), VmError> {
166        let right = self.pop()?;
167        let left = self.pop()?;
168        let result = left
169            .checked_add(right)
170            .ok_or_else(|| VmError::ArithmeticError {
171                message: format!("Addition overflow: {} + {}", left, right),
172            })?;
173        self.stack.push(result);
174        Ok(())
175    }
176
177    // f64 mode: use simple arithmetic (Inf/NaN allowed)
178    #[cfg(feature = "f64-floats")]
179    fn add_op(&mut self) -> Result<(), VmError> {
180        let right = self.pop()?;
181        let left = self.pop()?;
182        self.stack.push(left + right);
183        Ok(())
184    }
185
186    #[cfg(feature = "decimal-precision")]
187    fn sub_op(&mut self) -> Result<(), VmError> {
188        let right = self.pop()?;
189        let left = self.pop()?;
190        let result = left
191            .checked_sub(right)
192            .ok_or_else(|| VmError::ArithmeticError {
193                message: format!("Subtraction overflow: {} - {}", left, right),
194            })?;
195        self.stack.push(result);
196        Ok(())
197    }
198
199    #[cfg(feature = "f64-floats")]
200    fn sub_op(&mut self) -> Result<(), VmError> {
201        let right = self.pop()?;
202        let left = self.pop()?;
203        self.stack.push(left - right);
204        Ok(())
205    }
206
207    #[cfg(feature = "decimal-precision")]
208    fn mul_op(&mut self) -> Result<(), VmError> {
209        let right = self.pop()?;
210        let left = self.pop()?;
211        let result = left
212            .checked_mul(right)
213            .ok_or_else(|| VmError::ArithmeticError {
214                message: format!("Multiplication overflow: {} * {}", left, right),
215            })?;
216        self.stack.push(result);
217        Ok(())
218    }
219
220    #[cfg(feature = "f64-floats")]
221    fn mul_op(&mut self) -> Result<(), VmError> {
222        let right = self.pop()?;
223        let left = self.pop()?;
224        self.stack.push(left * right);
225        Ok(())
226    }
227
228    #[cfg(feature = "decimal-precision")]
229    fn div_op(&mut self) -> Result<(), VmError> {
230        use crate::number::consts;
231
232        let right = self.pop()?;
233        let left = self.pop()?;
234        let result = left.checked_div(right).ok_or_else(|| {
235            if right == consts::ZERO {
236                VmError::DivisionByZero
237            } else {
238                VmError::ArithmeticError {
239                    message: format!("Division overflow or underflow: {} / {}", left, right),
240                }
241            }
242        })?;
243        self.stack.push(result);
244        Ok(())
245    }
246
247    #[cfg(feature = "f64-floats")]
248    fn div_op(&mut self) -> Result<(), VmError> {
249        let right = self.pop()?;
250        let left = self.pop()?;
251        // Check for division by zero to prevent undefined behavior
252        if right == 0.0 {
253            return Err(VmError::DivisionByZero);
254        }
255        self.stack.push(left / right);
256        Ok(())
257    }
258
259    #[cfg(feature = "decimal-precision")]
260    fn pow_op(&mut self) -> Result<(), VmError> {
261        use rust_decimal::prelude::{FromPrimitive, ToPrimitive};
262
263        let exponent = self.pop()?;
264        let base = self.pop()?;
265
266        // Convert to f64, compute power, convert back
267        // This is a limitation of rust_decimal which doesn't have decimal exponent support
268        let base_f64 = base.to_f64().ok_or_else(|| VmError::ArithmeticError {
269            message: format!("Failed to convert base {} to f64", base),
270        })?;
271        let exp_f64 = exponent.to_f64().ok_or_else(|| VmError::ArithmeticError {
272            message: format!("Failed to convert exponent {} to f64", exponent),
273        })?;
274
275        let result_f64 = base_f64.powf(exp_f64);
276
277        let result = Number::from_f64(result_f64).ok_or_else(|| VmError::ArithmeticError {
278            message: format!(
279                "Power operation result cannot be represented: {} ^ {}",
280                base, exponent
281            ),
282        })?;
283
284        self.stack.push(result);
285        Ok(())
286    }
287
288    #[cfg(feature = "f64-floats")]
289    fn pow_op(&mut self) -> Result<(), VmError> {
290        let exponent = self.pop()?;
291        let base = self.pop()?;
292        self.stack.push(base.powf(exponent));
293        Ok(())
294    }
295
296    #[cfg(feature = "decimal-precision")]
297    fn fact_op(&mut self) -> Result<(), VmError> {
298        use crate::number::consts;
299        use rust_decimal::prelude::*;
300
301        let n = self.pop()?;
302
303        // Check for negative numbers
304        if n.is_sign_negative() {
305            return Err(VmError::InvalidFactorial { value: n });
306        }
307
308        // Check for non-integer
309        if n.fract() != consts::ZERO {
310            return Err(VmError::InvalidFactorial { value: n });
311        }
312
313        // Calculate factorial using safe multiplication with iterator
314        let n_u64 = n.to_u64().unwrap();
315        let result = (1..=n_u64).try_fold(consts::ONE, |acc, i| {
316            acc.checked_mul(Number::from(i))
317                .ok_or_else(|| VmError::ArithmeticError {
318                    message: format!("Factorial calculation overflow at {}!", i),
319                })
320        })?;
321
322        self.stack.push(result);
323        Ok(())
324    }
325
326    #[cfg(feature = "f64-floats")]
327    fn fact_op(&mut self) -> Result<(), VmError> {
328        let n = self.pop()?;
329
330        // Check for negative numbers
331        if n < 0.0 {
332            return Err(VmError::InvalidFactorial { value: n });
333        }
334
335        // Check for non-integer
336        if n.fract() != 0.0 {
337            return Err(VmError::InvalidFactorial { value: n });
338        }
339
340        // Calculate factorial
341        let n_u64 = n as u64;
342        let mut result = 1.0;
343        for i in 1..=n_u64 {
344            result *= i as f64;
345        }
346
347        self.stack.push(result);
348        Ok(())
349    }
350
351    fn call_op(&mut self, idx: usize, argc: usize) -> Result<(), VmError> {
352        match self.symtable.get_by_index(idx).unwrap() {
353            Symbol::Func { callback, .. } => {
354                let args_start = self.stack.len() - argc;
355                let args = &self.stack[args_start..];
356                let result = callback(args).map_err(VmError::FunctionError)?;
357                self.stack.truncate(args_start);
358                self.stack.push(result);
359                Ok(())
360            }
361            Symbol::Const { .. } => unreachable!(),
362        }
363    }
364
365    fn pop(&mut self) -> Result<Number, VmError> {
366        self.stack.pop().ok_or(VmError::StackUnderflow)
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373    use crate::num;
374    use crate::symtable::SymTable;
375
376    #[test]
377    fn test_vm_error_stack_underflow() {
378        let mut table = SymTable::stdlib();
379        let bytecode = vec![Instr::Add]; // No values on stack
380
381        let result = Vm::run(&bytecode, &mut table);
382        assert!(matches!(result, Err(VmError::StackUnderflow)));
383    }
384
385    #[test]
386    fn test_vm_error_division_by_zero() {
387        let mut table = SymTable::stdlib();
388        let bytecode = vec![Instr::Push(num!(5)), Instr::Push(num!(0)), Instr::Div];
389
390        let result = Vm::run(&bytecode, &mut table);
391        assert!(matches!(result, Err(VmError::DivisionByZero)));
392    }
393
394    #[test]
395    fn test_vm_error_invalid_final_stack() {
396        let mut table = SymTable::stdlib();
397        let bytecode = vec![
398            Instr::Push(num!(1)),
399            Instr::Push(num!(2)),
400            // No operation to combine them
401        ];
402
403        let result = Vm::run(&bytecode, &mut table);
404        assert!(matches!(
405            result,
406            Err(VmError::InvalidFinalStack { count: 2 })
407        ));
408    }
409
410    #[test]
411    fn test_vm_error_display() {
412        assert_eq!(
413            VmError::StackUnderflow.to_string(),
414            "Stack underflow: attempted to pop from empty stack"
415        );
416        assert_eq!(VmError::DivisionByZero.to_string(), "Division by zero");
417        assert_eq!(
418            VmError::InvalidFinalStack { count: 3 }.to_string(),
419            "Invalid stack state at program end: expected 1 element, found 3"
420        );
421    }
422
423    #[test]
424    fn test_binary_operations() {
425        let mut table = SymTable::stdlib();
426
427        // Test all binary operations
428        let test_cases = vec![
429            (
430                vec![Instr::Push(num!(6)), Instr::Push(num!(2)), Instr::Sub],
431                num!(4),
432            ),
433            (
434                vec![Instr::Push(num!(3)), Instr::Push(num!(4)), Instr::Mul],
435                num!(12),
436            ),
437            (
438                vec![Instr::Push(num!(8)), Instr::Push(num!(2)), Instr::Div],
439                num!(4),
440            ),
441        ];
442
443        for (code, expected) in test_cases {
444            let result = Vm::run(&code, &mut table).unwrap();
445            assert_eq!(result, expected);
446        }
447    }
448}