1use crate::ir::Instr;
2use crate::number::Number;
3use crate::symbol::{FuncError, Symbol};
4use crate::symtable::SymTable;
5use thiserror::Error;
6
7#[derive(Error, Debug, Clone)]
9pub enum VmError {
10 #[error("Stack underflow: attempted to pop from empty stack")]
11 StackUnderflow,
12 #[error("Division by zero")]
13 DivisionByZero,
14 #[error("Invalid stack state at program end: expected 1 element, found {count}")]
15 InvalidFinalStack { count: usize },
16 #[error("Invalid factorial: {value} (must be a non-negative integer)")]
17 InvalidFactorial { value: Number },
18 #[error("Arithmetic error: {message}")]
19 ArithmeticError { message: String },
20 #[error("Function error: {0}")]
21 FunctionError(FuncError),
22}
23
24#[derive(Debug)]
39pub struct Vm<'vm> {
40 bytecode: &'vm [Instr],
41 symtable: &'vm mut SymTable,
42 stack: Vec<Number>,
43 ip: usize,
44}
45
46impl<'vm> Vm<'vm> {
47 pub fn run(bytecode: &'vm [Instr], symtable: &'vm mut SymTable) -> Result<Number, VmError> {
59 use crate::number::consts;
60
61 if bytecode.is_empty() {
62 return Ok(consts::ZERO);
63 }
64
65 let mut vm = Vm {
66 bytecode,
67 symtable,
68 stack: Vec::new(),
69 ip: 0,
70 };
71
72 vm.execute()?;
73
74 match vm.stack.as_slice() {
75 [result] => Ok(*result),
76 _ => Err(VmError::InvalidFinalStack {
77 count: vm.stack.len(),
78 }),
79 }
80 }
81
82 fn execute(&mut self) -> Result<(), VmError> {
83 use crate::number::consts;
84
85 while self.ip < self.bytecode.len() {
86 let op = &self.bytecode[self.ip];
87
88 match op {
89 Instr::Jmp(target) => {
90 self.ip = *target;
91 continue;
92 }
93 Instr::Jz(target) => {
94 let cond = self.pop()?;
95 if cond == consts::ZERO {
96 self.ip = *target;
97 continue;
98 }
99 }
100 Instr::Push(v) => {
101 self.stack.push(*v);
102 }
103 Instr::Load(idx) => {
104 let sym = self.symtable.get_by_index(*idx).unwrap();
105 match sym {
106 Symbol::Const { value, .. } => {
107 self.stack.push(*value);
108 }
109 _ => unreachable!(),
110 }
111 }
112 Instr::Store(idx) => {
113 let top = self.pop()?;
114 let sym = self.symtable.get_mut_by_index(*idx).unwrap();
115 match sym {
116 Symbol::Const { value, .. } => {
117 *value = top;
118 }
119 _ => unreachable!(),
120 }
121 }
122 Instr::Neg => {
123 let v = self.pop()?;
124 self.stack.push(-v);
125 }
126 Instr::Add => self.add_op()?,
127 Instr::Sub => self.sub_op()?,
128 Instr::Mul => self.mul_op()?,
129 Instr::Div => self.div_op()?,
130 Instr::Pow => self.pow_op()?,
131 Instr::Fact => self.fact_op()?,
132 Instr::Call(idx, argc) => self.call_op(*idx, *argc)?,
133 Instr::Equal => self.comparison_op(|a, b| a == b)?,
134 Instr::NotEqual => self.comparison_op(|a, b| a != b)?,
135 Instr::Less => self.comparison_op(|a, b| a < b)?,
136 Instr::LessEqual => self.comparison_op(|a, b| a <= b)?,
137 Instr::Greater => self.comparison_op(|a, b| a > b)?,
138 Instr::GreaterEqual => self.comparison_op(|a, b| a >= b)?,
139 }
140
141 self.ip += 1;
142 }
143 Ok(())
144 }
145
146 fn comparison_op<F>(&mut self, f: F) -> Result<(), VmError>
147 where
148 F: FnOnce(Number, Number) -> bool,
149 {
150 use crate::number::consts;
151
152 let right = self.pop()?;
153 let left = self.pop()?;
154 let result = if f(left, right) {
155 consts::ONE
156 } else {
157 consts::ZERO
158 };
159 self.stack.push(result);
160 Ok(())
161 }
162
163 #[cfg(feature = "decimal-precision")]
165 fn add_op(&mut self) -> Result<(), VmError> {
166 let right = self.pop()?;
167 let left = self.pop()?;
168 let result = left
169 .checked_add(right)
170 .ok_or_else(|| VmError::ArithmeticError {
171 message: format!("Addition overflow: {} + {}", left, right),
172 })?;
173 self.stack.push(result);
174 Ok(())
175 }
176
177 #[cfg(feature = "f64-floats")]
179 fn add_op(&mut self) -> Result<(), VmError> {
180 let right = self.pop()?;
181 let left = self.pop()?;
182 self.stack.push(left + right);
183 Ok(())
184 }
185
186 #[cfg(feature = "decimal-precision")]
187 fn sub_op(&mut self) -> Result<(), VmError> {
188 let right = self.pop()?;
189 let left = self.pop()?;
190 let result = left
191 .checked_sub(right)
192 .ok_or_else(|| VmError::ArithmeticError {
193 message: format!("Subtraction overflow: {} - {}", left, right),
194 })?;
195 self.stack.push(result);
196 Ok(())
197 }
198
199 #[cfg(feature = "f64-floats")]
200 fn sub_op(&mut self) -> Result<(), VmError> {
201 let right = self.pop()?;
202 let left = self.pop()?;
203 self.stack.push(left - right);
204 Ok(())
205 }
206
207 #[cfg(feature = "decimal-precision")]
208 fn mul_op(&mut self) -> Result<(), VmError> {
209 let right = self.pop()?;
210 let left = self.pop()?;
211 let result = left
212 .checked_mul(right)
213 .ok_or_else(|| VmError::ArithmeticError {
214 message: format!("Multiplication overflow: {} * {}", left, right),
215 })?;
216 self.stack.push(result);
217 Ok(())
218 }
219
220 #[cfg(feature = "f64-floats")]
221 fn mul_op(&mut self) -> Result<(), VmError> {
222 let right = self.pop()?;
223 let left = self.pop()?;
224 self.stack.push(left * right);
225 Ok(())
226 }
227
228 #[cfg(feature = "decimal-precision")]
229 fn div_op(&mut self) -> Result<(), VmError> {
230 use crate::number::consts;
231
232 let right = self.pop()?;
233 let left = self.pop()?;
234 let result = left.checked_div(right).ok_or_else(|| {
235 if right == consts::ZERO {
236 VmError::DivisionByZero
237 } else {
238 VmError::ArithmeticError {
239 message: format!("Division overflow or underflow: {} / {}", left, right),
240 }
241 }
242 })?;
243 self.stack.push(result);
244 Ok(())
245 }
246
247 #[cfg(feature = "f64-floats")]
248 fn div_op(&mut self) -> Result<(), VmError> {
249 let right = self.pop()?;
250 let left = self.pop()?;
251 if right == 0.0 {
253 return Err(VmError::DivisionByZero);
254 }
255 self.stack.push(left / right);
256 Ok(())
257 }
258
259 #[cfg(feature = "decimal-precision")]
260 fn pow_op(&mut self) -> Result<(), VmError> {
261 use rust_decimal::prelude::{FromPrimitive, ToPrimitive};
262
263 let exponent = self.pop()?;
264 let base = self.pop()?;
265
266 let base_f64 = base.to_f64().ok_or_else(|| VmError::ArithmeticError {
269 message: format!("Failed to convert base {} to f64", base),
270 })?;
271 let exp_f64 = exponent.to_f64().ok_or_else(|| VmError::ArithmeticError {
272 message: format!("Failed to convert exponent {} to f64", exponent),
273 })?;
274
275 let result_f64 = base_f64.powf(exp_f64);
276
277 let result = Number::from_f64(result_f64).ok_or_else(|| VmError::ArithmeticError {
278 message: format!(
279 "Power operation result cannot be represented: {} ^ {}",
280 base, exponent
281 ),
282 })?;
283
284 self.stack.push(result);
285 Ok(())
286 }
287
288 #[cfg(feature = "f64-floats")]
289 fn pow_op(&mut self) -> Result<(), VmError> {
290 let exponent = self.pop()?;
291 let base = self.pop()?;
292 self.stack.push(base.powf(exponent));
293 Ok(())
294 }
295
296 #[cfg(feature = "decimal-precision")]
297 fn fact_op(&mut self) -> Result<(), VmError> {
298 use crate::number::consts;
299 use rust_decimal::prelude::*;
300
301 let n = self.pop()?;
302
303 if n.is_sign_negative() {
305 return Err(VmError::InvalidFactorial { value: n });
306 }
307
308 if n.fract() != consts::ZERO {
310 return Err(VmError::InvalidFactorial { value: n });
311 }
312
313 let n_u64 = n.to_u64().unwrap();
315 let result = (1..=n_u64).try_fold(consts::ONE, |acc, i| {
316 acc.checked_mul(Number::from(i))
317 .ok_or_else(|| VmError::ArithmeticError {
318 message: format!("Factorial calculation overflow at {}!", i),
319 })
320 })?;
321
322 self.stack.push(result);
323 Ok(())
324 }
325
326 #[cfg(feature = "f64-floats")]
327 fn fact_op(&mut self) -> Result<(), VmError> {
328 let n = self.pop()?;
329
330 if n < 0.0 {
332 return Err(VmError::InvalidFactorial { value: n });
333 }
334
335 if n.fract() != 0.0 {
337 return Err(VmError::InvalidFactorial { value: n });
338 }
339
340 let n_u64 = n as u64;
342 let mut result = 1.0;
343 for i in 1..=n_u64 {
344 result *= i as f64;
345 }
346
347 self.stack.push(result);
348 Ok(())
349 }
350
351 fn call_op(&mut self, idx: usize, argc: usize) -> Result<(), VmError> {
352 match self.symtable.get_by_index(idx).unwrap() {
353 Symbol::Func { callback, .. } => {
354 let args_start = self.stack.len() - argc;
355 let args = &self.stack[args_start..];
356 let result = callback(args).map_err(VmError::FunctionError)?;
357 self.stack.truncate(args_start);
358 self.stack.push(result);
359 Ok(())
360 }
361 Symbol::Const { .. } => unreachable!(),
362 }
363 }
364
365 fn pop(&mut self) -> Result<Number, VmError> {
366 self.stack.pop().ok_or(VmError::StackUnderflow)
367 }
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373 use crate::num;
374 use crate::symtable::SymTable;
375
376 #[test]
377 fn test_vm_error_stack_underflow() {
378 let mut table = SymTable::stdlib();
379 let bytecode = vec![Instr::Add]; let result = Vm::run(&bytecode, &mut table);
382 assert!(matches!(result, Err(VmError::StackUnderflow)));
383 }
384
385 #[test]
386 fn test_vm_error_division_by_zero() {
387 let mut table = SymTable::stdlib();
388 let bytecode = vec![Instr::Push(num!(5)), Instr::Push(num!(0)), Instr::Div];
389
390 let result = Vm::run(&bytecode, &mut table);
391 assert!(matches!(result, Err(VmError::DivisionByZero)));
392 }
393
394 #[test]
395 fn test_vm_error_invalid_final_stack() {
396 let mut table = SymTable::stdlib();
397 let bytecode = vec![
398 Instr::Push(num!(1)),
399 Instr::Push(num!(2)),
400 ];
402
403 let result = Vm::run(&bytecode, &mut table);
404 assert!(matches!(
405 result,
406 Err(VmError::InvalidFinalStack { count: 2 })
407 ));
408 }
409
410 #[test]
411 fn test_vm_error_display() {
412 assert_eq!(
413 VmError::StackUnderflow.to_string(),
414 "Stack underflow: attempted to pop from empty stack"
415 );
416 assert_eq!(VmError::DivisionByZero.to_string(), "Division by zero");
417 assert_eq!(
418 VmError::InvalidFinalStack { count: 3 }.to_string(),
419 "Invalid stack state at program end: expected 1 element, found 3"
420 );
421 }
422
423 #[test]
424 fn test_binary_operations() {
425 let mut table = SymTable::stdlib();
426
427 let test_cases = vec![
429 (
430 vec![Instr::Push(num!(6)), Instr::Push(num!(2)), Instr::Sub],
431 num!(4),
432 ),
433 (
434 vec![Instr::Push(num!(3)), Instr::Push(num!(4)), Instr::Mul],
435 num!(12),
436 ),
437 (
438 vec![Instr::Push(num!(8)), Instr::Push(num!(2)), Instr::Div],
439 num!(4),
440 ),
441 ];
442
443 for (code, expected) in test_cases {
444 let result = Vm::run(&code, &mut table).unwrap();
445 assert_eq!(result, expected);
446 }
447 }
448}