1use crate::ir::Instr;
2use crate::symbol::{FuncError, SymTable, Symbol};
3use rust_decimal::Decimal;
4use rust_decimal::prelude::*;
5use thiserror::Error;
6
7#[cfg(test)]
8use rust_decimal_macros::dec;
9
10#[derive(Error, Debug, Clone)]
12pub enum VmError {
13 #[error("Stack underflow: attempted to pop from empty stack")]
14 StackUnderflow,
15 #[error("Division by zero")]
16 DivisionByZero,
17 #[error("Invalid stack state at program end: expected 1 element, found {count}")]
18 InvalidFinalStack { count: usize },
19 #[error("Invalid load operation: cannot load '{symbol_name}' as a constant")]
20 InvalidLoad { symbol_name: String },
21 #[error("Invalid call operation: cannot call '{symbol_name}' as a function")]
22 InvalidCall { symbol_name: String },
23 #[error(
24 "Stack underflow on function call '{function_name}': expected {expected} arguments, found {found}"
25 )]
26 CallStackUnderflow {
27 function_name: String,
28 expected: usize,
29 found: usize,
30 },
31 #[error("Invalid factorial: {value} (must be a non-negative integer)")]
32 InvalidFactorial { value: Decimal },
33 #[error("Arithmetic error: {message}")]
34 ArithmeticError { message: String },
35 #[error("Function error: {0}")]
36 FunctionError(FuncError),
37 #[error("Invalid symbol index: {0}")]
38 InvalidSymbolIndex(usize),
39}
40
41#[derive(Debug, Default)]
46pub struct Vm;
47
48impl Vm {
49 pub fn run_bytecode(&self, bytecode: &[Instr], table: &SymTable) -> Result<Decimal, VmError> {
60 if bytecode.is_empty() {
61 return Ok(Decimal::ZERO);
62 }
63
64 let mut stack: Vec<Decimal> = Vec::new();
65
66 for op in bytecode {
67 self.execute_instruction(op, table, &mut stack)?;
68 }
69
70 match stack.as_slice() {
71 [result] => Ok(*result),
72 _ => Err(VmError::InvalidFinalStack { count: stack.len() }),
73 }
74 }
75
76 fn execute_instruction(
77 &self,
78 op: &Instr,
79 table: &SymTable,
80 stack: &mut Vec<Decimal>,
81 ) -> Result<(), VmError> {
82 match op {
83 Instr::Push(v) => {
84 stack.push(*v);
85 Ok(())
86 }
87 Instr::Load(idx) => {
88 let sym = table
89 .get_by_index(*idx)
90 .ok_or(VmError::InvalidSymbolIndex(*idx))?;
91 match sym {
92 Symbol::Const { value, .. } => {
93 stack.push(*value);
94 Ok(())
95 }
96 _ => Err(VmError::InvalidLoad {
97 symbol_name: sym.name().to_string(),
98 }),
99 }
100 }
101 Instr::Neg => {
102 let v = Self::pop(stack)?;
103 stack.push(-v);
104 Ok(())
105 }
106 Instr::Add => self.add_op(stack),
107 Instr::Sub => self.sub_op(stack),
108 Instr::Mul => self.mul_op(stack),
109 Instr::Div => self.div_op(stack),
110 Instr::Pow => self.pow_op(stack),
111 Instr::Fact => self.fact_op(stack),
112 Instr::Call(idx, argc) => {
113 let sym = table
114 .get_by_index(*idx)
115 .ok_or(VmError::InvalidSymbolIndex(*idx))?;
116 self.call_op(sym, *argc, stack)
117 }
118 Instr::Equal => self.comparison_op(stack, |a, b| a == b),
120 Instr::NotEqual => self.comparison_op(stack, |a, b| a != b),
121 Instr::Less => self.comparison_op(stack, |a, b| a < b),
122 Instr::LessEqual => self.comparison_op(stack, |a, b| a <= b),
123 Instr::Greater => self.comparison_op(stack, |a, b| a > b),
124 Instr::GreaterEqual => self.comparison_op(stack, |a, b| a >= b),
125 }
126 }
127
128 fn comparison_op<F>(&self, stack: &mut Vec<Decimal>, f: F) -> Result<(), VmError>
129 where
130 F: FnOnce(Decimal, Decimal) -> bool,
131 {
132 let right = Self::pop(stack)?;
133 let left = Self::pop(stack)?;
134 let result = if f(left, right) {
135 Decimal::ONE
136 } else {
137 Decimal::ZERO
138 };
139 stack.push(result);
140 Ok(())
141 }
142
143 fn add_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
144 let right = Self::pop(stack)?;
145 let left = Self::pop(stack)?;
146 let result = left
147 .checked_add(right)
148 .ok_or_else(|| VmError::ArithmeticError {
149 message: format!("Addition overflow: {} + {}", left, right),
150 })?;
151 stack.push(result);
152 Ok(())
153 }
154
155 fn sub_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
156 let right = Self::pop(stack)?;
157 let left = Self::pop(stack)?;
158 let result = left
159 .checked_sub(right)
160 .ok_or_else(|| VmError::ArithmeticError {
161 message: format!("Subtraction overflow: {} - {}", left, right),
162 })?;
163 stack.push(result);
164 Ok(())
165 }
166
167 fn mul_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
168 let right = Self::pop(stack)?;
169 let left = Self::pop(stack)?;
170 let result = left
171 .checked_mul(right)
172 .ok_or_else(|| VmError::ArithmeticError {
173 message: format!("Multiplication overflow: {} * {}", left, right),
174 })?;
175 stack.push(result);
176 Ok(())
177 }
178
179 fn div_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
180 let right = Self::pop(stack)?;
181 let left = Self::pop(stack)?;
182 let result = left.checked_div(right).ok_or_else(|| {
183 if right.is_zero() {
184 VmError::DivisionByZero
185 } else {
186 VmError::ArithmeticError {
187 message: format!("Division overflow or underflow: {} / {}", left, right),
188 }
189 }
190 })?;
191 stack.push(result);
192 Ok(())
193 }
194
195 fn pow_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
196 let exponent = Self::pop(stack)?;
197 let base = Self::pop(stack)?;
198
199 let result =
201 match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| base.powd(exponent))) {
202 Ok(result) => result,
203 Err(_) => {
204 return Err(VmError::ArithmeticError {
205 message: format!("Power operation failed: {} ^ {}", base, exponent),
206 });
207 }
208 };
209
210 stack.push(result);
211 Ok(())
212 }
213
214 fn fact_op(&self, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
215 let n = Self::pop(stack)?;
216
217 if n.is_sign_negative() {
219 return Err(VmError::InvalidFactorial { value: n });
220 }
221
222 if n.fract() != Decimal::ZERO {
224 return Err(VmError::InvalidFactorial { value: n });
225 }
226
227 let n_u64 = n.to_u64().unwrap();
229 let result = (1..=n_u64).try_fold(Decimal::ONE, |acc, i| {
230 acc.checked_mul(Decimal::from(i))
231 .ok_or_else(|| VmError::ArithmeticError {
232 message: format!("Factorial calculation overflow at {}!", i),
233 })
234 })?;
235
236 stack.push(result);
237 Ok(())
238 }
239
240 fn call_op(&self, sym: &Symbol, argc: usize, stack: &mut Vec<Decimal>) -> Result<(), VmError> {
241 match sym {
242 Symbol::Func {
243 name,
244 args: min_args,
245 variadic,
246 callback,
247 ..
248 } => {
249 if argc != *min_args && (!*variadic || argc < *min_args) {
250 return Err(VmError::CallStackUnderflow {
251 function_name: name.to_string(),
252 expected: *min_args,
253 found: argc,
254 });
255 }
256
257 if stack.len() < argc {
259 return Err(VmError::CallStackUnderflow {
260 function_name: name.to_string(),
261 expected: argc,
262 found: stack.len(),
263 });
264 }
265
266 let args_start = stack.len() - argc;
267 let args = &stack[args_start..];
268 let result = callback(args).map_err(VmError::FunctionError)?;
269 stack.truncate(args_start);
270 stack.push(result);
271 Ok(())
272 }
273 Symbol::Const { .. } => Err(VmError::InvalidCall {
274 symbol_name: sym.name().to_string(),
275 }),
276 }
277 }
278
279 fn pop(stack: &mut Vec<Decimal>) -> Result<Decimal, VmError> {
280 stack.pop().ok_or(VmError::StackUnderflow)
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287 use crate::symbol::SymTable;
288
289 #[test]
290 fn test_vm_error_stack_underflow() {
291 let vm = Vm;
292 let table = SymTable::stdlib();
293 let bytecode = vec![Instr::Add]; let result = vm.run_bytecode(&bytecode, &table);
296 assert!(matches!(result, Err(VmError::StackUnderflow)));
297 }
298
299 #[test]
300 fn test_vm_error_division_by_zero() {
301 let vm = Vm;
302 let table = SymTable::stdlib();
303 let bytecode = vec![Instr::Push(dec!(5)), Instr::Push(dec!(0)), Instr::Div];
304
305 let result = vm.run_bytecode(&bytecode, &table);
306 assert!(matches!(result, Err(VmError::DivisionByZero)));
307 }
308
309 #[test]
310 fn test_vm_error_invalid_final_stack() {
311 let vm = Vm;
312 let table = SymTable::stdlib();
313 let bytecode = vec![
314 Instr::Push(dec!(1)),
315 Instr::Push(dec!(2)),
316 ];
318
319 let result = vm.run_bytecode(&bytecode, &table);
320 assert!(matches!(
321 result,
322 Err(VmError::InvalidFinalStack { count: 2 })
323 ));
324 }
325
326 #[test]
327 fn test_vm_error_invalid_load() {
328 let vm = Vm;
329 let table = SymTable::stdlib();
330 let (sin_idx, _) = table.get_with_index("sin").unwrap();
331
332 let bytecode = vec![Instr::Load(sin_idx)]; let result = vm.run_bytecode(&bytecode, &table);
335 assert!(matches!(
336 result,
337 Err(VmError::InvalidLoad { symbol_name: _ })
338 ));
339 }
340
341 #[test]
342 fn test_vm_error_invalid_call() {
343 let vm = Vm;
344 let table = SymTable::stdlib();
345 let (pi_idx, _) = table.get_with_index("pi").unwrap();
346
347 let bytecode = vec![Instr::Call(pi_idx, 0)]; let result = vm.run_bytecode(&bytecode, &table);
350 assert!(matches!(
351 result,
352 Err(VmError::InvalidCall { symbol_name: _ })
353 ));
354 }
355
356 #[test]
357 fn test_vm_error_call_stack_underflow() {
358 let vm = Vm;
359 let table = SymTable::stdlib();
360 let (sin_idx, _) = table.get_with_index("sin").unwrap();
361
362 let bytecode = vec![Instr::Call(sin_idx, 0)]; let result = vm.run_bytecode(&bytecode, &table);
365 assert!(matches!(
366 result,
367 Err(VmError::CallStackUnderflow {
368 function_name: _,
369 expected: _,
370 found: _
371 })
372 ));
373 }
374
375 #[test]
376 fn test_vm_error_display() {
377 assert_eq!(
378 VmError::StackUnderflow.to_string(),
379 "Stack underflow: attempted to pop from empty stack"
380 );
381 assert_eq!(VmError::DivisionByZero.to_string(), "Division by zero");
382 assert_eq!(
383 VmError::InvalidFinalStack { count: 3 }.to_string(),
384 "Invalid stack state at program end: expected 1 element, found 3"
385 );
386 assert_eq!(
387 VmError::InvalidLoad {
388 symbol_name: "test".to_string(),
389 }
390 .to_string(),
391 "Invalid load operation: cannot load 'test' as a constant"
392 );
393 assert_eq!(
394 VmError::InvalidCall {
395 symbol_name: "test".to_string(),
396 }
397 .to_string(),
398 "Invalid call operation: cannot call 'test' as a function"
399 );
400 assert_eq!(
401 VmError::CallStackUnderflow {
402 function_name: "sin".to_string(),
403 expected: 1,
404 found: 0
405 }
406 .to_string(),
407 "Stack underflow on function call 'sin': expected 1 arguments, found 0"
408 );
409 }
410
411 #[test]
412 fn test_binary_operations() {
413 let vm = Vm;
414 let table = SymTable::stdlib();
415
416 let test_cases = vec![
418 (
419 vec![Instr::Push(dec!(6)), Instr::Push(dec!(2)), Instr::Sub],
420 dec!(4),
421 ),
422 (
423 vec![Instr::Push(dec!(3)), Instr::Push(dec!(4)), Instr::Mul],
424 dec!(12),
425 ),
426 (
427 vec![Instr::Push(dec!(8)), Instr::Push(dec!(2)), Instr::Div],
428 dec!(4),
429 ),
430 ];
431
432 for (code, expected) in test_cases {
433 assert_eq!(vm.run_bytecode(&code, &table).unwrap(), expected);
434 }
435 }
436}