1#![doc = include_str!("../README.md")]
2
3pub mod error;
4mod frame;
5mod instruction;
6mod register;
7
8use cas_compute::{
9 consts::{E, I, PHI, PI, TAU},
10 funcs::all as all_funcs,
11 numerical::{builtin::error::BuiltinError, func::Function, trig_mode::TrigMode, value::Value},
12 primitive::{complex, float},
13};
14use cas_compiler::{
15 expr::compile_stmts,
16 instruction::{Instruction, InstructionKind},
17 item::Symbol,
18 sym_table::SymbolTable,
19 Chunk,
20 Compile,
21 Compiler,
22 Label,
23};
24use cas_error::Error;
25use cas_parser::parser::ast::Stmt;
26use error::{
27 ConditionalNotBoolean,
28 IndexOutOfBounds,
29 IndexOutOfRange,
30 InternalError,
31 InvalidDifferentiation,
32 InvalidIndexTarget,
33 InvalidIndexType,
34 InvalidLengthType,
35 LengthOutOfRange,
36 MissingArgument,
37 StackOverflow,
38 TooManyArguments,
39 TypeMismatch,
40};
41use frame::Frame;
42use instruction::{
43 exec_binary_instruction,
44 exec_unary_instruction,
45 Derivative,
46};
47use register::Registers;
48use std::{cell::RefCell, collections::HashMap, ops::Range, rc::Rc};
49
50const MAX_STACK_FRAMES: usize = 2usize.pow(16);
52
53#[derive(Debug)]
56enum ControlFlow {
57 Continue,
59
60 Jump,
62}
63
64#[derive(Clone, Debug)]
67pub struct Vm {
68 trig_mode: TrigMode,
70
71 pub chunks: Vec<Chunk>,
73
74 labels: HashMap<Label, (usize, usize)>,
76
77 pub sym_table: SymbolTable,
82
83 pub variables: HashMap<usize, Value>,
85
86 registers: Registers,
88}
89
90impl Default for Vm {
91 fn default() -> Self {
92 Self {
93 trig_mode: TrigMode::default(),
94 chunks: vec![Chunk::default()], labels: HashMap::new(),
96 sym_table: SymbolTable::default(),
97 variables: HashMap::new(),
98 registers: Registers::default(),
99 }
100 }
101}
102
103impl From<Compiler> for Vm {
104 fn from(compiler: Compiler) -> Self {
105 Self {
106 trig_mode: TrigMode::default(),
107 chunks: compiler.chunks,
108 labels: compiler.labels
109 .into_iter()
110 .map(|(label, location)| (label, location.unwrap()))
111 .collect(),
112 sym_table: compiler.sym_table,
113 variables: HashMap::new(),
114 registers: Registers::default(),
115 }
116 }
117}
118
119impl Vm {
120 pub fn new() -> Self {
123 Self::default()
124 }
125
126 pub fn compile<T: Compile>(expr: T) -> Result<Self, Error> {
128 let compiler = Compiler::compile(expr)?;
129 Ok(Self {
130 trig_mode: TrigMode::default(),
131 chunks: compiler.chunks,
132 labels: compiler.labels
133 .into_iter()
134 .map(|(label, location)| (label, location.unwrap()))
135 .collect(),
136 sym_table: compiler.sym_table,
137 variables: HashMap::new(),
138 registers: Registers::default(),
139 })
140 }
141
142 pub fn compile_program(stmts: Vec<Stmt>) -> Result<Self, Error> {
144 let compiler = Compiler::compile_program(stmts)?;
145 Ok(Self {
146 trig_mode: TrigMode::default(),
147 chunks: compiler.chunks,
148 labels: compiler.labels
149 .into_iter()
150 .map(|(label, location)| (label, location.unwrap()))
151 .collect(),
152 sym_table: compiler.sym_table,
153 variables: HashMap::new(),
154 registers: Registers::default(),
155 })
156 }
157
158 pub fn with_trig_mode(mut self, mode: TrigMode) -> Self {
160 self.trig_mode = mode;
161 self
162 }
163
164 fn run_one(
168 &mut self,
169 value_stack: &mut Vec<Value>,
170 call_stack: &mut Vec<Frame>,
171 derivative_stack: &mut Vec<Derivative>,
172 instruction_pointer: &mut (usize, usize),
173 ) -> Result<ControlFlow, Error> {
174 fn check_stack_overflow(call_span: &[Range<usize>], call_stack: &[Frame]) -> Result<(), Error> {
176 if call_stack.len() > MAX_STACK_FRAMES {
178 return Err(Error::new(call_span.to_vec(), StackOverflow));
179 }
180
181 Ok(())
182 }
183
184 fn extract_index(value: Value, spans: Vec<Range<usize>>) -> Result<usize, Error> {
186 let typename = value.typename();
187 let Value::Integer(int) = value.coerce_integer() else {
188 return Err(Error::new(
189 spans,
190 InvalidIndexType {
191 expr_type: typename,
192 },
193 ));
194 };
195
196 int.to_usize().ok_or_else(|| Error::new(
197 spans,
198 IndexOutOfRange,
199 ))
200 }
201
202 fn extract_length(value: Value, spans: Vec<Range<usize>>) -> Result<usize, Error> {
206 let typename = value.typename();
207 let Value::Integer(int) = value.coerce_integer() else {
208 return Err(Error::new(
209 spans,
210 InvalidLengthType {
211 expr_type: typename,
212 },
213 ));
214 };
215
216 int.to_usize().ok_or_else(|| Error::new(
217 spans,
218 LengthOutOfRange,
219 ))
220 }
221
222 fn from_builtin_error(err: BuiltinError, spans: Vec<Range<usize>>) -> Error {
224 match err {
225 BuiltinError::TooManyArguments(err) => {
226 let spans = vec![
228 spans[0].clone(),
229 spans[1].clone(),
230 spans[2 + err.expected].start..spans.last().unwrap().end,
231 ];
232 Error::new(spans, TooManyArguments::from(err))
233 },
234 BuiltinError::MissingArgument(err) => {
235 Error::new(spans, MissingArgument::from(err))
236 },
237 BuiltinError::TypeMismatch(err) => {
238 let spans = vec![
240 spans[0].clone(),
241 spans[1].clone(),
242 spans[2 + err.index].clone(),
243 ];
244 Error::new(spans, TypeMismatch::from(err))
245 },
246 _ => todo!(),
247 }
248 }
249
250 fn internal_err(instruction: &Instruction, data: impl Into<String>) -> Error {
252 Error::new(
253 instruction.spans.clone(),
254 InternalError {
255 instruction: format!("{:?}", instruction.kind),
256 data: data.into(),
257 },
258 )
259 }
260
261 let instruction = &self.chunks[instruction_pointer.0].instructions[instruction_pointer.1];
262 match &instruction.kind {
269 InstructionKind::InitFunc(fn_name, fn_signature, num_params, num_default_params) => {
270 self.registers.fn_name = fn_name.clone();
271 self.registers.fn_signature = fn_signature.clone();
272 self.registers.num_params = *num_params;
273 self.registers.num_default_params = *num_default_params;
274
275 if call_stack.last().unwrap().derivative && self.registers.num_params != 1 {
276 return Err(Error::new(
277 self.registers.call_site_spans.clone(),
278 InvalidDifferentiation {
279 name: self.registers.fn_name.clone(),
280 actual: self.registers.num_params,
281 },
282 ));
283 } else if self.registers.num_args > self.registers.num_params {
284 let spans = &self.registers.call_site_spans;
285 let expected = self.registers.num_params;
286 return Err(Error::new(
287 vec![
290 spans[0].clone(),
291 spans[1].clone(),
292 spans[2 + expected].start..spans.last().unwrap().end,
293 ],
294 TooManyArguments {
295 name: self.registers.fn_name.clone(),
296 expected,
297 given: self.registers.num_args,
298 signature: self.registers.fn_signature.clone(),
299 },
300 ));
301 }
302 },
303 InstructionKind::CheckExecReady => {
304 value_stack.push(Value::Boolean(
305 self.registers.num_args == self.registers.num_params,
306 ));
307 },
308 InstructionKind::NextArg => {
309 self.registers.num_args += 1;
311 },
312 InstructionKind::ErrorIfMissingArgs => {
313 if self.registers.num_args != self.registers.num_params {
314 return Err(Error::new(
315 self.registers.call_site_spans.clone(),
316 MissingArgument {
317 name: self.registers.fn_name.clone(),
318 indices: {
319 let offset = self.registers.num_default_params;
320 let start_idx = self.registers.num_args - offset;
321 let end_idx = self.registers.num_params - offset;
322 start_idx..end_idx
323 },
324 expected: self.registers.num_params,
325 given: self.registers.num_args,
326 signature: self.registers.fn_signature.clone(),
327 },
328 ));
329 }
330 },
331 InstructionKind::LoadConst(value) => {
332 let mut value = value.clone();
333
334 if let Value::Function(Function::User(user)) = &mut value {
337 user.environment = call_stack.last().unwrap().variables.clone();
338 }
339
340 value_stack.push(value);
341 },
342 InstructionKind::CreateList(len) => {
343 let elements = value_stack.split_off(value_stack.len() - *len);
344 value_stack.push(Value::List(Rc::new(RefCell::new(elements))));
345 },
346 InstructionKind::CreateListRepeat => {
347 let count = value_stack.pop().ok_or_else(|| internal_err(
348 instruction,
349 "missing # of repetitions",
350 ))?;
351 let value = value_stack.pop().ok_or_else(|| internal_err(
352 instruction,
353 "missing value to repeat",
354 ))?;
355 let list = vec![value; extract_length(count, instruction.spans.clone())?];
356 value_stack.push(Value::List(Rc::new(RefCell::new(list))));
357 },
358 InstructionKind::CreateRange(kind) => {
359 let end = value_stack.pop().ok_or_else(|| internal_err(
360 instruction,
361 "missing end of range",
362 ))?;
363 let start = value_stack.pop().ok_or_else(|| internal_err(
364 instruction,
365 "missing start of range",
366 ))?;
367 value_stack.push(Value::Range(
368 Box::new(start),
369 *kind,
370 Box::new(end),
371 ));
372 },
373 InstructionKind::LoadVar(symbol) => match symbol {
374 Symbol::User(id) => {
375 let value = call_stack
376 .iter()
377 .rev()
378 .find_map(|frame| frame.get_variable(*id))
379 .cloned()
380 .ok_or_else(|| internal_err(
381 instruction,
382 format!("user variable `{}` not initialized", id),
383 ))?;
384 value_stack.push(value);
385 }
386 Symbol::Builtin(name) => {
387 value_stack.push(match *name {
388 "i" => Value::Complex(complex(&*I)),
389 "e" => Value::Float(float(&*E)),
390 "phi" => Value::Float(float(&*PHI)),
391 "pi" => Value::Float(float(&*PI)),
392 "tau" => Value::Float(float(&*TAU)),
393 other => Value::Function(Function::Builtin(
394 all_funcs()
395 .get(other)
396 .ok_or_else(|| internal_err(
397 instruction,
398 format!("builtin function `{}` not found", other),
399 ))?
400 .as_ref()
401 )),
402 });
403 }
404 },
405 InstructionKind::StoreVar(id) => {
406 let last_frame = call_stack.last_mut().ok_or_else(|| internal_err(
407 instruction,
408 "no stack frame to store variable in",
409 ))?;
410 let value = value_stack.last().cloned().ok_or_else(|| internal_err(
411 instruction,
412 "no value to store in variable",
413 ))?;
414 last_frame.add_variable(id.to_owned(), value);
415 },
416 InstructionKind::AssignVar(id) => {
417 let last_frame = call_stack.last_mut().ok_or_else(|| internal_err(
418 instruction,
419 "no stack frame to store variable in",
420 ))?;
421 last_frame.add_variable(id.to_owned(), value_stack.pop().ok_or_else(|| internal_err(
422 instruction,
423 "no value to store in variable",
424 ))?);
425 },
426 InstructionKind::StoreIndexed => {
427 let index = value_stack.pop().ok_or_else(|| internal_err(
428 instruction,
429 "missing index to store value at",
430 ))?;
431 let list = value_stack.pop().ok_or_else(|| internal_err(
432 instruction,
433 "missing list to store value in",
434 ))?;
435 let value = value_stack.last().cloned().ok_or_else(|| internal_err(
436 instruction,
437 "missing value to store",
438 ))?;
439 let Value::List(list) = list else {
440 return Err(Error::new(instruction.spans.clone(), InvalidIndexTarget {
441 expr_type: list.typename(),
442 }));
443 };
444
445 let index = extract_index(index, instruction.spans.clone())?;
446
447 let mut list = list.borrow_mut();
448 let len = list.len();
449 *list.get_mut(index).ok_or_else(|| {
450 Error::new(instruction.spans.clone(), IndexOutOfBounds { len, index })
451 })? = value;
452 },
453 InstructionKind::LoadIndexed => {
454 let index = value_stack.pop().ok_or_else(|| internal_err(
455 instruction,
456 "missing index to load value at",
457 ))?;
458 let list = value_stack.pop().ok_or_else(|| internal_err(
459 instruction,
460 "missing list to load value from",
461 ))?;
462 let Value::List(list) = list else {
463 return Err(Error::new(instruction.spans.clone(), InvalidIndexTarget {
464 expr_type: list.typename(),
465 }));
466 };
467
468 let index = extract_index(index, instruction.spans.clone())?;
469
470 let list = list.borrow();
471 let len = list.len();
472 let value = list.get(index).cloned().ok_or_else(|| {
473 Error::new(instruction.spans.clone(), IndexOutOfBounds { len, index })
474 })?;
475 value_stack.push(value);
476 },
477 InstructionKind::Drop => {
480 value_stack.pop().ok_or_else(|| internal_err(
481 instruction,
482 "nothing to drop",
483 ))?;
484 },
485 InstructionKind::Binary(op) => {
486 if let Err(err) = exec_binary_instruction(*op, value_stack) {
487 return Err(err.into_error(instruction.spans.clone()));
488 }
489 },
490 InstructionKind::Unary(op) => {
491 if let Err(err) = exec_unary_instruction(*op, value_stack) {
492 return Err(err.into_error(instruction.spans.clone()));
493 }
494 },
495 InstructionKind::Call(args_given) => {
496 let Value::Function(func) = value_stack.pop().ok_or_else(|| internal_err(
497 instruction,
498 "missing function to call",
499 ))? else {
500 return Err(internal_err(instruction, "cannot call non-function"));
501 };
502 match func {
503 Function::User(user) => {
504 self.registers.call_site_spans = instruction.spans.clone();
505 self.registers.num_args = *args_given;
506
507 call_stack.push(Frame::new((
508 instruction_pointer.0,
509 instruction_pointer.1 + 1,
510 )).with_variables(user.environment));
511 check_stack_overflow(&instruction.spans, call_stack)?;
512 *instruction_pointer = (user.index, 0);
513 return Ok(ControlFlow::Jump);
514 },
515 Function::Builtin(func) => {
516 let args = value_stack.split_off(value_stack.len() - *args_given);
517 let value = func
518 .eval(self.trig_mode, args)
519 .map_err(|err| from_builtin_error(err, instruction.spans.clone()))?;
520 value_stack.push(value);
521 },
522 }
523 },
524 InstructionKind::CallDerivative(derivatives, num_args) => {
525 let Value::Function(func) = value_stack.pop().ok_or_else(|| internal_err(
526 instruction,
527 "missing function for prime notation",
528 ))? else {
529 return Err(internal_err(instruction, "cannot use prime notation on non-function"));
530 };
531 match func {
532 Function::User(user) => {
533 self.registers.call_site_spans = instruction.spans.clone();
534 self.registers.num_args = *num_args;
535
536 let initial = value_stack.pop().ok_or_else(|| internal_err(
537 instruction,
538 "missing initial value for derivative computation",
539 ))?;
540 let derivative = Derivative::new(*derivatives, initial)
541 .map_err(|err| err.into_error(instruction.spans.clone()))?;
542 call_stack.push(Frame::new((
543 instruction_pointer.0,
544 instruction_pointer.1 + 1,
545 )).with_derivative());
546 check_stack_overflow(&instruction.spans, call_stack)?;
547 value_stack.push(derivative.next_eval().ok_or_else(|| internal_err(
548 instruction,
549 "missing next value in derivative computation",
550 ))?);
551 derivative_stack.push(derivative);
552 *instruction_pointer = (user.index, 0);
553 return Ok(ControlFlow::Jump);
554 },
555 Function::Builtin(builtin) => {
556 if builtin.sig().len() != 1 {
557 return Err(Error::new(
558 instruction.spans.clone(),
559 InvalidDifferentiation {
560 name: builtin.name().to_string(),
561 actual: builtin.sig().len(),
562 },
563 ));
564 }
565
566 let initial = value_stack.pop().ok_or_else(|| internal_err(
567 instruction,
568 "missing initial value for derivative computation",
569 ))?;
570 let value = Derivative::new(*derivatives, initial)
571 .and_then(|mut derv| derv.eval_builtin(builtin))
572 .map_err(|err| err.into_error(instruction.spans.clone()))?;
573 value_stack.push(value);
574 },
575 }
576 },
577 InstructionKind::Return => {
578 let frame = call_stack.pop().ok_or_else(|| internal_err(
579 instruction,
580 "no frame to return to",
581 ))?;
582
583 if frame.derivative {
584 let derivative = derivative_stack.last_mut().ok_or_else(|| internal_err(
586 instruction,
587 "no derivative computation to return to",
588 ))?;
589 let value = value_stack.pop().ok_or_else(|| internal_err(
590 instruction,
591 "missing value to feed in derivative computation",
592 ))?;
593 if let Some(value) = derivative.advance(value)
594 .map_err(|err| err.into_error(instruction.spans.clone()))?
595 {
596 value_stack.push(value);
597 *instruction_pointer = (frame.return_instruction.0, frame.return_instruction.1);
598 } else {
599 call_stack.push(frame);
602
603 value_stack.push(derivative.next_eval().ok_or_else(|| internal_err(
604 instruction,
605 "missing next value in derivative computation",
606 ))?);
607 *instruction_pointer = (instruction_pointer.0, 0);
608 }
609 } else {
610 *instruction_pointer = frame.return_instruction;
612 }
613
614 return Ok(ControlFlow::Jump);
615 },
616 InstructionKind::Jump(label) => {
617 *instruction_pointer = self.labels[label];
618 return Ok(ControlFlow::Jump);
619 },
620 InstructionKind::JumpIfTrue(label) => {
621 let b = match value_stack.pop() {
622 Some(Value::Boolean(b)) => b,
623 Some(other) => return Err(Error::new(instruction.spans.clone(), ConditionalNotBoolean {
624 expr_type: other.typename(),
625 })),
626 None => return Err(internal_err(instruction, "missing value to check")),
627 };
628
629 if b {
630 *instruction_pointer = self.labels[label];
631 return Ok(ControlFlow::Jump);
632 }
633 },
634 InstructionKind::JumpIfFalse(label) => {
635 let b = match value_stack.pop() {
636 Some(Value::Boolean(b)) => b,
637 Some(other) => return Err(Error::new(instruction.spans.clone(), ConditionalNotBoolean {
638 expr_type: other.typename(),
639 })),
640 None => return Err(internal_err(instruction, "missing value to check")),
641 };
642
643 if !b {
644 *instruction_pointer = self.labels[label];
645 return Ok(ControlFlow::Jump);
646 }
647 },
648 }
649
650 Ok(ControlFlow::Continue)
651 }
652
653 pub fn run(&mut self) -> Result<Value, Error> {
655 let mut call_stack = vec![Frame::new((0, 0)).with_variables(std::mem::take(&mut self.variables))];
656 let mut derivative_stack = vec![];
657 let mut value_stack = vec![];
658 let mut instruction_pointer = (0, 0);
659
660 self.registers = Registers::default();
661
662 while instruction_pointer.1 < self.chunks[instruction_pointer.0].instructions.len() {
663 match self.run_one(
664 &mut value_stack,
665 &mut call_stack,
666 &mut derivative_stack,
667 &mut instruction_pointer,
668 ).inspect_err(|_| {
669 self.variables = std::mem::take(&mut call_stack[0].variables);
673 })? {
674 ControlFlow::Continue => instruction_pointer.1 += 1,
675 ControlFlow::Jump => (),
676 }
677 }
678
679 assert_eq!(value_stack.len(), 1);
680 assert_eq!(call_stack.len(), 1);
681
682 self.variables = call_stack.pop().unwrap().variables;
683 Ok(value_stack.pop().unwrap())
684 }
685}
686
687#[derive(Debug, Default)]
690pub struct ReplVm {
691 compiler: Compiler,
693
694 vm: Vm,
696}
697
698impl ReplVm {
699 pub fn new() -> Self {
701 Self::default()
702 }
703
704 pub fn execute(&mut self, stmts: Vec<Stmt>) -> Result<Value, Error> {
706 let compiler_clone = self.compiler.clone();
708 let vm_clone = self.vm.clone();
709
710 self.compiler.chunks = std::mem::replace(&mut self.vm.chunks, vec![Default::default()]);
711 self.compiler.chunks[0].instructions.clear();
712 self.compiler.labels = std::mem::take(&mut self.vm.labels)
713 .into_iter()
714 .map(|(label, location)| (label, Some(location)))
715 .collect();
716 self.compiler.sym_table = std::mem::take(&mut self.vm.sym_table);
717
718 compile_stmts(&stmts, &mut self.compiler).inspect_err(|_| {
719 self.compiler = compiler_clone;
722 self.vm = vm_clone;
723 })?;
724
725 self.vm.chunks = std::mem::replace(&mut self.compiler.chunks, vec![Default::default()]);
726 self.vm.labels = std::mem::take(&mut self.compiler.labels)
727 .into_iter()
728 .map(|(label, location)| (label, location.unwrap()))
729 .collect();
730 self.vm.sym_table = std::mem::take(&mut self.compiler.sym_table);
731
732 self.vm.run()
733 }
734}
735
736#[cfg(test)]
737mod tests {
738 use super::*;
739 use cas_compute::{
740 funcs::miscellaneous::{Abs, Factorial},
741 numerical::{builtin::Builtin, value::Value},
742 primitive::{float, float_from_str, int},
743 };
744 use cas_parser::parser::{ast::stmt::Stmt, Parser};
745 use rug::ops::Pow;
746
747 fn run_program(source: &str) -> Result<Value, Error> {
749 let mut parser = Parser::new(source);
750 let stmts = parser.try_parse_full_many::<Stmt>().unwrap();
751
752 let mut vm = Vm::compile_program(stmts)?;
753 vm.run()
754 }
755
756 fn run_program_degrees(source: &str) -> Result<Value, Error> {
759 let mut parser = Parser::new(source);
760 let stmts = parser.try_parse_full_many::<Stmt>().unwrap();
761
762 let mut vm = Vm::compile_program(stmts)?
763 .with_trig_mode(TrigMode::Degrees);
764 vm.run()
765 }
766
767 #[test]
768 fn binary_expr() {
769 let result = run_program("1 + 2").unwrap();
770 assert_eq!(result, Value::Integer(int(3)));
771 }
772
773 #[test]
774 fn binary_expr_2() {
775 let result = run_program("1 + 2 * 3").unwrap();
776 assert_eq!(result, Value::Integer(int(7)));
777 }
778
779 #[test]
780 fn binary_and_unary() {
781 let result = run_program("3 * -5 / 5! + 6").unwrap();
782 assert_eq!(result, Value::Float(float(5.875)));
783 }
784
785 #[test]
786 fn parenthesized() {
787 let result = run_program("((1 + 9) / 5) * 3").unwrap();
788 assert_eq!(result, Value::Integer(int(6)));
789 }
790
791 #[test]
792 fn degree_to_radian() {
793 let result = run_program_degrees("90 * 2 * pi / 360").unwrap();
794 assert_eq!(result, Value::Float(float(&*PI) / 2));
795 }
796
797 #[test]
798 fn precision() {
799 let result = run_program("e^2 - tau").unwrap();
800 assert_eq!(result, Value::Float(float(&*E).pow(2) - float(&*TAU)));
801 }
802
803 #[test]
804 fn precision_2() {
805 let result = run_program("pi^2 * 17! / -4.9 + e").unwrap();
806
807 let fac_17 = if let Value::Integer(fac_17) = Factorial::eval_static(float(17)) {
808 fac_17
809 } else {
810 unreachable!("factorial of 17 is an integer")
811 };
812 let expected = float(&*PI).pow(2) * fac_17 / -float_from_str("4.9") + float(&*E);
813 assert_eq!(result, Value::Float(expected));
814 }
815
816 #[test]
817 fn func_call() {
818 let source = [
819 ("f(x) = x^2 + 5x + 6;", Value::Unit),
820 ("f(7)", 90.into()),
821 ];
822
823 let mut vm = ReplVm::new();
824 for (stmt, expected) in source {
825 let mut parser = Parser::new(stmt);
826 let stmt = parser.try_parse_full::<Stmt>().unwrap();
827 assert_eq!(vm.execute(vec![stmt]).unwrap(), expected);
828 }
829 }
830
831 #[test]
832 fn complicated_func_call() {
833 let source = [
834 ("f(n = 3, k = 6) = n * k;", Value::Unit),
835 ("f()", 18.into()),
836 ("f(9)", 54.into()),
837 ("f(8, 14)", 112.into()),
838 ];
839
840 let mut vm = ReplVm::new();
841 for (stmt, expected) in source {
842 println!("executing: {}", stmt);
843 let mut parser = Parser::new(stmt);
844 let stmt = parser.try_parse_full::<Stmt>().unwrap();
845 assert_eq!(vm.execute(vec![stmt]).unwrap(), expected);
846 }
847 }
848
849 #[test]
850 fn user_func_default_param() {
851 let result = run_program("f(x = 2) = x; f() + f(3)").unwrap();
852 assert_eq!(result, Value::Integer(int(5)));
853 }
854
855 #[test]
856 fn user_func_mixed_param() {
857 let result = run_program("f(a, b, c = 3, d = 4) = a b c d; f(1, 2)").unwrap();
858 assert_eq!(result, Value::Integer(int(24)));
859 }
860
861 #[test]
862 fn user_func_bad_mixed_param() {
863 assert!(run_program("f(a, b, c = 3, d = 4) = a b c d; f()").is_err());
864 }
865
866 #[test]
867 fn builtin_func_arg_check() {
868 assert_eq!(Abs.eval(Default::default(), vec![Value::from(4.0)]).unwrap().coerce_float(), 4.0.into());
869 assert!(Abs.eval(Default::default(), vec![Value::Unit]).is_err());
870 }
871
872 #[test]
873 fn exec_literal_number() {
874 let result = run_program("42").unwrap();
875 assert_eq!(result, Value::Integer(int(42)));
876 }
877
878 #[test]
879 fn exec_multiple_assignment() {
880 let result = run_program("x=y=z=5").unwrap();
881 assert_eq!(result, Value::Integer(int(5)));
882 }
883
884 #[test]
885 fn exec_loop() {
886 let result = run_program("a = 0
887while a < 10 {
888 a += 1
889}; a").unwrap();
890 assert_eq!(result, Value::Integer(int(10)));
891 }
892
893 #[test]
894 fn exec_dumb_loop() {
895 let result = run_program("while true break").unwrap();
896 assert_eq!(result, Value::Unit);
897 }
898
899 #[test]
900 fn exec_loop_with_conditions() {
901 let result = run_program("a = 0
902j = 2
903while a < 10 && j < 15 {
904 if a < 5 {
905 a += 2
906 } else {
907 a += 1
908 j = -j + 4
909 }
910}; j").unwrap();
911 assert_eq!(result, Value::Integer(int(2)));
912 }
913
914 #[test]
915 fn exec_simple_program() {
916 let result = run_program("x = 4.5
9173x + 45 (x + 2) (1 + 3)").unwrap();
918 assert_eq!(result, Value::Float(float(1183.5)));
919 }
920
921 #[test]
922 fn exec_trig_mode() {
923 let result_1 = run_program("sin(pi/2)").unwrap();
924 let result_2 = run_program_degrees("sin(90)").unwrap();
925 assert_eq!(result_1.coerce_float(), Value::Float(float(1)));
926 assert_eq!(result_2.coerce_float(), Value::Float(float(1)));
927 }
928
929 #[test]
930 fn exec_factorial() {
931 let result = run_program("n = result = 8
932loop {
933 n -= 1
934 result *= n
935 if n <= 1 break result
936}").unwrap();
937 assert_eq!(result, Value::Integer(int(40320)));
938 }
939
940 #[test]
941 fn exec_partial_factorial() {
942 let result = run_program("partial_factorial(n, k) = {
943 result = 1
944 while n > k {
945 result *= n
946 n -= 1
947 }
948 result
949}
950
951partial_factorial(10, 7)").unwrap();
952 assert_eq!(result, Value::Integer(int(720)));
953 }
954
955 #[test]
956 fn exec_sum_even() {
957 let result = run_program("n = 200
958c = 0
959total = 0
960while c < n {
961 c += 1
962 if c & 1 == 1 continue
963 total += c
964}; total").unwrap();
965 assert_eq!(result, Value::Integer(int(10100)));
966 }
967
968 #[test]
969 fn exec_list_index() {
970 let result = run_program("arr = [1, 2, 3]
971arr[0] = 5
972arr[0] + arr[1] + arr[2] == 10").unwrap();
973 assert_eq!(result, Value::Boolean(true));
974 }
975
976 #[test]
977 fn exec_inline_indexing() {
978 let result = run_program("[1, 2, 3][im(2sqrt(-1))]").unwrap();
979 assert_eq!(result, Value::Integer(int(3)));
980 }
981
982 #[test]
983 fn exec_call_func() {
984 let result = run_program("g() = 6
985f(x) = x^2 + 5x + g()
986f(32)").unwrap();
987 assert_eq!(result, Value::Integer(int(1190)));
988 }
989
990 #[test]
991 fn exec_valid_prime_notation() {
992 let result = run_program("f(x) = log(x, 2)
993g(x) = 1/(x ln(2))
994f'(64) ~== g(64)").unwrap();
995 assert_eq!(result, Value::Boolean(true));
996 }
997
998 #[test]
999 fn exec_valid_prime_notation_but_wrong_args() {
1000 let result = run_program("f(x) = x^2 + 5x + 6; f'(64, 32)").unwrap_err();
1001
1002 assert_eq!(result.spans, vec![
1007 21..24,
1008 30..31,
1009 28..30,
1010 ]);
1011 }
1012
1013 #[test]
1014 fn exec_unsupported_prime_notation() {
1015 let result = run_program("log'(32, 2)").unwrap_err();
1016
1017 assert_eq!(result.spans[0], 0..5);
1020 }
1021
1022 #[test]
1023 fn exec_scoping() {
1024 let err = run_program("f() = j + 6
1025g() = {
1026 j = 10
1027 f()
1028}").unwrap_err();
1029
1030 assert_eq!(err.spans, vec![6..7]);
1035 }
1036
1037 #[test]
1038 fn exec_define_and_call() {
1039 let result = match run_program("f(x) = 2/sqrt(x)
1040g(x, y) = f(x) + f(y)
1041g(2, 3)").unwrap() {
1042 Value::Float(f) => f,
1043 other => panic!("expected float, got {:?}", other),
1044 };
1045
1046 let left = int(6) * float(2).sqrt();
1047 let right = int(4) * float(3).sqrt();
1048 let value = (left + right) / 6;
1049 assert!(float(result - value).abs() < 1e-6);
1050 }
1051
1052 #[test]
1053 fn exec_branching_return() {
1054 let result = run_program("f(x) = {
1055 if x < 0 return -x
1056 x
1057}
1058f(-5)").unwrap();
1059 assert_eq!(result, Value::Integer(int(5)));
1060 }
1061
1062 #[test]
1063 fn exec_unit_mess() {
1064 let result = run_program("f() = {}
1065f()").unwrap();
1066 assert_eq!(result, Value::Unit);
1067 }
1068
1069 #[test]
1070 fn exec_arithmetic_sequence_summation() {
1071 let result = run_program("f(a, d, n) = n / 2 * (2a + (n - 1)d)
1072g(a, d, n) = sum i in 0..n of a + i d
1073f(1, 2, 10) == g(1, 2, 10)").unwrap();
1074 assert_eq!(result, Value::Boolean(true));
1075 }
1076
1077 #[test]
1078 fn exec_product_factorial() {
1079 let result = run_program("iter_fac(n) = product i in 1..=n of i; iter_fac(5)").unwrap();
1080 assert_eq!(result, Value::Integer(int(120)));
1081 }
1082
1083 #[test]
1084 fn exec_for_loop_with_control() {
1085 let result = run_program("list = [20, 30, 40, 50, 60]
1086for i in 0..5 {
1087 list[i] += 5
1088 if i == 2 break
1089}; list").unwrap();
1090 assert_eq!(result, vec![
1091 25.into(),
1092 35.into(),
1093 45.into(),
1094 50.into(),
1095 60.into()
1096 ].into());
1097 }
1098
1099 #[test]
1100 fn exec_sum_even_indices() {
1101 let result = run_program("list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
1102total = 0
1103for i in 0..10 {
1104 if i & 1 == 1 continue
1105 total += list[i]
1106}; total").unwrap();
1107 assert_eq!(result, Value::Integer(int(25)));
1108 }
1109
1110 #[test]
1111 fn exec_incr_list() {
1112 let result = run_program("arr = [1, 2, 3]
1113arr[0] += 5
1114arr[1] += 6
1115arr[2] += 7
1116arr").unwrap();
1117 assert_eq!(result, vec![6.into(), 8.into(), 10.into()].into());
1118 }
1119
1120 #[test]
1121 fn example_bad_lcm() {
1122 let source = include_str!("../../examples/bad_lcm.calc");
1123 let result = run_program(source).unwrap();
1124 assert_eq!(result, 1517.into());
1125 }
1126
1127 #[test]
1128 fn example_factorial() {
1129 let source = include_str!("../../examples/factorial.calc");
1130 let result = run_program(source).unwrap();
1131 assert_eq!(result, true.into());
1132 }
1133
1134 #[test]
1135 fn example_convert_binary() {
1136 let source = include_str!("../../examples/convert_binary.calc");
1137 let result = run_program(source).unwrap();
1138 assert_eq!(result, vec![true.into(); 7].into());
1139 }
1140
1141 #[test]
1142 fn example_environment_capture() {
1143 let source = include_str!("../../examples/environment_capture.calc");
1144 let result = run_program(source).unwrap();
1145 assert_eq!(result, vec![55.into(), 20.into()].into());
1146 }
1147
1148 #[test]
1149 fn example_function_scope() {
1150 let source = include_str!("../../examples/function_scope.calc");
1151 let result = run_program(source).unwrap();
1152 assert_eq!(result, 14.into());
1153 }
1154
1155 #[test]
1156 fn example_higher_order_function() {
1157 let source = include_str!("../../examples/higher_order_function.calc");
1158 let result = run_program(source).unwrap();
1159 assert_eq!(result, vec![
1160 16.0.into(),
1161 (float(32) / float(3)).into(),
1162 20.0.into(),
1163 (float(40) / float(3)).into(),
1164 ].into());
1165 }
1166
1167 #[test]
1168 fn example_if_branching() {
1169 let source = include_str!("../../examples/if_branching.calc");
1170 let result = run_program(source).unwrap();
1171 assert_eq!(result.coerce_float(), float(5).log2().into());
1172 }
1173
1174 #[test]
1175 fn example_manual_abs() {
1176 let source = include_str!("../../examples/manual_abs.calc");
1177 let result = run_program(source).unwrap();
1178 assert_eq!(result, 4.into());
1179 }
1180
1181 #[test]
1182 fn example_map_list() {
1183 let source = include_str!("../../examples/map_list.calc");
1184 let result = run_program(source).unwrap();
1185 assert_eq!(result, vec![
1186 complex(0).into(),
1187 complex(1).into(),
1188 complex(2).into(),
1189 complex(3).into(),
1190 complex(4).into(),
1191 complex(5).into(),
1192 complex(6).into(),
1193 complex(7).into(),
1194 complex(8).into(),
1195 complex(9).into(),
1196 ].into());
1197 }
1198
1199 #[test]
1200 fn example_memoized_fib() {
1201 let source = include_str!("../../examples/memoized_fib.calc");
1202 let result = run_program(source).unwrap();
1203 assert_eq!(result, 6_557_470_319_842.into());
1204 }
1205
1206 #[test]
1207 fn example_ncr() {
1208 let source = include_str!("../../examples/ncr.calc");
1209 let result = run_program(source).unwrap();
1210 assert_eq!(result, true.into());
1211 }
1212
1213 #[test]
1214 fn example_prime_notation() {
1215 let source = include_str!("../../examples/prime_notation.calc");
1216 let result = run_program(source).unwrap();
1217 assert_eq!(result, vec![true.into(); 15].into());
1218 }
1219
1220 #[test]
1221 fn example_resolving_calls() {
1222 let source = include_str!("../../examples/resolving_calls.calc");
1223 let result = run_program(source).unwrap();
1224 assert_eq!(result, true.into());
1225 }
1226
1227 #[test]
1228 fn repl() {
1229 let source = [
1231 "f() = x", "f()", ];
1234
1235 let mut vm = ReplVm::new();
1236 for stmt in &source {
1237 let mut parser = Parser::new(stmt);
1238 let stmt = parser.try_parse_full::<Stmt>().unwrap();
1239 vm.execute(vec![stmt]).unwrap_err();
1240 }
1241 }
1242}