expr_solver/
vm.rs

1use crate::ir::Instr;
2use crate::program::Program;
3use crate::symbol::{FuncError, Symbol};
4use rust_decimal::Decimal;
5use rust_decimal::prelude::*;
6use std::borrow::Cow;
7use thiserror::Error;
8
9#[cfg(test)]
10use rust_decimal_macros::dec;
11
12/// Virtual machine runtime errors.
13#[derive(Error, Debug, Clone)]
14pub enum VmError {
15    #[error("Stack underflow: attempted to pop from empty stack")]
16    StackUnderflow,
17    #[error("Division by zero")]
18    DivisionByZero,
19    #[error("Invalid stack state at program end: expected 1 element, found {count}")]
20    InvalidFinalStack { count: usize },
21    #[error("Invalid load operation: cannot load '{symbol_name}' as a constant")]
22    InvalidLoad { symbol_name: Cow<'static, str> },
23    #[error("Invalid call operation: cannot call '{symbol_name}' as a function")]
24    InvalidCall { symbol_name: Cow<'static, str> },
25    #[error(
26        "Stack underflow on function call '{function_name}': expected {expected} arguments, found {found}"
27    )]
28    CallStackUnderflow {
29        function_name: Cow<'static, str>,
30        expected: usize,
31        found: usize,
32    },
33    #[error("Invalid factorial: {value} (must be a non-negative integer)")]
34    InvalidFactorial { value: Decimal },
35    #[error("Arithmetic error: {message}")]
36    ArithmeticError { message: String },
37    #[error("Function error: {0}")]
38    FunctionError(FuncError),
39}
40
41/// Stack-based virtual machine for executing bytecode programs.
42///
43/// The VM evaluates programs by executing bytecode instructions on a stack,
44/// performing arithmetic operations and function calls.
45#[derive(Debug, Default)]
46pub struct Vm;
47
48impl Vm {
49    /// Executes a program and returns the result.
50    ///
51    /// # Errors
52    ///
53    /// Returns [`VmError`] if execution fails due to:
54    /// - Stack underflow
55    /// - Division by zero
56    /// - Invalid operations (e.g., factorial of non-integer)
57    /// - Function errors
58    pub fn run(&self, prog: &Program) -> Result<Decimal, VmError> {
59        if prog.code.is_empty() {
60            return Ok(Decimal::ZERO);
61        }
62
63        let mut stack: Vec<Decimal> = Vec::new();
64
65        for op in &prog.code {
66            self.execute_instruction(op, &mut stack)?;
67        }
68
69        match stack.as_slice() {
70            [result] => Ok(*result),
71            _ => Err(VmError::InvalidFinalStack { count: stack.len() }),
72        }
73    }
74
75    fn execute_instruction(&self, op: &Instr, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
76        match op {
77            Instr::Push(v) => {
78                stack.push(*v);
79                Ok(())
80            }
81            Instr::Load(sym) => match sym {
82                Symbol::Const { name: _, value, .. } => {
83                    stack.push(*value);
84                    Ok(())
85                }
86                _ => Err(VmError::InvalidLoad {
87                    symbol_name: Cow::Owned(sym.name().to_string()),
88                }),
89            },
90            Instr::Neg => {
91                let v = Self::pop(stack)?;
92                stack.push(-v);
93                Ok(())
94            }
95            Instr::Add => self.add_op(stack),
96            Instr::Sub => self.sub_op(stack),
97            Instr::Mul => self.mul_op(stack),
98            Instr::Div => self.div_op(stack),
99            Instr::Pow => self.pow_op(stack),
100            Instr::Fact => self.fact_op(stack),
101            Instr::Call(sym, argc) => self.call_op(sym, *argc, stack),
102            // Comparison operators
103            Instr::Equal => self.comparison_op(stack, |a, b| a == b),
104            Instr::NotEqual => self.comparison_op(stack, |a, b| a != b),
105            Instr::Less => self.comparison_op(stack, |a, b| a < b),
106            Instr::LessEqual => self.comparison_op(stack, |a, b| a <= b),
107            Instr::Greater => self.comparison_op(stack, |a, b| a > b),
108            Instr::GreaterEqual => self.comparison_op(stack, |a, b| a >= b),
109        }
110    }
111
112    fn comparison_op<F>(&self, stack: &mut Vec<Decimal>, f: F) -> Result<(), VmError>
113    where
114        F: FnOnce(Decimal, Decimal) -> bool,
115    {
116        let right = Self::pop(stack)?;
117        let left = Self::pop(stack)?;
118        let result = if f(left, right) {
119            Decimal::ONE
120        } else {
121            Decimal::ZERO
122        };
123        stack.push(result);
124        Ok(())
125    }
126
127    fn add_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
128        let right = Self::pop(stack)?;
129        let left = Self::pop(stack)?;
130        let result = left
131            .checked_add(right)
132            .ok_or_else(|| VmError::ArithmeticError {
133                message: format!("Addition overflow: {} + {}", left, right),
134            })?;
135        stack.push(result);
136        Ok(())
137    }
138
139    fn sub_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
140        let right = Self::pop(stack)?;
141        let left = Self::pop(stack)?;
142        let result = left
143            .checked_sub(right)
144            .ok_or_else(|| VmError::ArithmeticError {
145                message: format!("Subtraction overflow: {} - {}", left, right),
146            })?;
147        stack.push(result);
148        Ok(())
149    }
150
151    fn mul_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
152        let right = Self::pop(stack)?;
153        let left = Self::pop(stack)?;
154        let result = left
155            .checked_mul(right)
156            .ok_or_else(|| VmError::ArithmeticError {
157                message: format!("Multiplication overflow: {} * {}", left, right),
158            })?;
159        stack.push(result);
160        Ok(())
161    }
162
163    fn div_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
164        let right = Self::pop(stack)?;
165        let left = Self::pop(stack)?;
166        let result = left.checked_div(right).ok_or_else(|| {
167            if right.is_zero() {
168                VmError::DivisionByZero
169            } else {
170                VmError::ArithmeticError {
171                    message: format!("Division overflow or underflow: {} / {}", left, right),
172                }
173            }
174        })?;
175        stack.push(result);
176        Ok(())
177    }
178
179    fn pow_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
180        let exponent = Self::pop(stack)?;
181        let base = Self::pop(stack)?;
182
183        // Use Decimal's powd with error handling
184        let result =
185            match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| base.powd(exponent))) {
186                Ok(result) => result,
187                Err(_) => {
188                    return Err(VmError::ArithmeticError {
189                        message: format!("Power operation failed: {} ^ {}", base, exponent),
190                    });
191                }
192            };
193
194        stack.push(result);
195        Ok(())
196    }
197
198    fn fact_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
199        let n = Self::pop(stack)?;
200
201        // Check for negative numbers
202        if n.is_sign_negative() {
203            return Err(VmError::InvalidFactorial { value: n });
204        }
205
206        // Check for non-integer
207        if n.fract() != Decimal::ZERO {
208            return Err(VmError::InvalidFactorial { value: n });
209        }
210
211        // Calculate factorial using safe multiplication with iterator
212        let n_u64 = n.to_u64().unwrap();
213        let result = (1..=n_u64).try_fold(Decimal::ONE, |acc, i| {
214            acc.checked_mul(Decimal::from(i))
215                .ok_or_else(|| VmError::ArithmeticError {
216                    message: format!("Factorial calculation overflow at {}!", i),
217                })
218        })?;
219
220        stack.push(result);
221        Ok(())
222    }
223
224    fn call_op(&self, sym: &Symbol, argc: usize, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
225        match sym {
226            Symbol::Func {
227                name,
228                args: min_args,
229                variadic,
230                callback,
231                ..
232            } => {
233                if argc != *min_args && (!*variadic || argc < *min_args) {
234                    return Err(VmError::CallStackUnderflow {
235                        function_name: name.clone(),
236                        expected: *min_args,
237                        found: argc,
238                    });
239                }
240
241                // Check if we have enough values on the stack
242                if stack.len() < argc {
243                    return Err(VmError::CallStackUnderflow {
244                        function_name: name.clone(),
245                        expected: argc,
246                        found: stack.len(),
247                    });
248                }
249
250                let args_start = stack.len() - argc;
251                let args = &stack[args_start..];
252                let result = callback(args).map_err(VmError::FunctionError)?;
253                stack.truncate(args_start);
254                stack.push(result);
255                Ok(())
256            }
257            Symbol::Const { .. } => Err(VmError::InvalidCall {
258                symbol_name: Cow::Owned(sym.name().to_string()),
259            }),
260        }
261    }
262
263    fn pop(stack: &mut Vec<Decimal>) -> Result<Decimal, VmError> {
264        stack.pop().ok_or(VmError::StackUnderflow)
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use crate::symbol::SymTable;
272    use std::borrow::Cow;
273
274    fn make(code: Vec<Instr>) -> Program {
275        let mut program = Program::new();
276        program.code = code;
277        program
278    }
279
280    #[test]
281    fn test_vm_error_stack_underflow() {
282        let vm = Vm::default();
283        let program = make(
284            vec![Instr::Add], // No values on stack
285        );
286
287        let result = vm.run(&program);
288        assert!(matches!(result, Err(VmError::StackUnderflow)));
289    }
290
291    #[test]
292    fn test_vm_error_division_by_zero() {
293        let vm = Vm::default();
294        let program = make(vec![Instr::Push(dec!(5)), Instr::Push(dec!(0)), Instr::Div]);
295
296        let result = vm.run(&program);
297        assert!(matches!(result, Err(VmError::DivisionByZero)));
298    }
299
300    #[test]
301    fn test_vm_error_invalid_final_stack() {
302        let vm = Vm::default();
303        let program = make(vec![
304            Instr::Push(dec!(1)),
305            Instr::Push(dec!(2)),
306            // No operation to combine them
307        ]);
308
309        let result = vm.run(&program);
310        assert!(matches!(
311            result,
312            Err(VmError::InvalidFinalStack { count: 2 })
313        ));
314    }
315
316    #[test]
317    fn test_vm_error_invalid_load() {
318        let vm = Vm::default();
319        let table = SymTable::stdlib();
320        let sin_func = table.get("sin").unwrap();
321
322        let program = make(
323            vec![Instr::Load(sin_func)], // Trying to load a function as constant
324        );
325
326        let result = vm.run(&program);
327        assert!(matches!(
328            result,
329            Err(VmError::InvalidLoad { symbol_name: _ })
330        ));
331    }
332
333    #[test]
334    fn test_vm_error_invalid_call() {
335        let vm = Vm::default();
336        let table = SymTable::stdlib();
337        let pi_const = table.get("pi").unwrap();
338
339        let program = make(
340            vec![Instr::Call(pi_const, 0)], // Trying to call a constant as function
341        );
342
343        let result = vm.run(&program);
344        assert!(matches!(
345            result,
346            Err(VmError::InvalidCall { symbol_name: _ })
347        ));
348    }
349
350    #[test]
351    fn test_vm_error_call_stack_underflow() {
352        let vm = Vm::default();
353        let table = SymTable::stdlib();
354        let sin_func = table.get("sin").unwrap();
355
356        let program = make(
357            vec![Instr::Call(sin_func, 0)], // No arguments for sin function
358        );
359
360        let result = vm.run(&program);
361        assert!(matches!(
362            result,
363            Err(VmError::CallStackUnderflow {
364                function_name: _,
365                expected: _,
366                found: _
367            })
368        ));
369    }
370
371    #[test]
372    fn test_vm_error_display() {
373        assert_eq!(
374            VmError::StackUnderflow.to_string(),
375            "Stack underflow: attempted to pop from empty stack"
376        );
377        assert_eq!(VmError::DivisionByZero.to_string(), "Division by zero");
378        assert_eq!(
379            VmError::InvalidFinalStack { count: 3 }.to_string(),
380            "Invalid stack state at program end: expected 1 element, found 3"
381        );
382        assert_eq!(
383            VmError::InvalidLoad {
384                symbol_name: Cow::Borrowed("test"),
385            }
386            .to_string(),
387            "Invalid load operation: cannot load 'test' as a constant"
388        );
389        assert_eq!(
390            VmError::InvalidCall {
391                symbol_name: Cow::Borrowed("test"),
392            }
393            .to_string(),
394            "Invalid call operation: cannot call 'test' as a function"
395        );
396        assert_eq!(
397            VmError::CallStackUnderflow {
398                function_name: Cow::Borrowed("sin"),
399                expected: 1,
400                found: 0
401            }
402            .to_string(),
403            "Stack underflow on function call 'sin': expected 1 arguments, found 0"
404        );
405    }
406
407    #[test]
408    fn test_binary_operations() {
409        let vm = Vm::default();
410
411        // Test all binary operations
412        let test_cases = vec![
413            (
414                vec![Instr::Push(dec!(6)), Instr::Push(dec!(2)), Instr::Sub],
415                dec!(4),
416            ),
417            (
418                vec![Instr::Push(dec!(3)), Instr::Push(dec!(4)), Instr::Mul],
419                dec!(12),
420            ),
421            (
422                vec![Instr::Push(dec!(8)), Instr::Push(dec!(2)), Instr::Div],
423                dec!(4),
424            ),
425        ];
426
427        for (code, expected) in test_cases {
428            let program = make(code);
429            assert_eq!(vm.run(&program).unwrap(), expected);
430        }
431    }
432}