1use crate::ir::Instr;
2use crate::program::Program;
3use crate::symbol::{FuncError, Symbol};
4use rust_decimal::Decimal;
5use rust_decimal::prelude::*;
6use std::borrow::Cow;
7use thiserror::Error;
8
9#[cfg(test)]
10use rust_decimal_macros::dec;
11
12#[derive(Error, Debug, Clone)]
14pub enum VmError {
15 #[error("Stack underflow: attempted to pop from empty stack")]
16 StackUnderflow,
17 #[error("Division by zero")]
18 DivisionByZero,
19 #[error("Invalid stack state at program end: expected 1 element, found {count}")]
20 InvalidFinalStack { count: usize },
21 #[error("Invalid load operation: cannot load '{symbol_name}' as a constant")]
22 InvalidLoad { symbol_name: Cow<'static, str> },
23 #[error("Invalid call operation: cannot call '{symbol_name}' as a function")]
24 InvalidCall { symbol_name: Cow<'static, str> },
25 #[error(
26 "Stack underflow on function call '{function_name}': expected {expected} arguments, found {found}"
27 )]
28 CallStackUnderflow {
29 function_name: Cow<'static, str>,
30 expected: usize,
31 found: usize,
32 },
33 #[error("Invalid factorial: {value} (must be a non-negative integer)")]
34 InvalidFactorial { value: Decimal },
35 #[error("Arithmetic error: {message}")]
36 ArithmeticError { message: String },
37 #[error("Function error: {0}")]
38 FunctionError(FuncError),
39}
40
41#[derive(Debug, Default)]
43pub struct Vm;
44
45impl Vm {
46 pub fn run(&self, prog: &Program) -> Result<Decimal, VmError> {
48 if prog.code.is_empty() {
49 return Ok(Decimal::ZERO);
50 }
51
52 let mut stack: Vec<Decimal> = Vec::new();
53
54 for op in &prog.code {
55 self.execute_instruction(op, &mut stack)?;
56 }
57
58 match stack.as_slice() {
59 [result] => Ok(*result),
60 _ => Err(VmError::InvalidFinalStack { count: stack.len() }),
61 }
62 }
63
64 fn execute_instruction(&self, op: &Instr, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
65 match op {
66 Instr::Push(v) => {
67 stack.push(*v);
68 Ok(())
69 }
70 Instr::Load(sym) => match sym {
71 Symbol::Const { name: _, value, .. } => {
72 stack.push(*value);
73 Ok(())
74 }
75 _ => Err(VmError::InvalidLoad {
76 symbol_name: Cow::Owned(sym.name().to_string()),
77 }),
78 },
79 Instr::Neg => {
80 let v = Self::pop(stack)?;
81 stack.push(-v);
82 Ok(())
83 }
84 Instr::Add => self.add_op(stack),
85 Instr::Sub => self.sub_op(stack),
86 Instr::Mul => self.mul_op(stack),
87 Instr::Div => self.div_op(stack),
88 Instr::Pow => self.pow_op(stack),
89 Instr::Fact => self.fact_op(stack),
90 Instr::Call(sym, argc) => self.call_op(sym, *argc, stack),
91 Instr::Equal => self.comparison_op(stack, |a, b| a == b),
93 Instr::NotEqual => self.comparison_op(stack, |a, b| a != b),
94 Instr::Less => self.comparison_op(stack, |a, b| a < b),
95 Instr::LessEqual => self.comparison_op(stack, |a, b| a <= b),
96 Instr::Greater => self.comparison_op(stack, |a, b| a > b),
97 Instr::GreaterEqual => self.comparison_op(stack, |a, b| a >= b),
98 }
99 }
100
101 fn comparison_op<F>(&self, stack: &mut Vec<Decimal>, f: F) -> Result<(), VmError>
102 where
103 F: FnOnce(Decimal, Decimal) -> bool,
104 {
105 let right = Self::pop(stack)?;
106 let left = Self::pop(stack)?;
107 let result = if f(left, right) {
108 Decimal::ONE
109 } else {
110 Decimal::ZERO
111 };
112 stack.push(result);
113 Ok(())
114 }
115
116 fn add_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
117 let right = Self::pop(stack)?;
118 let left = Self::pop(stack)?;
119 let result = left
120 .checked_add(right)
121 .ok_or_else(|| VmError::ArithmeticError {
122 message: format!("Addition overflow: {} + {}", left, right),
123 })?;
124 stack.push(result);
125 Ok(())
126 }
127
128 fn sub_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
129 let right = Self::pop(stack)?;
130 let left = Self::pop(stack)?;
131 let result = left
132 .checked_sub(right)
133 .ok_or_else(|| VmError::ArithmeticError {
134 message: format!("Subtraction overflow: {} - {}", left, right),
135 })?;
136 stack.push(result);
137 Ok(())
138 }
139
140 fn mul_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
141 let right = Self::pop(stack)?;
142 let left = Self::pop(stack)?;
143 let result = left
144 .checked_mul(right)
145 .ok_or_else(|| VmError::ArithmeticError {
146 message: format!("Multiplication overflow: {} * {}", left, right),
147 })?;
148 stack.push(result);
149 Ok(())
150 }
151
152 fn div_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
153 let right = Self::pop(stack)?;
154 let left = Self::pop(stack)?;
155 let result = left.checked_div(right).ok_or_else(|| {
156 if right.is_zero() {
157 VmError::DivisionByZero
158 } else {
159 VmError::ArithmeticError {
160 message: format!("Division overflow or underflow: {} / {}", left, right),
161 }
162 }
163 })?;
164 stack.push(result);
165 Ok(())
166 }
167
168 fn pow_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
169 let exponent = Self::pop(stack)?;
170 let base = Self::pop(stack)?;
171
172 let result =
174 match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| base.powd(exponent))) {
175 Ok(result) => result,
176 Err(_) => {
177 return Err(VmError::ArithmeticError {
178 message: format!("Power operation failed: {} ^ {}", base, exponent),
179 });
180 }
181 };
182
183 stack.push(result);
184 Ok(())
185 }
186
187 fn fact_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
188 let n = Self::pop(stack)?;
189
190 if n.is_sign_negative() {
192 return Err(VmError::InvalidFactorial { value: n });
193 }
194
195 if n.fract() != Decimal::ZERO {
197 return Err(VmError::InvalidFactorial { value: n });
198 }
199
200 let n_u64 = n.to_u64().unwrap();
202 let mut result = Decimal::ONE;
203 for i in 1..=n_u64 {
204 result =
205 result
206 .checked_mul(Decimal::from(i))
207 .ok_or_else(|| VmError::ArithmeticError {
208 message: format!("Factorial calculation overflow at {}!", i),
209 })?;
210 }
211
212 stack.push(result);
213 Ok(())
214 }
215
216 fn call_op(&self, sym: &Symbol, argc: usize, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
217 match sym {
218 Symbol::Func {
219 name,
220 args: min_args,
221 variadic,
222 callback,
223 ..
224 } => {
225 if argc != *min_args && (!*variadic || argc < *min_args) {
226 return Err(VmError::CallStackUnderflow {
227 function_name: name.clone(),
228 expected: *min_args,
229 found: argc,
230 });
231 }
232
233 if stack.len() < argc {
235 return Err(VmError::CallStackUnderflow {
236 function_name: name.clone(),
237 expected: argc,
238 found: stack.len(),
239 });
240 }
241
242 let args_start = stack.len() - argc;
243 let args = &stack[args_start..];
244 let result = callback(args).map_err(VmError::FunctionError)?;
245 stack.truncate(args_start);
246 stack.push(result);
247 Ok(())
248 }
249 Symbol::Const { .. } => Err(VmError::InvalidCall {
250 symbol_name: Cow::Owned(sym.name().to_string()),
251 }),
252 }
253 }
254
255 fn pop(stack: &mut Vec<Decimal>) -> Result<Decimal, VmError> {
256 stack.pop().ok_or(VmError::StackUnderflow)
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use crate::symbol::SymTable;
264 use std::borrow::Cow;
265
266 fn make(code: Vec<Instr>) -> Program {
267 let mut program = Program::new();
268 program.code = code;
269 program
270 }
271
272 #[test]
273 fn test_vm_error_stack_underflow() {
274 let vm = Vm::default();
275 let program = make(
276 vec![Instr::Add], );
278
279 let result = vm.run(&program);
280 assert!(matches!(result, Err(VmError::StackUnderflow)));
281 }
282
283 #[test]
284 fn test_vm_error_division_by_zero() {
285 let vm = Vm::default();
286 let program = make(vec![Instr::Push(dec!(5)), Instr::Push(dec!(0)), Instr::Div]);
287
288 let result = vm.run(&program);
289 assert!(matches!(result, Err(VmError::DivisionByZero)));
290 }
291
292 #[test]
293 fn test_vm_error_invalid_final_stack() {
294 let vm = Vm::default();
295 let program = make(vec![
296 Instr::Push(dec!(1)),
297 Instr::Push(dec!(2)),
298 ]);
300
301 let result = vm.run(&program);
302 assert!(matches!(
303 result,
304 Err(VmError::InvalidFinalStack { count: 2 })
305 ));
306 }
307
308 #[test]
309 fn test_vm_error_invalid_load() {
310 let vm = Vm::default();
311 let table = SymTable::stdlib();
312 let sin_func = table.get("sin").unwrap();
313
314 let program = make(
315 vec![Instr::Load(sin_func)], );
317
318 let result = vm.run(&program);
319 assert!(matches!(
320 result,
321 Err(VmError::InvalidLoad { symbol_name: _ })
322 ));
323 }
324
325 #[test]
326 fn test_vm_error_invalid_call() {
327 let vm = Vm::default();
328 let table = SymTable::stdlib();
329 let pi_const = table.get("pi").unwrap();
330
331 let program = make(
332 vec![Instr::Call(pi_const, 0)], );
334
335 let result = vm.run(&program);
336 assert!(matches!(
337 result,
338 Err(VmError::InvalidCall { symbol_name: _ })
339 ));
340 }
341
342 #[test]
343 fn test_vm_error_call_stack_underflow() {
344 let vm = Vm::default();
345 let table = SymTable::stdlib();
346 let sin_func = table.get("sin").unwrap();
347
348 let program = make(
349 vec![Instr::Call(sin_func, 0)], );
351
352 let result = vm.run(&program);
353 assert!(matches!(
354 result,
355 Err(VmError::CallStackUnderflow {
356 function_name: _,
357 expected: _,
358 found: _
359 })
360 ));
361 }
362
363 #[test]
364 fn test_vm_error_display() {
365 assert_eq!(
366 VmError::StackUnderflow.to_string(),
367 "Stack underflow: attempted to pop from empty stack"
368 );
369 assert_eq!(VmError::DivisionByZero.to_string(), "Division by zero");
370 assert_eq!(
371 VmError::InvalidFinalStack { count: 3 }.to_string(),
372 "Invalid stack state at program end: expected 1 element, found 3"
373 );
374 assert_eq!(
375 VmError::InvalidLoad {
376 symbol_name: Cow::Borrowed("test"),
377 }
378 .to_string(),
379 "Invalid load operation: cannot load 'test' as a constant"
380 );
381 assert_eq!(
382 VmError::InvalidCall {
383 symbol_name: Cow::Borrowed("test"),
384 }
385 .to_string(),
386 "Invalid call operation: cannot call 'test' as a function"
387 );
388 assert_eq!(
389 VmError::CallStackUnderflow {
390 function_name: Cow::Borrowed("sin"),
391 expected: 1,
392 found: 0
393 }
394 .to_string(),
395 "Stack underflow on function call 'sin': expected 1 arguments, found 0"
396 );
397 }
398
399 #[test]
400 fn test_binary_operations() {
401 let vm = Vm::default();
402
403 let test_cases = vec![
405 (
406 vec![Instr::Push(dec!(6)), Instr::Push(dec!(2)), Instr::Sub],
407 dec!(4),
408 ),
409 (
410 vec![Instr::Push(dec!(3)), Instr::Push(dec!(4)), Instr::Mul],
411 dec!(12),
412 ),
413 (
414 vec![Instr::Push(dec!(8)), Instr::Push(dec!(2)), Instr::Div],
415 dec!(4),
416 ),
417 ];
418
419 for (code, expected) in test_cases {
420 let program = make(code);
421 assert_eq!(vm.run(&program).unwrap(), expected);
422 }
423 }
424}