expr_solver/
vm.rs

1use crate::ir::Instr;
2use crate::program::Program;
3use crate::symbol::{FuncError, Symbol};
4use rust_decimal::Decimal;
5use rust_decimal::prelude::*;
6use std::borrow::Cow;
7use thiserror::Error;
8
9#[cfg(test)]
10use rust_decimal_macros::dec;
11
12/// Virtual machine runtime errors.
13#[derive(Error, Debug, Clone)]
14pub enum VmError {
15    #[error("Stack underflow: attempted to pop from empty stack")]
16    StackUnderflow,
17    #[error("Division by zero")]
18    DivisionByZero,
19    #[error("Invalid stack state at program end: expected 1 element, found {count}")]
20    InvalidFinalStack { count: usize },
21    #[error("Invalid load operation: cannot load '{symbol_name}' as a constant")]
22    InvalidLoad { symbol_name: Cow<'static, str> },
23    #[error("Invalid call operation: cannot call '{symbol_name}' as a function")]
24    InvalidCall { symbol_name: Cow<'static, str> },
25    #[error(
26        "Stack underflow on function call '{function_name}': expected {expected} arguments, found {found}"
27    )]
28    CallStackUnderflow {
29        function_name: Cow<'static, str>,
30        expected: usize,
31        found: usize,
32    },
33    #[error("Invalid factorial: {value} (must be a non-negative integer)")]
34    InvalidFactorial { value: Decimal },
35    #[error("Arithmetic error: {message}")]
36    ArithmeticError { message: String },
37    #[error("Function error: {0}")]
38    FunctionError(FuncError),
39}
40
41/// A simple stack-based virtual machine for evaluating programs.
42#[derive(Debug, Default)]
43pub struct Vm;
44
45impl Vm {
46    /// Executes the given program and returns the result or a VmError.
47    pub fn run(&self, prog: &Program) -> Result<Decimal, VmError> {
48        if prog.code.is_empty() {
49            return Ok(Decimal::ZERO);
50        }
51
52        let mut stack: Vec<Decimal> = Vec::new();
53
54        for op in &prog.code {
55            self.execute_instruction(op, &mut stack)?;
56        }
57
58        match stack.as_slice() {
59            [result] => Ok(*result),
60            _ => Err(VmError::InvalidFinalStack { count: stack.len() }),
61        }
62    }
63
64    fn execute_instruction(&self, op: &Instr, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
65        match op {
66            Instr::Push(v) => {
67                stack.push(*v);
68                Ok(())
69            }
70            Instr::Load(sym) => match sym {
71                Symbol::Const { name: _, value, .. } => {
72                    stack.push(*value);
73                    Ok(())
74                }
75                _ => Err(VmError::InvalidLoad {
76                    symbol_name: Cow::Owned(sym.name().to_string()),
77                }),
78            },
79            Instr::Neg => {
80                let v = Self::pop(stack)?;
81                stack.push(-v);
82                Ok(())
83            }
84            Instr::Add => self.add_op(stack),
85            Instr::Sub => self.sub_op(stack),
86            Instr::Mul => self.mul_op(stack),
87            Instr::Div => self.div_op(stack),
88            Instr::Pow => self.pow_op(stack),
89            Instr::Fact => self.fact_op(stack),
90            Instr::Call(sym, argc) => self.call_op(sym, *argc, stack),
91            // Comparison operators
92            Instr::Equal => self.comparison_op(stack, |a, b| a == b),
93            Instr::NotEqual => self.comparison_op(stack, |a, b| a != b),
94            Instr::Less => self.comparison_op(stack, |a, b| a < b),
95            Instr::LessEqual => self.comparison_op(stack, |a, b| a <= b),
96            Instr::Greater => self.comparison_op(stack, |a, b| a > b),
97            Instr::GreaterEqual => self.comparison_op(stack, |a, b| a >= b),
98        }
99    }
100
101    fn comparison_op<F>(&self, stack: &mut Vec<Decimal>, f: F) -> Result<(), VmError>
102    where
103        F: FnOnce(Decimal, Decimal) -> bool,
104    {
105        let right = Self::pop(stack)?;
106        let left = Self::pop(stack)?;
107        let result = if f(left, right) {
108            Decimal::ONE
109        } else {
110            Decimal::ZERO
111        };
112        stack.push(result);
113        Ok(())
114    }
115
116    fn add_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
117        let right = Self::pop(stack)?;
118        let left = Self::pop(stack)?;
119        let result = left
120            .checked_add(right)
121            .ok_or_else(|| VmError::ArithmeticError {
122                message: format!("Addition overflow: {} + {}", left, right),
123            })?;
124        stack.push(result);
125        Ok(())
126    }
127
128    fn sub_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
129        let right = Self::pop(stack)?;
130        let left = Self::pop(stack)?;
131        let result = left
132            .checked_sub(right)
133            .ok_or_else(|| VmError::ArithmeticError {
134                message: format!("Subtraction overflow: {} - {}", left, right),
135            })?;
136        stack.push(result);
137        Ok(())
138    }
139
140    fn mul_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
141        let right = Self::pop(stack)?;
142        let left = Self::pop(stack)?;
143        let result = left
144            .checked_mul(right)
145            .ok_or_else(|| VmError::ArithmeticError {
146                message: format!("Multiplication overflow: {} * {}", left, right),
147            })?;
148        stack.push(result);
149        Ok(())
150    }
151
152    fn div_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
153        let right = Self::pop(stack)?;
154        let left = Self::pop(stack)?;
155        let result = left.checked_div(right).ok_or_else(|| {
156            if right.is_zero() {
157                VmError::DivisionByZero
158            } else {
159                VmError::ArithmeticError {
160                    message: format!("Division overflow or underflow: {} / {}", left, right),
161                }
162            }
163        })?;
164        stack.push(result);
165        Ok(())
166    }
167
168    fn pow_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
169        let exponent = Self::pop(stack)?;
170        let base = Self::pop(stack)?;
171
172        // Use Decimal's powd with error handling
173        let result =
174            match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| base.powd(exponent))) {
175                Ok(result) => result,
176                Err(_) => {
177                    return Err(VmError::ArithmeticError {
178                        message: format!("Power operation failed: {} ^ {}", base, exponent),
179                    });
180                }
181            };
182
183        stack.push(result);
184        Ok(())
185    }
186
187    fn fact_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
188        let n = Self::pop(stack)?;
189
190        // Check for negative numbers
191        if n.is_sign_negative() {
192            return Err(VmError::InvalidFactorial { value: n });
193        }
194
195        // Check for non-integer
196        if n.fract() != Decimal::ZERO {
197            return Err(VmError::InvalidFactorial { value: n });
198        }
199
200        // Calculate factorial using safe multiplication
201        let n_u64 = n.to_u64().unwrap();
202        let mut result = Decimal::ONE;
203        for i in 1..=n_u64 {
204            result =
205                result
206                    .checked_mul(Decimal::from(i))
207                    .ok_or_else(|| VmError::ArithmeticError {
208                        message: format!("Factorial calculation overflow at {}!", i),
209                    })?;
210        }
211
212        stack.push(result);
213        Ok(())
214    }
215
216    fn call_op(&self, sym: &Symbol, argc: usize, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
217        match sym {
218            Symbol::Func {
219                name,
220                args: min_args,
221                variadic,
222                callback,
223                ..
224            } => {
225                if argc != *min_args && (!*variadic || argc < *min_args) {
226                    return Err(VmError::CallStackUnderflow {
227                        function_name: name.clone(),
228                        expected: *min_args,
229                        found: argc,
230                    });
231                }
232
233                // Check if we have enough values on the stack
234                if stack.len() < argc {
235                    return Err(VmError::CallStackUnderflow {
236                        function_name: name.clone(),
237                        expected: argc,
238                        found: stack.len(),
239                    });
240                }
241
242                let args_start = stack.len() - argc;
243                let args = &stack[args_start..];
244                let result = callback(args).map_err(VmError::FunctionError)?;
245                stack.truncate(args_start);
246                stack.push(result);
247                Ok(())
248            }
249            Symbol::Const { .. } => Err(VmError::InvalidCall {
250                symbol_name: Cow::Owned(sym.name().to_string()),
251            }),
252        }
253    }
254
255    fn pop(stack: &mut Vec<Decimal>) -> Result<Decimal, VmError> {
256        stack.pop().ok_or(VmError::StackUnderflow)
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use crate::symbol::SymTable;
264    use std::borrow::Cow;
265
266    fn make(code: Vec<Instr>) -> Program {
267        let mut program = Program::new();
268        program.code = code;
269        program
270    }
271
272    #[test]
273    fn test_vm_error_stack_underflow() {
274        let vm = Vm::default();
275        let program = make(
276            vec![Instr::Add], // No values on stack
277        );
278
279        let result = vm.run(&program);
280        assert!(matches!(result, Err(VmError::StackUnderflow)));
281    }
282
283    #[test]
284    fn test_vm_error_division_by_zero() {
285        let vm = Vm::default();
286        let program = make(vec![Instr::Push(dec!(5)), Instr::Push(dec!(0)), Instr::Div]);
287
288        let result = vm.run(&program);
289        assert!(matches!(result, Err(VmError::DivisionByZero)));
290    }
291
292    #[test]
293    fn test_vm_error_invalid_final_stack() {
294        let vm = Vm::default();
295        let program = make(vec![
296            Instr::Push(dec!(1)),
297            Instr::Push(dec!(2)),
298            // No operation to combine them
299        ]);
300
301        let result = vm.run(&program);
302        assert!(matches!(
303            result,
304            Err(VmError::InvalidFinalStack { count: 2 })
305        ));
306    }
307
308    #[test]
309    fn test_vm_error_invalid_load() {
310        let vm = Vm::default();
311        let table = SymTable::stdlib();
312        let sin_func = table.get("sin").unwrap();
313
314        let program = make(
315            vec![Instr::Load(sin_func)], // Trying to load a function as constant
316        );
317
318        let result = vm.run(&program);
319        assert!(matches!(
320            result,
321            Err(VmError::InvalidLoad { symbol_name: _ })
322        ));
323    }
324
325    #[test]
326    fn test_vm_error_invalid_call() {
327        let vm = Vm::default();
328        let table = SymTable::stdlib();
329        let pi_const = table.get("pi").unwrap();
330
331        let program = make(
332            vec![Instr::Call(pi_const, 0)], // Trying to call a constant as function
333        );
334
335        let result = vm.run(&program);
336        assert!(matches!(
337            result,
338            Err(VmError::InvalidCall { symbol_name: _ })
339        ));
340    }
341
342    #[test]
343    fn test_vm_error_call_stack_underflow() {
344        let vm = Vm::default();
345        let table = SymTable::stdlib();
346        let sin_func = table.get("sin").unwrap();
347
348        let program = make(
349            vec![Instr::Call(sin_func, 0)], // No arguments for sin function
350        );
351
352        let result = vm.run(&program);
353        assert!(matches!(
354            result,
355            Err(VmError::CallStackUnderflow {
356                function_name: _,
357                expected: _,
358                found: _
359            })
360        ));
361    }
362
363    #[test]
364    fn test_vm_error_display() {
365        assert_eq!(
366            VmError::StackUnderflow.to_string(),
367            "Stack underflow: attempted to pop from empty stack"
368        );
369        assert_eq!(VmError::DivisionByZero.to_string(), "Division by zero");
370        assert_eq!(
371            VmError::InvalidFinalStack { count: 3 }.to_string(),
372            "Invalid stack state at program end: expected 1 element, found 3"
373        );
374        assert_eq!(
375            VmError::InvalidLoad {
376                symbol_name: Cow::Borrowed("test"),
377            }
378            .to_string(),
379            "Invalid load operation: cannot load 'test' as a constant"
380        );
381        assert_eq!(
382            VmError::InvalidCall {
383                symbol_name: Cow::Borrowed("test"),
384            }
385            .to_string(),
386            "Invalid call operation: cannot call 'test' as a function"
387        );
388        assert_eq!(
389            VmError::CallStackUnderflow {
390                function_name: Cow::Borrowed("sin"),
391                expected: 1,
392                found: 0
393            }
394            .to_string(),
395            "Stack underflow on function call 'sin': expected 1 arguments, found 0"
396        );
397    }
398
399    #[test]
400    fn test_binary_operations() {
401        let vm = Vm::default();
402
403        // Test all binary operations
404        let test_cases = vec![
405            (
406                vec![Instr::Push(dec!(6)), Instr::Push(dec!(2)), Instr::Sub],
407                dec!(4),
408            ),
409            (
410                vec![Instr::Push(dec!(3)), Instr::Push(dec!(4)), Instr::Mul],
411                dec!(12),
412            ),
413            (
414                vec![Instr::Push(dec!(8)), Instr::Push(dec!(2)), Instr::Div],
415                dec!(4),
416            ),
417        ];
418
419        for (code, expected) in test_cases {
420            let program = make(code);
421            assert_eq!(vm.run(&program).unwrap(), expected);
422        }
423    }
424}