cas_vm/
lib.rs

1#![doc = include_str!("../README.md")]
2
3pub mod error;
4mod frame;
5mod instruction;
6mod register;
7
8use cas_compute::{
9    consts::{E, I, PHI, PI, TAU},
10    funcs::all as all_funcs,
11    numerical::{builtin::error::BuiltinError, func::Function, trig_mode::TrigMode, value::Value},
12    primitive::{complex, float},
13};
14use cas_compiler::{
15    expr::compile_stmts,
16    instruction::{Instruction, InstructionKind},
17    item::Symbol,
18    sym_table::SymbolTable,
19    Chunk,
20    Compile,
21    Compiler,
22    Label,
23};
24use cas_error::Error;
25use cas_parser::parser::ast::Stmt;
26use error::{
27    ConditionalNotBoolean,
28    IndexOutOfBounds,
29    IndexOutOfRange,
30    InternalError,
31    InvalidDifferentiation,
32    InvalidIndexTarget,
33    InvalidIndexType,
34    InvalidLengthType,
35    LengthOutOfRange,
36    MissingArgument,
37    StackOverflow,
38    TooManyArguments,
39    TypeMismatch,
40};
41use frame::Frame;
42use instruction::{
43    exec_binary_instruction,
44    exec_unary_instruction,
45    Derivative,
46};
47use register::Registers;
48use std::{cell::RefCell, collections::HashMap, ops::Range, rc::Rc};
49
50/// The maximum number of stack frames before a stack overflow error is thrown in the [`Vm`].
51const MAX_STACK_FRAMES: usize = 2usize.pow(16);
52
53/// After executing [`Vm::run_one`], indicates if execution will continue normally, or if it
54/// something else, such as a jump, has occurred.
55#[derive(Debug)]
56enum ControlFlow {
57    /// Continue executing the program sequentially.
58    Continue,
59
60    /// A jump has occurred. Do not increment the instruction pointer.
61    Jump,
62}
63
64/// A virtual machine that executes bytecode instructions generated by the compiler (see
65/// [`Compiler`]).
66#[derive(Clone, Debug)]
67pub struct Vm {
68    /// The trigonometric mode used when calling native functions.
69    trig_mode: TrigMode,
70
71    /// The bytecode chunks to execute.
72    pub chunks: Vec<Chunk>,
73
74    /// Labels generated by the compiler, mapped to the index of the instruction they reference.
75    labels: HashMap<Label, (usize, usize)>,
76
77    /// A symbol table that maps identifiers to information about the values they represent.
78    ///
79    /// This is used to store information about variables and functions that are defined in the
80    /// program.
81    pub sym_table: SymbolTable,
82
83    /// Variables in the global scope.
84    pub variables: HashMap<usize, Value>,
85
86    /// Registers used by the VM.
87    registers: Registers,
88}
89
90impl Default for Vm {
91    fn default() -> Self {
92        Self {
93            trig_mode: TrigMode::default(),
94            chunks: vec![Chunk::default()], // add main chunk
95            labels: HashMap::new(),
96            sym_table: SymbolTable::default(),
97            variables: HashMap::new(),
98            registers: Registers::default(),
99        }
100    }
101}
102
103impl From<Compiler> for Vm {
104    fn from(compiler: Compiler) -> Self {
105        Self {
106            trig_mode: TrigMode::default(),
107            chunks: compiler.chunks,
108            labels: compiler.labels
109                .into_iter()
110                .map(|(label, location)| (label, location.unwrap()))
111                .collect(),
112            sym_table: compiler.sym_table,
113            variables: HashMap::new(),
114            registers: Registers::default(),
115        }
116    }
117}
118
119impl Vm {
120    /// Creates a blank [`Vm`] with no chunks. This is used for `cas-repl` to compile code on the
121    /// fly.
122    pub fn new() -> Self {
123        Self::default()
124    }
125
126    /// Creates a [`Vm`] by compiling the given source AST.
127    pub fn compile<T: Compile>(expr: T) -> Result<Self, Error> {
128        let compiler = Compiler::compile(expr)?;
129        Ok(Self {
130            trig_mode: TrigMode::default(),
131            chunks: compiler.chunks,
132            labels: compiler.labels
133                .into_iter()
134                .map(|(label, location)| (label, location.unwrap()))
135                .collect(),
136            sym_table: compiler.sym_table,
137            variables: HashMap::new(),
138            registers: Registers::default(),
139        })
140    }
141
142    /// Creates a [`Vm`] by compiling multiple statements.
143    pub fn compile_program(stmts: Vec<Stmt>) -> Result<Self, Error> {
144        let compiler = Compiler::compile_program(stmts)?;
145        Ok(Self {
146            trig_mode: TrigMode::default(),
147            chunks: compiler.chunks,
148            labels: compiler.labels
149                .into_iter()
150                .map(|(label, location)| (label, location.unwrap()))
151                .collect(),
152            sym_table: compiler.sym_table,
153            variables: HashMap::new(),
154            registers: Registers::default(),
155        })
156    }
157
158    /// Sets the trigonometric mode used when calling native functions.
159    pub fn with_trig_mode(mut self, mode: TrigMode) -> Self {
160        self.trig_mode = mode;
161        self
162    }
163
164    /// Executes one instruction, updating the given state.
165    ///
166    /// Returns a [`ControlFlow`] value that indicates whether the program should continue running
167    fn run_one(
168        &mut self,
169        value_stack: &mut Vec<Value>,
170        call_stack: &mut Vec<Frame>,
171        derivative_stack: &mut Vec<Derivative>,
172        instruction_pointer: &mut (usize, usize),
173    ) -> Result<ControlFlow, Error> {
174        /// Check for a stack overflow and return an error if one is detected.
175        fn check_stack_overflow(call_span: &[Range<usize>], call_stack: &[Frame]) -> Result<(), Error> {
176            // MAX_STACK_FRAMES is an arbitrary limit to prevent infinite recursion
177            if call_stack.len() > MAX_STACK_FRAMES {
178                return Err(Error::new(call_span.to_vec(), StackOverflow));
179            }
180
181            Ok(())
182        }
183
184        /// Extracts the `usize` index from the given [`Value`] for the indexing instructions.
185        fn extract_index(value: Value, spans: Vec<Range<usize>>) -> Result<usize, Error> {
186            let typename = value.typename();
187            let Value::Integer(int) = value.coerce_integer() else {
188                return Err(Error::new(
189                    spans,
190                    InvalidIndexType {
191                        expr_type: typename,
192                    },
193                ));
194            };
195
196            int.to_usize().ok_or_else(|| Error::new(
197                spans,
198                IndexOutOfRange,
199            ))
200        }
201
202        /// Extracts the `usize` length from the given [`Value`] for the
203        /// [`InstructionKind::CreateListRepeat`] instruction, which uses slightly different error
204        /// types than [`extract_index`].
205        fn extract_length(value: Value, spans: Vec<Range<usize>>) -> Result<usize, Error> {
206            let typename = value.typename();
207            let Value::Integer(int) = value.coerce_integer() else {
208                return Err(Error::new(
209                    spans,
210                    InvalidLengthType {
211                        expr_type: typename,
212                    },
213                ));
214            };
215
216            int.to_usize().ok_or_else(|| Error::new(
217                spans,
218                LengthOutOfRange,
219            ))
220        }
221
222        /// Convert a [`BuiltinError`] to an [`EvalError`].
223        fn from_builtin_error(err: BuiltinError, spans: Vec<Range<usize>>) -> Error {
224            match err {
225                BuiltinError::TooManyArguments(err) => {
226                    // remove spans of all arguments, add span starting from extraneous argument
227                    let spans = vec![
228                        spans[0].clone(),
229                        spans[1].clone(),
230                        spans[2 + err.expected].start..spans.last().unwrap().end,
231                    ];
232                    Error::new(spans, TooManyArguments::from(err))
233                },
234                BuiltinError::MissingArgument(err) => {
235                    Error::new(spans, MissingArgument::from(err))
236                },
237                BuiltinError::TypeMismatch(err) => {
238                    // remove spans of all arguments but the mentioned one
239                    let spans = vec![
240                        spans[0].clone(),
241                        spans[1].clone(),
242                        spans[2 + err.index].clone(),
243                    ];
244                    Error::new(spans, TypeMismatch::from(err))
245                },
246                _ => todo!(),
247            }
248        }
249
250        /// Helper to build an internal error.
251        fn internal_err(instruction: &Instruction, data: impl Into<String>) -> Error {
252            Error::new(
253                instruction.spans.clone(),
254                InternalError {
255                    instruction: format!("{:?}", instruction.kind),
256                    data: data.into(),
257                },
258            )
259        }
260
261        let instruction = &self.chunks[instruction_pointer.0].instructions[instruction_pointer.1];
262        // println!("value stack: {:?}", value_stack.iter().map(|v: &Value| v.to_string()).collect::<Vec<_>>());
263        // println!("call stack: {:?}", call_stack);
264        // println!("instruction to execute: {:?} (chunk/inst: {}/{})", instruction, instruction_pointer.0, instruction_pointer.1);
265        // println!("registers: {:?}", self.registers);
266        // println!();
267
268        match &instruction.kind {
269            InstructionKind::InitFunc(fn_name, fn_signature, num_params, num_default_params) => {
270                self.registers.fn_name = fn_name.clone();
271                self.registers.fn_signature = fn_signature.clone();
272                self.registers.num_params = *num_params;
273                self.registers.num_default_params = *num_default_params;
274
275                if call_stack.last().unwrap().derivative && self.registers.num_params != 1 {
276                    return Err(Error::new(
277                        self.registers.call_site_spans.clone(),
278                        InvalidDifferentiation {
279                            name: self.registers.fn_name.clone(),
280                            actual: self.registers.num_params,
281                        },
282                    ));
283                } else if self.registers.num_args > self.registers.num_params {
284                    let spans = &self.registers.call_site_spans;
285                    let expected = self.registers.num_params;
286                    return Err(Error::new(
287                        // take spans starting from extraneous argument, matching builtin function
288                        // error behavior
289                        vec![
290                            spans[0].clone(),
291                            spans[1].clone(),
292                            spans[2 + expected].start..spans.last().unwrap().end,
293                        ],
294                        TooManyArguments {
295                            name: self.registers.fn_name.clone(),
296                            expected,
297                            given: self.registers.num_args,
298                            signature: self.registers.fn_signature.clone(),
299                        },
300                    ));
301                }
302            },
303            InstructionKind::CheckExecReady => {
304                value_stack.push(Value::Boolean(
305                    self.registers.num_args == self.registers.num_params,
306                ));
307            },
308            InstructionKind::NextArg => {
309                // increment the argument counter
310                self.registers.num_args += 1;
311            },
312            InstructionKind::ErrorIfMissingArgs => {
313                if self.registers.num_args != self.registers.num_params {
314                    return Err(Error::new(
315                        self.registers.call_site_spans.clone(),
316                        MissingArgument {
317                            name: self.registers.fn_name.clone(),
318                            indices: {
319                                let offset = self.registers.num_default_params;
320                                let start_idx = self.registers.num_args - offset;
321                                let end_idx = self.registers.num_params - offset;
322                                start_idx..end_idx
323                            },
324                            expected: self.registers.num_params,
325                            given: self.registers.num_args,
326                            signature: self.registers.fn_signature.clone(),
327                        },
328                    ));
329                }
330            },
331            InstructionKind::LoadConst(value) => {
332                let mut value = value.clone();
333
334                // if the value created is a function, store its environment, which is the
335                // variables in the current stack frame
336                if let Value::Function(Function::User(user)) = &mut value {
337                    user.environment = call_stack.last().unwrap().variables.clone();
338                }
339
340                value_stack.push(value);
341            },
342            InstructionKind::CreateList(len) => {
343                let elements = value_stack.split_off(value_stack.len() - *len);
344                value_stack.push(Value::List(Rc::new(RefCell::new(elements))));
345            },
346            InstructionKind::CreateListRepeat => {
347                let count = value_stack.pop().ok_or_else(|| internal_err(
348                    instruction,
349                    "missing # of repetitions",
350                ))?;
351                let value = value_stack.pop().ok_or_else(|| internal_err(
352                    instruction,
353                    "missing value to repeat",
354                ))?;
355                let list = vec![value; extract_length(count, instruction.spans.clone())?];
356                value_stack.push(Value::List(Rc::new(RefCell::new(list))));
357            },
358            InstructionKind::CreateRange(kind) => {
359                let end = value_stack.pop().ok_or_else(|| internal_err(
360                    instruction,
361                    "missing end of range",
362                ))?;
363                let start = value_stack.pop().ok_or_else(|| internal_err(
364                    instruction,
365                    "missing start of range",
366                ))?;
367                value_stack.push(Value::Range(
368                    Box::new(start),
369                    *kind,
370                    Box::new(end),
371                ));
372            },
373            InstructionKind::LoadVar(symbol) => match symbol {
374                Symbol::User(id) => {
375                    let value = call_stack
376                        .iter()
377                        .rev()
378                        .find_map(|frame| frame.get_variable(*id))
379                        .cloned()
380                        .ok_or_else(|| internal_err(
381                            instruction,
382                            format!("user variable `{}` not initialized", id),
383                        ))?;
384                    value_stack.push(value);
385                }
386                Symbol::Builtin(name) => {
387                    value_stack.push(match *name {
388                        "i" => Value::Complex(complex(&*I)),
389                        "e" => Value::Float(float(&*E)),
390                        "phi" => Value::Float(float(&*PHI)),
391                        "pi" => Value::Float(float(&*PI)),
392                        "tau" => Value::Float(float(&*TAU)),
393                        other => Value::Function(Function::Builtin(
394                            all_funcs()
395                                .get(other)
396                                .ok_or_else(|| internal_err(
397                                    instruction,
398                                    format!("builtin function `{}` not found", other),
399                                ))?
400                                .as_ref()
401                        )),
402                    });
403                }
404            },
405            InstructionKind::StoreVar(id) => {
406                let last_frame = call_stack.last_mut().ok_or_else(|| internal_err(
407                    instruction,
408                    "no stack frame to store variable in",
409                ))?;
410                let value = value_stack.last().cloned().ok_or_else(|| internal_err(
411                    instruction,
412                    "no value to store in variable",
413                ))?;
414                last_frame.add_variable(id.to_owned(), value);
415            },
416            InstructionKind::AssignVar(id) => {
417                let last_frame = call_stack.last_mut().ok_or_else(|| internal_err(
418                    instruction,
419                    "no stack frame to store variable in",
420                ))?;
421                last_frame.add_variable(id.to_owned(), value_stack.pop().ok_or_else(|| internal_err(
422                    instruction,
423                    "no value to store in variable",
424                ))?);
425            },
426            InstructionKind::StoreIndexed => {
427                let index = value_stack.pop().ok_or_else(|| internal_err(
428                    instruction,
429                    "missing index to store value at",
430                ))?;
431                let list = value_stack.pop().ok_or_else(|| internal_err(
432                    instruction,
433                    "missing list to store value in",
434                ))?;
435                let value = value_stack.last().cloned().ok_or_else(|| internal_err(
436                    instruction,
437                    "missing value to store",
438                ))?;
439                let Value::List(list) = list else {
440                    return Err(Error::new(instruction.spans.clone(), InvalidIndexTarget {
441                        expr_type: list.typename(),
442                    }));
443                };
444
445                let index = extract_index(index, instruction.spans.clone())?;
446
447                let mut list = list.borrow_mut();
448                let len = list.len();
449                *list.get_mut(index).ok_or_else(|| {
450                    Error::new(instruction.spans.clone(), IndexOutOfBounds { len, index })
451                })? = value;
452            },
453            InstructionKind::LoadIndexed => {
454                let index = value_stack.pop().ok_or_else(|| internal_err(
455                    instruction,
456                    "missing index to load value at",
457                ))?;
458                let list = value_stack.pop().ok_or_else(|| internal_err(
459                    instruction,
460                    "missing list to load value from",
461                ))?;
462                let Value::List(list) = list else {
463                    return Err(Error::new(instruction.spans.clone(), InvalidIndexTarget {
464                        expr_type: list.typename(),
465                    }));
466                };
467
468                let index = extract_index(index, instruction.spans.clone())?;
469
470                let list = list.borrow();
471                let len = list.len();
472                let value = list.get(index).cloned().ok_or_else(|| {
473                    Error::new(instruction.spans.clone(), IndexOutOfBounds { len, index })
474                })?;
475                value_stack.push(value);
476            },
477            // erroring here helps us verify that exactly the right number of values are produced
478            // and popped through the program
479            InstructionKind::Drop => {
480                value_stack.pop().ok_or_else(|| internal_err(
481                    instruction,
482                    "nothing to drop",
483                ))?;
484            },
485            InstructionKind::Binary(op) => {
486                if let Err(err) = exec_binary_instruction(*op, value_stack) {
487                    return Err(err.into_error(instruction.spans.clone()));
488                }
489            },
490            InstructionKind::Unary(op) => {
491                if let Err(err) = exec_unary_instruction(*op, value_stack) {
492                    return Err(err.into_error(instruction.spans.clone()));
493                }
494            },
495            InstructionKind::Call(args_given) => {
496                let Value::Function(func) = value_stack.pop().ok_or_else(|| internal_err(
497                    instruction,
498                    "missing function to call",
499                ))? else {
500                    return Err(internal_err(instruction, "cannot call non-function"));
501                };
502                match func {
503                    Function::User(user) => {
504                        self.registers.call_site_spans = instruction.spans.clone();
505                        self.registers.num_args = *args_given;
506
507                        call_stack.push(Frame::new((
508                            instruction_pointer.0,
509                            instruction_pointer.1 + 1,
510                        )).with_variables(user.environment));
511                        check_stack_overflow(&instruction.spans, call_stack)?;
512                        *instruction_pointer = (user.index, 0);
513                        return Ok(ControlFlow::Jump);
514                    },
515                    Function::Builtin(func) => {
516                        let args = value_stack.split_off(value_stack.len() - *args_given);
517                        let value = func
518                            .eval(self.trig_mode, args)
519                            .map_err(|err| from_builtin_error(err, instruction.spans.clone()))?;
520                        value_stack.push(value);
521                    },
522                }
523            },
524            InstructionKind::CallDerivative(derivatives, num_args) => {
525                let Value::Function(func) = value_stack.pop().ok_or_else(|| internal_err(
526                    instruction,
527                    "missing function for prime notation",
528                ))? else {
529                    return Err(internal_err(instruction, "cannot use prime notation on non-function"));
530                };
531                match func {
532                    Function::User(user) => {
533                        self.registers.call_site_spans = instruction.spans.clone();
534                        self.registers.num_args = *num_args;
535
536                        let initial = value_stack.pop().ok_or_else(|| internal_err(
537                            instruction,
538                            "missing initial value for derivative computation",
539                        ))?;
540                        let derivative = Derivative::new(*derivatives, initial)
541                            .map_err(|err| err.into_error(instruction.spans.clone()))?;
542                        call_stack.push(Frame::new((
543                            instruction_pointer.0,
544                            instruction_pointer.1 + 1,
545                        )).with_derivative());
546                        check_stack_overflow(&instruction.spans, call_stack)?;
547                        value_stack.push(derivative.next_eval().ok_or_else(|| internal_err(
548                            instruction,
549                            "missing next value in derivative computation",
550                        ))?);
551                        derivative_stack.push(derivative);
552                        *instruction_pointer = (user.index, 0);
553                        return Ok(ControlFlow::Jump);
554                    },
555                    Function::Builtin(builtin) => {
556                        if builtin.sig().len() != 1 {
557                            return Err(Error::new(
558                                instruction.spans.clone(),
559                                InvalidDifferentiation {
560                                    name: builtin.name().to_string(),
561                                    actual: builtin.sig().len(),
562                                },
563                            ));
564                        }
565
566                        let initial = value_stack.pop().ok_or_else(|| internal_err(
567                            instruction,
568                            "missing initial value for derivative computation",
569                        ))?;
570                        let value = Derivative::new(*derivatives, initial)
571                            .and_then(|mut derv| derv.eval_builtin(builtin))
572                            .map_err(|err| err.into_error(instruction.spans.clone()))?;
573                        value_stack.push(value);
574                    },
575                }
576            },
577            InstructionKind::Return => {
578                let frame = call_stack.pop().ok_or_else(|| internal_err(
579                    instruction,
580                    "no frame to return to",
581                ))?;
582
583                if frame.derivative {
584                    // progress the derivative computation if needed
585                    let derivative = derivative_stack.last_mut().ok_or_else(|| internal_err(
586                        instruction,
587                        "no derivative computation to return to",
588                    ))?;
589                    let value = value_stack.pop().ok_or_else(|| internal_err(
590                        instruction,
591                        "missing value to feed in derivative computation",
592                    ))?;
593                    if let Some(value) = derivative.advance(value)
594                        .map_err(|err| err.into_error(instruction.spans.clone()))?
595                    {
596                        value_stack.push(value);
597                        *instruction_pointer = (frame.return_instruction.0, frame.return_instruction.1);
598                    } else {
599                        // back to the top; put the stack frame back
600                        // no need to check for stack overflow here, since we're not adding a new frame
601                        call_stack.push(frame);
602
603                        value_stack.push(derivative.next_eval().ok_or_else(|| internal_err(
604                            instruction,
605                            "missing next value in derivative computation",
606                        ))?);
607                        *instruction_pointer = (instruction_pointer.0, 0);
608                    }
609                } else {
610                    // return to the previous caller
611                    *instruction_pointer = frame.return_instruction;
612                }
613
614                return Ok(ControlFlow::Jump);
615            },
616            InstructionKind::Jump(label) => {
617                *instruction_pointer = self.labels[label];
618                return Ok(ControlFlow::Jump);
619            },
620            InstructionKind::JumpIfTrue(label) => {
621                let b = match value_stack.pop() {
622                    Some(Value::Boolean(b)) => b,
623                    Some(other) => return Err(Error::new(instruction.spans.clone(), ConditionalNotBoolean {
624                        expr_type: other.typename(),
625                    })),
626                    None => return Err(internal_err(instruction, "missing value to check")),
627                };
628
629                if b {
630                    *instruction_pointer = self.labels[label];
631                    return Ok(ControlFlow::Jump);
632                }
633            },
634            InstructionKind::JumpIfFalse(label) => {
635                let b = match value_stack.pop() {
636                    Some(Value::Boolean(b)) => b,
637                    Some(other) => return Err(Error::new(instruction.spans.clone(), ConditionalNotBoolean {
638                        expr_type: other.typename(),
639                    })),
640                    None => return Err(internal_err(instruction, "missing value to check")),
641                };
642
643                if !b {
644                    *instruction_pointer = self.labels[label];
645                    return Ok(ControlFlow::Jump);
646                }
647            },
648        }
649
650        Ok(ControlFlow::Continue)
651    }
652
653    /// Executes the bytecode instructions.
654    pub fn run(&mut self) -> Result<Value, Error> {
655        let mut call_stack = vec![Frame::new((0, 0)).with_variables(std::mem::take(&mut self.variables))];
656        let mut derivative_stack = vec![];
657        let mut value_stack = vec![];
658        let mut instruction_pointer = (0, 0);
659
660        self.registers = Registers::default();
661
662        while instruction_pointer.1 < self.chunks[instruction_pointer.0].instructions.len() {
663            match self.run_one(
664                &mut value_stack,
665                &mut call_stack,
666                &mut derivative_stack,
667                &mut instruction_pointer,
668            ).inspect_err(|_| {
669                // get the variables from the global stack frame, regardless of whether it changed
670                // since we started running the program
671                // this behavior is consistent with other VMs
672                self.variables = std::mem::take(&mut call_stack[0].variables);
673            })? {
674                ControlFlow::Continue => instruction_pointer.1 += 1,
675                ControlFlow::Jump => (),
676            }
677        }
678
679        assert_eq!(value_stack.len(), 1);
680        assert_eq!(call_stack.len(), 1);
681
682        self.variables = call_stack.pop().unwrap().variables;
683        Ok(value_stack.pop().unwrap())
684    }
685}
686
687/// A virtual machine that can compile and execute bytecode instructions in a REPL-like environment
688/// by maintaining state.
689#[derive(Debug, Default)]
690pub struct ReplVm {
691    /// Compiler used to hold the current state of the VM.
692    compiler: Compiler,
693
694    /// The current VM state.
695    vm: Vm,
696}
697
698impl ReplVm {
699    /// Creates a new [`ReplVm`] with the default context.
700    pub fn new() -> Self {
701        Self::default()
702    }
703
704    /// Compiles the given source code and executes it, updating the VM state.
705    pub fn execute(&mut self, stmts: Vec<Stmt>) -> Result<Value, Error> {
706        // TODO: this feels kinda hacky
707        let compiler_clone = self.compiler.clone();
708        let vm_clone = self.vm.clone();
709
710        self.compiler.chunks = std::mem::replace(&mut self.vm.chunks, vec![Default::default()]);
711        self.compiler.chunks[0].instructions.clear();
712        self.compiler.labels = std::mem::take(&mut self.vm.labels)
713            .into_iter()
714            .map(|(label, location)| (label, Some(location)))
715            .collect();
716        self.compiler.sym_table = std::mem::take(&mut self.vm.sym_table);
717
718        compile_stmts(&stmts, &mut self.compiler).inspect_err(|_| {
719            // restore the previous state of the VM to ensure compilation errors don't affect the
720            // current state
721            self.compiler = compiler_clone;
722            self.vm = vm_clone;
723        })?;
724
725        self.vm.chunks = std::mem::replace(&mut self.compiler.chunks, vec![Default::default()]);
726        self.vm.labels = std::mem::take(&mut self.compiler.labels)
727            .into_iter()
728            .map(|(label, location)| (label, location.unwrap()))
729            .collect();
730        self.vm.sym_table = std::mem::take(&mut self.compiler.sym_table);
731
732        self.vm.run()
733    }
734}
735
736#[cfg(test)]
737mod tests {
738    use super::*;
739    use cas_compute::{
740        funcs::miscellaneous::{Abs, Factorial},
741        numerical::{builtin::Builtin, value::Value},
742        primitive::{float, float_from_str, int},
743    };
744    use cas_parser::parser::{ast::stmt::Stmt, Parser};
745    use rug::ops::Pow;
746
747    /// Compile the given source code and execute the resulting bytecode.
748    fn run_program(source: &str) -> Result<Value, Error> {
749        let mut parser = Parser::new(source);
750        let stmts = parser.try_parse_full_many::<Stmt>().unwrap();
751
752        let mut vm = Vm::compile_program(stmts)?;
753        vm.run()
754    }
755
756    /// Compile the given source code and execute the resulting bytecode, using degrees as the
757    /// trigonometric mode.
758    fn run_program_degrees(source: &str) -> Result<Value, Error> {
759        let mut parser = Parser::new(source);
760        let stmts = parser.try_parse_full_many::<Stmt>().unwrap();
761
762        let mut vm = Vm::compile_program(stmts)?
763            .with_trig_mode(TrigMode::Degrees);
764        vm.run()
765    }
766
767    #[test]
768    fn binary_expr() {
769        let result = run_program("1 + 2").unwrap();
770        assert_eq!(result, Value::Integer(int(3)));
771    }
772
773    #[test]
774    fn binary_expr_2() {
775        let result = run_program("1 + 2 * 3").unwrap();
776        assert_eq!(result, Value::Integer(int(7)));
777    }
778
779    #[test]
780    fn binary_and_unary() {
781        let result = run_program("3 * -5 / 5! + 6").unwrap();
782        assert_eq!(result, Value::Float(float(5.875)));
783    }
784
785    #[test]
786    fn parenthesized() {
787        let result = run_program("((1 + 9) / 5) * 3").unwrap();
788        assert_eq!(result, Value::Integer(int(6)));
789    }
790
791    #[test]
792    fn degree_to_radian() {
793        let result = run_program_degrees("90 * 2 * pi / 360").unwrap();
794        assert_eq!(result, Value::Float(float(&*PI) / 2));
795    }
796
797    #[test]
798    fn precision() {
799        let result = run_program("e^2 - tau").unwrap();
800        assert_eq!(result, Value::Float(float(&*E).pow(2) - float(&*TAU)));
801    }
802
803    #[test]
804    fn precision_2() {
805        let result = run_program("pi^2 * 17! / -4.9 + e").unwrap();
806
807        let fac_17 = if let Value::Integer(fac_17) = Factorial::eval_static(float(17)) {
808            fac_17
809        } else {
810            unreachable!("factorial of 17 is an integer")
811        };
812        let expected = float(&*PI).pow(2) * fac_17 / -float_from_str("4.9") + float(&*E);
813        assert_eq!(result, Value::Float(expected));
814    }
815
816    #[test]
817    fn func_call() {
818        let source = [
819            ("f(x) = x^2 + 5x + 6;", Value::Unit),
820            ("f(7)", 90.into()),
821        ];
822
823        let mut vm = ReplVm::new();
824        for (stmt, expected) in source {
825            let mut parser = Parser::new(stmt);
826            let stmt = parser.try_parse_full::<Stmt>().unwrap();
827            assert_eq!(vm.execute(vec![stmt]).unwrap(), expected);
828        }
829    }
830
831    #[test]
832    fn complicated_func_call() {
833        let source = [
834            ("f(n = 3, k = 6) = n * k;", Value::Unit),
835            ("f()", 18.into()),
836            ("f(9)", 54.into()),
837            ("f(8, 14)", 112.into()),
838        ];
839
840        let mut vm = ReplVm::new();
841        for (stmt, expected) in source {
842            println!("executing: {}", stmt);
843            let mut parser = Parser::new(stmt);
844            let stmt = parser.try_parse_full::<Stmt>().unwrap();
845            assert_eq!(vm.execute(vec![stmt]).unwrap(), expected);
846        }
847    }
848
849    #[test]
850    fn user_func_default_param() {
851        let result = run_program("f(x = 2) = x; f() + f(3)").unwrap();
852        assert_eq!(result, Value::Integer(int(5)));
853    }
854
855    #[test]
856    fn user_func_mixed_param() {
857        let result = run_program("f(a, b, c = 3, d = 4) = a b c d; f(1, 2)").unwrap();
858        assert_eq!(result, Value::Integer(int(24)));
859    }
860
861    #[test]
862    fn user_func_bad_mixed_param() {
863        assert!(run_program("f(a, b, c = 3, d = 4) = a b c d; f()").is_err());
864    }
865
866    #[test]
867    fn builtin_func_arg_check() {
868        assert_eq!(Abs.eval(Default::default(), vec![Value::from(4.0)]).unwrap().coerce_float(), 4.0.into());
869        assert!(Abs.eval(Default::default(), vec![Value::Unit]).is_err());
870    }
871
872    #[test]
873    fn exec_literal_number() {
874        let result = run_program("42").unwrap();
875        assert_eq!(result, Value::Integer(int(42)));
876    }
877
878    #[test]
879    fn exec_multiple_assignment() {
880        let result = run_program("x=y=z=5").unwrap();
881        assert_eq!(result, Value::Integer(int(5)));
882    }
883
884    #[test]
885    fn exec_loop() {
886        let result = run_program("a = 0
887while a < 10 {
888    a += 1
889}; a").unwrap();
890        assert_eq!(result, Value::Integer(int(10)));
891    }
892
893    #[test]
894    fn exec_dumb_loop() {
895        let result = run_program("while true break").unwrap();
896        assert_eq!(result, Value::Unit);
897    }
898
899    #[test]
900    fn exec_loop_with_conditions() {
901        let result = run_program("a = 0
902j = 2
903while a < 10 && j < 15 {
904    if a < 5 {
905        a += 2
906    } else {
907        a += 1
908        j = -j + 4
909    }
910}; j").unwrap();
911        assert_eq!(result, Value::Integer(int(2)));
912    }
913
914    #[test]
915    fn exec_simple_program() {
916        let result = run_program("x = 4.5
9173x + 45 (x + 2) (1 + 3)").unwrap();
918        assert_eq!(result, Value::Float(float(1183.5)));
919    }
920
921    #[test]
922    fn exec_trig_mode() {
923        let result_1 = run_program("sin(pi/2)").unwrap();
924        let result_2 = run_program_degrees("sin(90)").unwrap();
925        assert_eq!(result_1.coerce_float(), Value::Float(float(1)));
926        assert_eq!(result_2.coerce_float(), Value::Float(float(1)));
927    }
928
929    #[test]
930    fn exec_factorial() {
931        let result = run_program("n = result = 8
932loop {
933    n -= 1
934    result *= n
935    if n <= 1 break result
936}").unwrap();
937        assert_eq!(result, Value::Integer(int(40320)));
938    }
939
940    #[test]
941    fn exec_partial_factorial() {
942        let result = run_program("partial_factorial(n, k) = {
943    result = 1
944    while n > k {
945        result *= n
946        n -= 1
947    }
948    result
949}
950
951partial_factorial(10, 7)").unwrap();
952        assert_eq!(result, Value::Integer(int(720)));
953    }
954
955    #[test]
956    fn exec_sum_even() {
957        let result = run_program("n = 200
958c = 0
959total = 0
960while c < n {
961    c += 1
962    if c & 1 == 1 continue
963    total += c
964}; total").unwrap();
965        assert_eq!(result, Value::Integer(int(10100)));
966    }
967
968    #[test]
969    fn exec_list_index() {
970        let result = run_program("arr = [1, 2, 3]
971arr[0] = 5
972arr[0] + arr[1] + arr[2] == 10").unwrap();
973        assert_eq!(result, Value::Boolean(true));
974    }
975
976    #[test]
977    fn exec_inline_indexing() {
978        let result = run_program("[1, 2, 3][im(2sqrt(-1))]").unwrap();
979        assert_eq!(result, Value::Integer(int(3)));
980    }
981
982    #[test]
983    fn exec_call_func() {
984        let result = run_program("g() = 6
985f(x) = x^2 + 5x + g()
986f(32)").unwrap();
987        assert_eq!(result, Value::Integer(int(1190)));
988    }
989
990    #[test]
991    fn exec_valid_prime_notation() {
992        let result = run_program("f(x) = log(x, 2)
993g(x) = 1/(x ln(2))
994f'(64) ~== g(64)").unwrap();
995        assert_eq!(result, Value::Boolean(true));
996    }
997
998    #[test]
999    fn exec_valid_prime_notation_but_wrong_args() {
1000        let result = run_program("f(x) = x^2 + 5x + 6; f'(64, 32)").unwrap_err();
1001
1002        // error: too many arguments passed to f
1003        // spans[0] = start of "f'(..."
1004        // spans[1] = closing parenthesis of "f'(...)"
1005        // spans[2] = points two second argument "32"
1006        assert_eq!(result.spans, vec![
1007            21..24,
1008            30..31,
1009            28..30,
1010        ]);
1011    }
1012
1013    #[test]
1014    fn exec_unsupported_prime_notation() {
1015        let result = run_program("log'(32, 2)").unwrap_err();
1016
1017        // error: can only differentiate functions with exactly 1 argument
1018        // spans[0] = start of "log'(..."
1019        assert_eq!(result.spans[0], 0..5);
1020    }
1021
1022    #[test]
1023    fn exec_scoping() {
1024        let err = run_program("f() = j + 6
1025g() = {
1026    j = 10
1027    f()
1028}").unwrap_err();
1029
1030        // error: undefined variable `j`
1031        // variable `j` is defined in `g`, so `f` can only access it if `x` is passed as an
1032        // argument, or `j` is in a higher scope
1033        // spans[0] = variable `j` in `j + 6`
1034        assert_eq!(err.spans, vec![6..7]);
1035    }
1036
1037    #[test]
1038    fn exec_define_and_call() {
1039        let result = match run_program("f(x) = 2/sqrt(x)
1040g(x, y) = f(x) + f(y)
1041g(2, 3)").unwrap() {
1042            Value::Float(f) => f,
1043            other => panic!("expected float, got {:?}", other),
1044        };
1045
1046        let left = int(6) * float(2).sqrt();
1047        let right = int(4) * float(3).sqrt();
1048        let value = (left + right) / 6;
1049        assert!(float(result - value).abs() < 1e-6);
1050    }
1051
1052    #[test]
1053    fn exec_branching_return() {
1054        let result = run_program("f(x) = {
1055    if x < 0 return -x
1056    x
1057}
1058f(-5)").unwrap();
1059        assert_eq!(result, Value::Integer(int(5)));
1060    }
1061
1062    #[test]
1063    fn exec_unit_mess() {
1064        let result = run_program("f() = {}
1065f()").unwrap();
1066        assert_eq!(result, Value::Unit);
1067    }
1068
1069    #[test]
1070    fn exec_arithmetic_sequence_summation() {
1071        let result = run_program("f(a, d, n) = n / 2 * (2a + (n - 1)d)
1072g(a, d, n) = sum i in 0..n of a + i d
1073f(1, 2, 10) == g(1, 2, 10)").unwrap();
1074        assert_eq!(result, Value::Boolean(true));
1075    }
1076
1077    #[test]
1078    fn exec_product_factorial() {
1079        let result = run_program("iter_fac(n) = product i in 1..=n of i; iter_fac(5)").unwrap();
1080        assert_eq!(result, Value::Integer(int(120)));
1081    }
1082
1083    #[test]
1084    fn exec_for_loop_with_control() {
1085        let result = run_program("list = [20, 30, 40, 50, 60]
1086for i in 0..5 {
1087    list[i] += 5
1088    if i == 2 break
1089}; list").unwrap();
1090        assert_eq!(result, vec![
1091            25.into(),
1092            35.into(),
1093            45.into(),
1094            50.into(),
1095            60.into()
1096        ].into());
1097    }
1098
1099    #[test]
1100    fn exec_sum_even_indices() {
1101        let result = run_program("list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
1102total = 0
1103for i in 0..10 {
1104    if i & 1 == 1 continue
1105    total += list[i]
1106}; total").unwrap();
1107        assert_eq!(result, Value::Integer(int(25)));
1108    }
1109
1110    #[test]
1111    fn exec_incr_list() {
1112        let result = run_program("arr = [1, 2, 3]
1113arr[0] += 5
1114arr[1] += 6
1115arr[2] += 7
1116arr").unwrap();
1117        assert_eq!(result, vec![6.into(), 8.into(), 10.into()].into());
1118    }
1119
1120    #[test]
1121    fn example_bad_lcm() {
1122        let source = include_str!("../../examples/bad_lcm.calc");
1123        let result = run_program(source).unwrap();
1124        assert_eq!(result, 1517.into());
1125    }
1126
1127    #[test]
1128    fn example_factorial() {
1129        let source = include_str!("../../examples/factorial.calc");
1130        let result = run_program(source).unwrap();
1131        assert_eq!(result, true.into());
1132    }
1133
1134    #[test]
1135    fn example_convert_binary() {
1136        let source = include_str!("../../examples/convert_binary.calc");
1137        let result = run_program(source).unwrap();
1138        assert_eq!(result, vec![true.into(); 7].into());
1139    }
1140
1141    #[test]
1142    fn example_environment_capture() {
1143        let source = include_str!("../../examples/environment_capture.calc");
1144        let result = run_program(source).unwrap();
1145        assert_eq!(result, vec![55.into(), 20.into()].into());
1146    }
1147
1148    #[test]
1149    fn example_function_scope() {
1150        let source = include_str!("../../examples/function_scope.calc");
1151        let result = run_program(source).unwrap();
1152        assert_eq!(result, 14.into());
1153    }
1154
1155    #[test]
1156    fn example_higher_order_function() {
1157        let source = include_str!("../../examples/higher_order_function.calc");
1158        let result = run_program(source).unwrap();
1159        assert_eq!(result, vec![
1160            16.0.into(),
1161            (float(32) / float(3)).into(),
1162            20.0.into(),
1163            (float(40) / float(3)).into(),
1164        ].into());
1165    }
1166
1167    #[test]
1168    fn example_if_branching() {
1169        let source = include_str!("../../examples/if_branching.calc");
1170        let result = run_program(source).unwrap();
1171        assert_eq!(result.coerce_float(), float(5).log2().into());
1172    }
1173
1174    #[test]
1175    fn example_manual_abs() {
1176        let source = include_str!("../../examples/manual_abs.calc");
1177        let result = run_program(source).unwrap();
1178        assert_eq!(result, 4.into());
1179    }
1180
1181    #[test]
1182    fn example_map_list() {
1183        let source = include_str!("../../examples/map_list.calc");
1184        let result = run_program(source).unwrap();
1185        assert_eq!(result, vec![
1186            complex(0).into(),
1187            complex(1).into(),
1188            complex(2).into(),
1189            complex(3).into(),
1190            complex(4).into(),
1191            complex(5).into(),
1192            complex(6).into(),
1193            complex(7).into(),
1194            complex(8).into(),
1195            complex(9).into(),
1196        ].into());
1197    }
1198
1199    #[test]
1200    fn example_memoized_fib() {
1201        let source = include_str!("../../examples/memoized_fib.calc");
1202        let result = run_program(source).unwrap();
1203        assert_eq!(result, 6_557_470_319_842.into());
1204    }
1205
1206    #[test]
1207    fn example_ncr() {
1208        let source = include_str!("../../examples/ncr.calc");
1209        let result = run_program(source).unwrap();
1210        assert_eq!(result, true.into());
1211    }
1212
1213    #[test]
1214    fn example_prime_notation() {
1215        let source = include_str!("../../examples/prime_notation.calc");
1216        let result = run_program(source).unwrap();
1217        assert_eq!(result, vec![true.into(); 15].into());
1218    }
1219
1220    #[test]
1221    fn example_resolving_calls() {
1222        let source = include_str!("../../examples/resolving_calls.calc");
1223        let result = run_program(source).unwrap();
1224        assert_eq!(result, true.into());
1225    }
1226
1227    #[test]
1228    fn repl() {
1229        // ensure REPL state is restored if compilation fails
1230        let source = [
1231            "f() = x", // compile error: `x` is not defined
1232            "f()", // compile error: `f` is not defined
1233        ];
1234
1235        let mut vm = ReplVm::new();
1236        for stmt in &source {
1237            let mut parser = Parser::new(stmt);
1238            let stmt = parser.try_parse_full::<Stmt>().unwrap();
1239            vm.execute(vec![stmt]).unwrap_err();
1240        }
1241    }
1242}