expr_solver/
vm.rs

1use crate::ir::Instr;
2use crate::symbol::{FuncError, SymTable, Symbol};
3use rust_decimal::Decimal;
4use rust_decimal::prelude::*;
5use thiserror::Error;
6
7#[cfg(test)]
8use rust_decimal_macros::dec;
9
10/// Virtual machine runtime errors.
11#[derive(Error, Debug, Clone)]
12pub enum VmError {
13    #[error("Stack underflow: attempted to pop from empty stack")]
14    StackUnderflow,
15    #[error("Division by zero")]
16    DivisionByZero,
17    #[error("Invalid stack state at program end: expected 1 element, found {count}")]
18    InvalidFinalStack { count: usize },
19    #[error("Invalid load operation: cannot load '{symbol_name}' as a constant")]
20    InvalidLoad { symbol_name: String },
21    #[error("Invalid call operation: cannot call '{symbol_name}' as a function")]
22    InvalidCall { symbol_name: String },
23    #[error(
24        "Stack underflow on function call '{function_name}': expected {expected} arguments, found {found}"
25    )]
26    CallStackUnderflow {
27        function_name: String,
28        expected: usize,
29        found: usize,
30    },
31    #[error("Invalid factorial: {value} (must be a non-negative integer)")]
32    InvalidFactorial { value: Decimal },
33    #[error("Arithmetic error: {message}")]
34    ArithmeticError { message: String },
35    #[error("Function error: {0}")]
36    FunctionError(FuncError),
37    #[error("Invalid symbol index: {0}")]
38    InvalidSymbolIndex(usize),
39}
40
41/// Stack-based virtual machine for executing bytecode programs.
42///
43/// The VM evaluates programs by executing bytecode instructions on a stack,
44/// performing arithmetic operations and function calls.
45#[derive(Debug, Default)]
46pub struct Vm;
47
48impl Vm {
49    /// Executes bytecode directly and returns the result.
50    ///
51    /// # Errors
52    ///
53    /// Returns [`VmError`] if execution fails due to:
54    /// - Stack underflow
55    /// - Division by zero
56    /// - Invalid operations (e.g., factorial of non-integer)
57    /// - Function errors
58    /// - Invalid symbol indices
59    pub fn run_bytecode(&self, bytecode: &[Instr], table: &SymTable) -> Result<Decimal, VmError> {
60        if bytecode.is_empty() {
61            return Ok(Decimal::ZERO);
62        }
63
64        let mut stack: Vec<Decimal> = Vec::new();
65
66        for op in bytecode {
67            self.execute_instruction(op, table, &mut stack)?;
68        }
69
70        match stack.as_slice() {
71            [result] => Ok(*result),
72            _ => Err(VmError::InvalidFinalStack { count: stack.len() }),
73        }
74    }
75
76    fn execute_instruction(
77        &self,
78        op: &Instr,
79        table: &SymTable,
80        stack: &mut Vec<Decimal>,
81    ) -> Result<(), VmError> {
82        match op {
83            Instr::Push(v) => {
84                stack.push(*v);
85                Ok(())
86            }
87            Instr::Load(idx) => {
88                let sym = table
89                    .get_by_index(*idx)
90                    .ok_or(VmError::InvalidSymbolIndex(*idx))?;
91                match sym {
92                    Symbol::Const { value, .. } => {
93                        stack.push(*value);
94                        Ok(())
95                    }
96                    _ => Err(VmError::InvalidLoad {
97                        symbol_name: sym.name().to_string(),
98                    }),
99                }
100            }
101            Instr::Neg => {
102                let v = Self::pop(stack)?;
103                stack.push(-v);
104                Ok(())
105            }
106            Instr::Add => self.add_op(stack),
107            Instr::Sub => self.sub_op(stack),
108            Instr::Mul => self.mul_op(stack),
109            Instr::Div => self.div_op(stack),
110            Instr::Pow => self.pow_op(stack),
111            Instr::Fact => self.fact_op(stack),
112            Instr::Call(idx, argc) => {
113                let sym = table
114                    .get_by_index(*idx)
115                    .ok_or(VmError::InvalidSymbolIndex(*idx))?;
116                self.call_op(sym, *argc, stack)
117            }
118            // Comparison operators
119            Instr::Equal => self.comparison_op(stack, |a, b| a == b),
120            Instr::NotEqual => self.comparison_op(stack, |a, b| a != b),
121            Instr::Less => self.comparison_op(stack, |a, b| a < b),
122            Instr::LessEqual => self.comparison_op(stack, |a, b| a <= b),
123            Instr::Greater => self.comparison_op(stack, |a, b| a > b),
124            Instr::GreaterEqual => self.comparison_op(stack, |a, b| a >= b),
125        }
126    }
127
128    fn comparison_op<F>(&self, stack: &mut Vec<Decimal>, f: F) -> Result<(), VmError>
129    where
130        F: FnOnce(Decimal, Decimal) -> bool,
131    {
132        let right = Self::pop(stack)?;
133        let left = Self::pop(stack)?;
134        let result = if f(left, right) {
135            Decimal::ONE
136        } else {
137            Decimal::ZERO
138        };
139        stack.push(result);
140        Ok(())
141    }
142
143    fn add_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
144        let right = Self::pop(stack)?;
145        let left = Self::pop(stack)?;
146        let result = left
147            .checked_add(right)
148            .ok_or_else(|| VmError::ArithmeticError {
149                message: format!("Addition overflow: {} + {}", left, right),
150            })?;
151        stack.push(result);
152        Ok(())
153    }
154
155    fn sub_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
156        let right = Self::pop(stack)?;
157        let left = Self::pop(stack)?;
158        let result = left
159            .checked_sub(right)
160            .ok_or_else(|| VmError::ArithmeticError {
161                message: format!("Subtraction overflow: {} - {}", left, right),
162            })?;
163        stack.push(result);
164        Ok(())
165    }
166
167    fn mul_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
168        let right = Self::pop(stack)?;
169        let left = Self::pop(stack)?;
170        let result = left
171            .checked_mul(right)
172            .ok_or_else(|| VmError::ArithmeticError {
173                message: format!("Multiplication overflow: {} * {}", left, right),
174            })?;
175        stack.push(result);
176        Ok(())
177    }
178
179    fn div_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
180        let right = Self::pop(stack)?;
181        let left = Self::pop(stack)?;
182        let result = left.checked_div(right).ok_or_else(|| {
183            if right.is_zero() {
184                VmError::DivisionByZero
185            } else {
186                VmError::ArithmeticError {
187                    message: format!("Division overflow or underflow: {} / {}", left, right),
188                }
189            }
190        })?;
191        stack.push(result);
192        Ok(())
193    }
194
195    fn pow_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
196        let exponent = Self::pop(stack)?;
197        let base = Self::pop(stack)?;
198
199        // Use Decimal's powd with error handling
200        let result =
201            match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| base.powd(exponent))) {
202                Ok(result) => result,
203                Err(_) => {
204                    return Err(VmError::ArithmeticError {
205                        message: format!("Power operation failed: {} ^ {}", base, exponent),
206                    });
207                }
208            };
209
210        stack.push(result);
211        Ok(())
212    }
213
214    fn fact_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
215        let n = Self::pop(stack)?;
216
217        // Check for negative numbers
218        if n.is_sign_negative() {
219            return Err(VmError::InvalidFactorial { value: n });
220        }
221
222        // Check for non-integer
223        if n.fract() != Decimal::ZERO {
224            return Err(VmError::InvalidFactorial { value: n });
225        }
226
227        // Calculate factorial using safe multiplication with iterator
228        let n_u64 = n.to_u64().unwrap();
229        let result = (1..=n_u64).try_fold(Decimal::ONE, |acc, i| {
230            acc.checked_mul(Decimal::from(i))
231                .ok_or_else(|| VmError::ArithmeticError {
232                    message: format!("Factorial calculation overflow at {}!", i),
233                })
234        })?;
235
236        stack.push(result);
237        Ok(())
238    }
239
240    fn call_op(&self, sym: &Symbol, argc: usize, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
241        match sym {
242            Symbol::Func {
243                name,
244                args: min_args,
245                variadic,
246                callback,
247                ..
248            } => {
249                if argc != *min_args && (!*variadic || argc < *min_args) {
250                    return Err(VmError::CallStackUnderflow {
251                        function_name: name.to_string(),
252                        expected: *min_args,
253                        found: argc,
254                    });
255                }
256
257                // Check if we have enough values on the stack
258                if stack.len() < argc {
259                    return Err(VmError::CallStackUnderflow {
260                        function_name: name.to_string(),
261                        expected: argc,
262                        found: stack.len(),
263                    });
264                }
265
266                let args_start = stack.len() - argc;
267                let args = &stack[args_start..];
268                let result = callback(args).map_err(VmError::FunctionError)?;
269                stack.truncate(args_start);
270                stack.push(result);
271                Ok(())
272            }
273            Symbol::Const { .. } => Err(VmError::InvalidCall {
274                symbol_name: sym.name().to_string(),
275            }),
276        }
277    }
278
279    fn pop(stack: &mut Vec<Decimal>) -> Result<Decimal, VmError> {
280        stack.pop().ok_or(VmError::StackUnderflow)
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287    use crate::symbol::SymTable;
288
289    #[test]
290    fn test_vm_error_stack_underflow() {
291        let vm = Vm;
292        let table = SymTable::stdlib();
293        let bytecode = vec![Instr::Add]; // No values on stack
294
295        let result = vm.run_bytecode(&bytecode, &table);
296        assert!(matches!(result, Err(VmError::StackUnderflow)));
297    }
298
299    #[test]
300    fn test_vm_error_division_by_zero() {
301        let vm = Vm;
302        let table = SymTable::stdlib();
303        let bytecode = vec![Instr::Push(dec!(5)), Instr::Push(dec!(0)), Instr::Div];
304
305        let result = vm.run_bytecode(&bytecode, &table);
306        assert!(matches!(result, Err(VmError::DivisionByZero)));
307    }
308
309    #[test]
310    fn test_vm_error_invalid_final_stack() {
311        let vm = Vm;
312        let table = SymTable::stdlib();
313        let bytecode = vec![
314            Instr::Push(dec!(1)),
315            Instr::Push(dec!(2)),
316            // No operation to combine them
317        ];
318
319        let result = vm.run_bytecode(&bytecode, &table);
320        assert!(matches!(
321            result,
322            Err(VmError::InvalidFinalStack { count: 2 })
323        ));
324    }
325
326    #[test]
327    fn test_vm_error_invalid_load() {
328        let vm = Vm;
329        let table = SymTable::stdlib();
330        let (sin_idx, _) = table.get_with_index("sin").unwrap();
331
332        let bytecode = vec![Instr::Load(sin_idx)]; // Trying to load a function as constant
333
334        let result = vm.run_bytecode(&bytecode, &table);
335        assert!(matches!(
336            result,
337            Err(VmError::InvalidLoad { symbol_name: _ })
338        ));
339    }
340
341    #[test]
342    fn test_vm_error_invalid_call() {
343        let vm = Vm;
344        let table = SymTable::stdlib();
345        let (pi_idx, _) = table.get_with_index("pi").unwrap();
346
347        let bytecode = vec![Instr::Call(pi_idx, 0)]; // Trying to call a constant as function
348
349        let result = vm.run_bytecode(&bytecode, &table);
350        assert!(matches!(
351            result,
352            Err(VmError::InvalidCall { symbol_name: _ })
353        ));
354    }
355
356    #[test]
357    fn test_vm_error_call_stack_underflow() {
358        let vm = Vm;
359        let table = SymTable::stdlib();
360        let (sin_idx, _) = table.get_with_index("sin").unwrap();
361
362        let bytecode = vec![Instr::Call(sin_idx, 0)]; // No arguments for sin function
363
364        let result = vm.run_bytecode(&bytecode, &table);
365        assert!(matches!(
366            result,
367            Err(VmError::CallStackUnderflow {
368                function_name: _,
369                expected: _,
370                found: _
371            })
372        ));
373    }
374
375    #[test]
376    fn test_vm_error_display() {
377        assert_eq!(
378            VmError::StackUnderflow.to_string(),
379            "Stack underflow: attempted to pop from empty stack"
380        );
381        assert_eq!(VmError::DivisionByZero.to_string(), "Division by zero");
382        assert_eq!(
383            VmError::InvalidFinalStack { count: 3 }.to_string(),
384            "Invalid stack state at program end: expected 1 element, found 3"
385        );
386        assert_eq!(
387            VmError::InvalidLoad {
388                symbol_name: "test".to_string(),
389            }
390            .to_string(),
391            "Invalid load operation: cannot load 'test' as a constant"
392        );
393        assert_eq!(
394            VmError::InvalidCall {
395                symbol_name: "test".to_string(),
396            }
397            .to_string(),
398            "Invalid call operation: cannot call 'test' as a function"
399        );
400        assert_eq!(
401            VmError::CallStackUnderflow {
402                function_name: "sin".to_string(),
403                expected: 1,
404                found: 0
405            }
406            .to_string(),
407            "Stack underflow on function call 'sin': expected 1 arguments, found 0"
408        );
409    }
410
411    #[test]
412    fn test_binary_operations() {
413        let vm = Vm;
414        let table = SymTable::stdlib();
415
416        // Test all binary operations
417        let test_cases = vec![
418            (
419                vec![Instr::Push(dec!(6)), Instr::Push(dec!(2)), Instr::Sub],
420                dec!(4),
421            ),
422            (
423                vec![Instr::Push(dec!(3)), Instr::Push(dec!(4)), Instr::Mul],
424                dec!(12),
425            ),
426            (
427                vec![Instr::Push(dec!(8)), Instr::Push(dec!(2)), Instr::Div],
428                dec!(4),
429            ),
430        ];
431
432        for (code, expected) in test_cases {
433            assert_eq!(vm.run_bytecode(&code, &table).unwrap(), expected);
434        }
435    }
436}