1use crate::ast::BinaryOp;
2use crate::ast::Node;
3use crate::error::Error;
4
5pub struct Evaluator<'a> {
6 ast: &'a Node,
7}
8
9impl Evaluator<'_> {
10 pub fn new(ast: &Node) -> Evaluator {
11 Evaluator { ast }
12 }
13
14 pub fn eval(&self) -> Result<i32, Error> {
15 fn eval_node(node: &Node) -> Result<i32, Error> {
16 match node {
17 Node::IntLit(value) => Ok(*value),
18 Node::BinaryExpr { op, left, right } => {
19 let left = eval_node(left)?;
20 let right = eval_node(right)?;
21
22 match op {
23 BinaryOp::Add => Ok(left.wrapping_add(right)),
24 BinaryOp::Sub => Ok(left.wrapping_sub(right)),
25 BinaryOp::Mul => Ok(left.wrapping_mul(right)),
26 BinaryOp::Div => {
27 if right != 0 {
28 Ok(left.wrapping_div(right))
29 } else {
30 Err(Error::new("division by zero"))
31 }
32 }
33 }
34 }
35 }
36 }
37
38 eval_node(self.ast)
39 }
40}
41
42#[cfg(test)]
43mod tests {
44 use super::*;
45 use crate::tests::helpers::ast::*;
46
47 macro_rules! assert_evals {
48 ($ast:expr, $value:expr) => {
49 let ast = $ast;
50 let evaluator = Evaluator::new(&ast);
51
52 assert_eq!(evaluator.eval(), Ok($value));
53 };
54 }
55
56 macro_rules! assert_does_not_eval {
57 ($ast:expr, $message:expr) => {
58 let ast = $ast;
59 let evaluator = Evaluator::new(&ast);
60
61 assert_eq!(evaluator.eval(), Err(Error::new($message)));
62 };
63 }
64
65 #[test]
66 fn evals_int_lit() {
67 assert_evals!(int(1), 1);
68 }
69
70 #[test]
71 fn evals_binary_expr_add() {
72 assert_evals!(add(int(1), int(2)), 3);
73
74 assert_evals!(add(int(2147483647), int(1)), -2147483648);
76 assert_evals!(add(int(-2147483648), int(-1)), 2147483647);
77 }
78
79 #[test]
80 fn evals_binary_expr_sub() {
81 assert_evals!(sub(int(3), int(2)), 1);
82
83 assert_evals!(sub(int(2147483647), int(-1)), -2147483648);
85 assert_evals!(sub(int(-2147483648), int(1)), 2147483647);
86 }
87
88 #[test]
89 fn evals_binary_expr_mul() {
90 assert_evals!(mul(int(2), int(3)), 6);
91
92 assert_evals!(mul(int(-2147483648), int(-1)), -2147483648);
94 }
95
96 #[test]
97 fn evals_binary_expr_div() {
98 assert_evals!(div(int(6), int(3)), 2);
99
100 assert_evals!(div(int(-2147483648), int(-1)), -2147483648);
102
103 assert_does_not_eval!(div(int(1), int(0)), "division by zero");
105 }
106
107 #[test]
108 fn evals_complex_expressions() {
109 assert_evals!(mul(add(int(1), int(2)), add(int(3), int(4))), 21);
110 }
111}