1use crate::ir::Instr;
2use crate::program::Program;
3use crate::symbol::{FuncError, Symbol};
4use rust_decimal::Decimal;
5use rust_decimal::prelude::*;
6use std::borrow::Cow;
7use thiserror::Error;
8
9#[cfg(test)]
10use rust_decimal_macros::dec;
11
12#[derive(Error, Debug, Clone)]
14pub enum VmError {
15 #[error("Stack underflow: attempted to pop from empty stack")]
16 StackUnderflow,
17 #[error("Division by zero")]
18 DivisionByZero,
19 #[error("Invalid stack state at program end: expected 1 element, found {count}")]
20 InvalidFinalStack { count: usize },
21 #[error("Invalid load operation: cannot load '{symbol_name}' as a constant")]
22 InvalidLoad { symbol_name: Cow<'static, str> },
23 #[error("Invalid call operation: cannot call '{symbol_name}' as a function")]
24 InvalidCall { symbol_name: Cow<'static, str> },
25 #[error(
26 "Stack underflow on function call '{function_name}': expected {expected} arguments, found {found}"
27 )]
28 CallStackUnderflow {
29 function_name: Cow<'static, str>,
30 expected: usize,
31 found: usize,
32 },
33 #[error("Invalid factorial: {value} (must be a non-negative integer)")]
34 InvalidFactorial { value: Decimal },
35 #[error("Arithmetic error: {message}")]
36 ArithmeticError { message: String },
37 #[error("Function error: {0}")]
38 FunctionError(FuncError),
39}
40
41#[derive(Debug, Default)]
46pub struct Vm;
47
48impl Vm {
49 pub fn run(&self, prog: &Program) -> Result<Decimal, VmError> {
59 if prog.code.is_empty() {
60 return Ok(Decimal::ZERO);
61 }
62
63 let mut stack: Vec<Decimal> = Vec::new();
64
65 for op in &prog.code {
66 self.execute_instruction(op, &mut stack)?;
67 }
68
69 match stack.as_slice() {
70 [result] => Ok(*result),
71 _ => Err(VmError::InvalidFinalStack { count: stack.len() }),
72 }
73 }
74
75 fn execute_instruction(&self, op: &Instr, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
76 match op {
77 Instr::Push(v) => {
78 stack.push(*v);
79 Ok(())
80 }
81 Instr::Load(sym) => match sym {
82 Symbol::Const { name: _, value, .. } => {
83 stack.push(*value);
84 Ok(())
85 }
86 _ => Err(VmError::InvalidLoad {
87 symbol_name: Cow::Owned(sym.name().to_string()),
88 }),
89 },
90 Instr::Neg => {
91 let v = Self::pop(stack)?;
92 stack.push(-v);
93 Ok(())
94 }
95 Instr::Add => self.add_op(stack),
96 Instr::Sub => self.sub_op(stack),
97 Instr::Mul => self.mul_op(stack),
98 Instr::Div => self.div_op(stack),
99 Instr::Pow => self.pow_op(stack),
100 Instr::Fact => self.fact_op(stack),
101 Instr::Call(sym, argc) => self.call_op(sym, *argc, stack),
102 Instr::Equal => self.comparison_op(stack, |a, b| a == b),
104 Instr::NotEqual => self.comparison_op(stack, |a, b| a != b),
105 Instr::Less => self.comparison_op(stack, |a, b| a < b),
106 Instr::LessEqual => self.comparison_op(stack, |a, b| a <= b),
107 Instr::Greater => self.comparison_op(stack, |a, b| a > b),
108 Instr::GreaterEqual => self.comparison_op(stack, |a, b| a >= b),
109 }
110 }
111
112 fn comparison_op<F>(&self, stack: &mut Vec<Decimal>, f: F) -> Result<(), VmError>
113 where
114 F: FnOnce(Decimal, Decimal) -> bool,
115 {
116 let right = Self::pop(stack)?;
117 let left = Self::pop(stack)?;
118 let result = if f(left, right) {
119 Decimal::ONE
120 } else {
121 Decimal::ZERO
122 };
123 stack.push(result);
124 Ok(())
125 }
126
127 fn add_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
128 let right = Self::pop(stack)?;
129 let left = Self::pop(stack)?;
130 let result = left
131 .checked_add(right)
132 .ok_or_else(|| VmError::ArithmeticError {
133 message: format!("Addition overflow: {} + {}", left, right),
134 })?;
135 stack.push(result);
136 Ok(())
137 }
138
139 fn sub_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
140 let right = Self::pop(stack)?;
141 let left = Self::pop(stack)?;
142 let result = left
143 .checked_sub(right)
144 .ok_or_else(|| VmError::ArithmeticError {
145 message: format!("Subtraction overflow: {} - {}", left, right),
146 })?;
147 stack.push(result);
148 Ok(())
149 }
150
151 fn mul_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
152 let right = Self::pop(stack)?;
153 let left = Self::pop(stack)?;
154 let result = left
155 .checked_mul(right)
156 .ok_or_else(|| VmError::ArithmeticError {
157 message: format!("Multiplication overflow: {} * {}", left, right),
158 })?;
159 stack.push(result);
160 Ok(())
161 }
162
163 fn div_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
164 let right = Self::pop(stack)?;
165 let left = Self::pop(stack)?;
166 let result = left.checked_div(right).ok_or_else(|| {
167 if right.is_zero() {
168 VmError::DivisionByZero
169 } else {
170 VmError::ArithmeticError {
171 message: format!("Division overflow or underflow: {} / {}", left, right),
172 }
173 }
174 })?;
175 stack.push(result);
176 Ok(())
177 }
178
179 fn pow_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
180 let exponent = Self::pop(stack)?;
181 let base = Self::pop(stack)?;
182
183 let result =
185 match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| base.powd(exponent))) {
186 Ok(result) => result,
187 Err(_) => {
188 return Err(VmError::ArithmeticError {
189 message: format!("Power operation failed: {} ^ {}", base, exponent),
190 });
191 }
192 };
193
194 stack.push(result);
195 Ok(())
196 }
197
198 fn fact_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
199 let n = Self::pop(stack)?;
200
201 if n.is_sign_negative() {
203 return Err(VmError::InvalidFactorial { value: n });
204 }
205
206 if n.fract() != Decimal::ZERO {
208 return Err(VmError::InvalidFactorial { value: n });
209 }
210
211 let n_u64 = n.to_u64().unwrap();
213 let result = (1..=n_u64).try_fold(Decimal::ONE, |acc, i| {
214 acc.checked_mul(Decimal::from(i))
215 .ok_or_else(|| VmError::ArithmeticError {
216 message: format!("Factorial calculation overflow at {}!", i),
217 })
218 })?;
219
220 stack.push(result);
221 Ok(())
222 }
223
224 fn call_op(&self, sym: &Symbol, argc: usize, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
225 match sym {
226 Symbol::Func {
227 name,
228 args: min_args,
229 variadic,
230 callback,
231 ..
232 } => {
233 if argc != *min_args && (!*variadic || argc < *min_args) {
234 return Err(VmError::CallStackUnderflow {
235 function_name: name.clone(),
236 expected: *min_args,
237 found: argc,
238 });
239 }
240
241 if stack.len() < argc {
243 return Err(VmError::CallStackUnderflow {
244 function_name: name.clone(),
245 expected: argc,
246 found: stack.len(),
247 });
248 }
249
250 let args_start = stack.len() - argc;
251 let args = &stack[args_start..];
252 let result = callback(args).map_err(VmError::FunctionError)?;
253 stack.truncate(args_start);
254 stack.push(result);
255 Ok(())
256 }
257 Symbol::Const { .. } => Err(VmError::InvalidCall {
258 symbol_name: Cow::Owned(sym.name().to_string()),
259 }),
260 }
261 }
262
263 fn pop(stack: &mut Vec<Decimal>) -> Result<Decimal, VmError> {
264 stack.pop().ok_or(VmError::StackUnderflow)
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use crate::symbol::SymTable;
272 use std::borrow::Cow;
273
274 fn make(code: Vec<Instr>) -> Program {
275 let mut program = Program::new();
276 program.code = code;
277 program
278 }
279
280 #[test]
281 fn test_vm_error_stack_underflow() {
282 let vm = Vm::default();
283 let program = make(
284 vec![Instr::Add], );
286
287 let result = vm.run(&program);
288 assert!(matches!(result, Err(VmError::StackUnderflow)));
289 }
290
291 #[test]
292 fn test_vm_error_division_by_zero() {
293 let vm = Vm::default();
294 let program = make(vec![Instr::Push(dec!(5)), Instr::Push(dec!(0)), Instr::Div]);
295
296 let result = vm.run(&program);
297 assert!(matches!(result, Err(VmError::DivisionByZero)));
298 }
299
300 #[test]
301 fn test_vm_error_invalid_final_stack() {
302 let vm = Vm::default();
303 let program = make(vec![
304 Instr::Push(dec!(1)),
305 Instr::Push(dec!(2)),
306 ]);
308
309 let result = vm.run(&program);
310 assert!(matches!(
311 result,
312 Err(VmError::InvalidFinalStack { count: 2 })
313 ));
314 }
315
316 #[test]
317 fn test_vm_error_invalid_load() {
318 let vm = Vm::default();
319 let table = SymTable::stdlib();
320 let sin_func = table.get("sin").unwrap();
321
322 let program = make(
323 vec![Instr::Load(sin_func)], );
325
326 let result = vm.run(&program);
327 assert!(matches!(
328 result,
329 Err(VmError::InvalidLoad { symbol_name: _ })
330 ));
331 }
332
333 #[test]
334 fn test_vm_error_invalid_call() {
335 let vm = Vm::default();
336 let table = SymTable::stdlib();
337 let pi_const = table.get("pi").unwrap();
338
339 let program = make(
340 vec![Instr::Call(pi_const, 0)], );
342
343 let result = vm.run(&program);
344 assert!(matches!(
345 result,
346 Err(VmError::InvalidCall { symbol_name: _ })
347 ));
348 }
349
350 #[test]
351 fn test_vm_error_call_stack_underflow() {
352 let vm = Vm::default();
353 let table = SymTable::stdlib();
354 let sin_func = table.get("sin").unwrap();
355
356 let program = make(
357 vec![Instr::Call(sin_func, 0)], );
359
360 let result = vm.run(&program);
361 assert!(matches!(
362 result,
363 Err(VmError::CallStackUnderflow {
364 function_name: _,
365 expected: _,
366 found: _
367 })
368 ));
369 }
370
371 #[test]
372 fn test_vm_error_display() {
373 assert_eq!(
374 VmError::StackUnderflow.to_string(),
375 "Stack underflow: attempted to pop from empty stack"
376 );
377 assert_eq!(VmError::DivisionByZero.to_string(), "Division by zero");
378 assert_eq!(
379 VmError::InvalidFinalStack { count: 3 }.to_string(),
380 "Invalid stack state at program end: expected 1 element, found 3"
381 );
382 assert_eq!(
383 VmError::InvalidLoad {
384 symbol_name: Cow::Borrowed("test"),
385 }
386 .to_string(),
387 "Invalid load operation: cannot load 'test' as a constant"
388 );
389 assert_eq!(
390 VmError::InvalidCall {
391 symbol_name: Cow::Borrowed("test"),
392 }
393 .to_string(),
394 "Invalid call operation: cannot call 'test' as a function"
395 );
396 assert_eq!(
397 VmError::CallStackUnderflow {
398 function_name: Cow::Borrowed("sin"),
399 expected: 1,
400 found: 0
401 }
402 .to_string(),
403 "Stack underflow on function call 'sin': expected 1 arguments, found 0"
404 );
405 }
406
407 #[test]
408 fn test_binary_operations() {
409 let vm = Vm::default();
410
411 let test_cases = vec![
413 (
414 vec![Instr::Push(dec!(6)), Instr::Push(dec!(2)), Instr::Sub],
415 dec!(4),
416 ),
417 (
418 vec![Instr::Push(dec!(3)), Instr::Push(dec!(4)), Instr::Mul],
419 dec!(12),
420 ),
421 (
422 vec![Instr::Push(dec!(8)), Instr::Push(dec!(2)), Instr::Div],
423 dec!(4),
424 ),
425 ];
426
427 for (code, expected) in test_cases {
428 let program = make(code);
429 assert_eq!(vm.run(&program).unwrap(), expected);
430 }
431 }
432}