Skip to main content

bubbles/runtime/
eval.rs

1//! Expression evaluator — walks an [`Expr`] AST and produces a [`Value`].
2
3use crate::compiler::expr::{BinOp, Expr, UnOp};
4use crate::error::{DialogueError, Result};
5use crate::value::{Value, VariableStorage};
6
7/// Evaluate an [`Expr`] AST node using `storage` for variable reads and `fns` for function calls.
8///
9/// `fns` receives the function name and evaluated arguments and must return a [`Value`].
10///
11/// # Errors
12/// Returns [`crate::error::DialogueError`] for undefined variables, type mismatches, or failed
13/// function calls.
14pub fn eval<S, F>(expr: &Expr, storage: &S, fns: &F) -> Result<Value>
15where
16    S: VariableStorage,
17    F: Fn(&str, Vec<Value>) -> Result<Value>,
18{
19    match expr {
20        Expr::Number(n) => Ok(Value::Number(*n)),
21        Expr::Text(s) => Ok(Value::Text(s.clone())),
22        Expr::Bool(b) => Ok(Value::Bool(*b)),
23        Expr::Var(name) => storage
24            .get(name)
25            .ok_or_else(|| DialogueError::UndefinedVariable(name.clone())),
26        Expr::Call { name, args } => {
27            let evaluated: Result<Vec<Value>> =
28                args.iter().map(|a| eval(a, storage, fns)).collect();
29            fns(name, evaluated?)
30        }
31        Expr::Unary { op, expr } => {
32            let v = eval(expr, storage, fns)?;
33            match op {
34                UnOp::Neg => {
35                    if let Value::Number(n) = v {
36                        Ok(Value::Number(-n))
37                    } else {
38                        Err(DialogueError::Type(format!("cannot negate {v:?}")))
39                    }
40                }
41                UnOp::Not => Ok(Value::Bool(!v.is_truthy())),
42            }
43        }
44        Expr::Binary { left, op, right } => eval_binary(left, *op, right, storage, fns),
45    }
46}
47
48fn eval_binary<S, F>(left: &Expr, op: BinOp, right: &Expr, storage: &S, fns: &F) -> Result<Value>
49where
50    S: VariableStorage,
51    F: Fn(&str, Vec<Value>) -> Result<Value>,
52{
53    // short-circuit for `&&` and `||`
54    match op {
55        BinOp::And => {
56            let lv = eval(left, storage, fns)?;
57            if !lv.is_truthy() {
58                return Ok(Value::Bool(false));
59            }
60            return Ok(Value::Bool(eval(right, storage, fns)?.is_truthy()));
61        }
62        BinOp::Or => {
63            let lv = eval(left, storage, fns)?;
64            if lv.is_truthy() {
65                return Ok(Value::Bool(true));
66            }
67            return Ok(Value::Bool(eval(right, storage, fns)?.is_truthy()));
68        }
69        _ => {}
70    }
71
72    let lv = eval(left, storage, fns)?;
73    let rv = eval(right, storage, fns)?;
74
75    match op {
76        BinOp::Add => match (lv, rv) {
77            (Value::Number(a), Value::Number(b)) => Ok(Value::Number(a + b)),
78            (Value::Text(a), Value::Text(b)) => Ok(Value::Text(a + &b)),
79            (Value::Text(a), b) => Ok(Value::Text(a + &b.to_string())),
80            (a, b) => Err(DialogueError::Type(format!("cannot add {a:?} and {b:?}"))),
81        },
82        BinOp::Sub => num_op(lv, rv, "-", |x, y| x - y),
83        BinOp::Mul => num_op(lv, rv, "*", |x, y| x * y),
84        BinOp::Div => num_op(lv, rv, "/", |x, y| x / y),
85        BinOp::Rem => num_op(lv, rv, "%", |x, y| x % y),
86        BinOp::Eq => Ok(Value::Bool(lv == rv)),
87        BinOp::Neq => Ok(Value::Bool(lv != rv)),
88        BinOp::Lt => cmp_op(lv, rv, "<", |x: f64, y: f64| x < y),
89        BinOp::Lte => cmp_op(lv, rv, "<=", |x: f64, y: f64| x <= y),
90        BinOp::Gt => cmp_op(lv, rv, ">", |x: f64, y: f64| x > y),
91        BinOp::Gte => cmp_op(lv, rv, ">=", |x: f64, y: f64| x >= y),
92        BinOp::And | BinOp::Or => unreachable!("handled above"),
93    }
94}
95
96fn num_op(left: Value, right: Value, op: &str, calc: impl Fn(f64, f64) -> f64) -> Result<Value> {
97    match (left, right) {
98        (Value::Number(a), Value::Number(b)) => {
99            if op == "/" && b == 0.0 {
100                return Err(DialogueError::Runtime("division by zero".into()));
101            }
102            if op == "%" && b == 0.0 {
103                return Err(DialogueError::Runtime("modulo by zero".into()));
104            }
105            Ok(Value::Number(calc(a, b)))
106        }
107        (lv, rv) => Err(DialogueError::Type(format!(
108            "operator `{op}` requires numbers, got {lv:?} and {rv:?}"
109        ))),
110    }
111}
112
113fn cmp_op(left: Value, right: Value, op: &str, pred: impl Fn(f64, f64) -> bool) -> Result<Value> {
114    match (left, right) {
115        (Value::Number(a), Value::Number(b)) => Ok(Value::Bool(pred(a, b))),
116        (lv, rv) => Err(DialogueError::Type(format!(
117            "operator `{op}` requires numbers, got {lv:?} and {rv:?}"
118        ))),
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use crate::compiler::expr::parse_expr;
126    use crate::value::HashMapStorage;
127
128    fn no_fns(_: &str, _: Vec<Value>) -> Result<Value> {
129        Err(DialogueError::Runtime("no functions registered".into()))
130    }
131
132    fn ev(src: &str) -> Value {
133        let storage = HashMapStorage::new();
134        let expr = parse_expr(src).unwrap();
135        eval(&expr, &storage, &no_fns).unwrap()
136    }
137
138    #[test]
139    fn eval_addition() {
140        assert_eq!(ev("1 + 2"), Value::Number(3.0));
141    }
142
143    #[test]
144    fn eval_precedence() {
145        assert_eq!(ev("1 + 2 * 3"), Value::Number(7.0));
146    }
147
148    #[test]
149    fn eval_parentheses() {
150        assert_eq!(ev("(1 + 2) * 3"), Value::Number(9.0));
151    }
152
153    #[test]
154    fn eval_comparison() {
155        assert_eq!(ev("3 > 2"), Value::Bool(true));
156        assert_eq!(ev("1 >= 2"), Value::Bool(false));
157    }
158
159    #[test]
160    fn eval_logical_and_short_circuit() {
161        assert_eq!(ev("false && true"), Value::Bool(false));
162        assert_eq!(ev("true && true"), Value::Bool(true));
163    }
164
165    #[test]
166    fn eval_logical_or_short_circuit() {
167        assert_eq!(ev("true || false"), Value::Bool(true));
168        assert_eq!(ev("false || false"), Value::Bool(false));
169    }
170
171    #[test]
172    fn eval_string_concat() {
173        assert_eq!(
174            ev(r#""hello" + " world""#),
175            Value::Text("hello world".into())
176        );
177    }
178
179    #[test]
180    fn eval_unary_neg() {
181        assert_eq!(ev("-3"), Value::Number(-3.0));
182    }
183
184    #[test]
185    fn eval_unary_not() {
186        assert_eq!(ev("!false"), Value::Bool(true));
187    }
188
189    #[test]
190    fn negate_non_number_errors() {
191        let storage = HashMapStorage::new();
192        let expr = parse_expr(r#"-"hello""#).unwrap();
193        let err = eval(&expr, &storage, &no_fns).unwrap_err();
194        assert!(err.to_string().contains("negate"), "got {err}");
195    }
196
197    #[test]
198    fn subtract_non_numbers_errors() {
199        let storage = HashMapStorage::new();
200        let expr = parse_expr(r#""a" - 1"#).unwrap();
201        let err = eval(&expr, &storage, &no_fns).unwrap_err();
202        assert!(err.to_string().contains('-'), "got {err}");
203    }
204
205    #[test]
206    fn compare_non_numbers_errors() {
207        let storage = HashMapStorage::new();
208        let expr = parse_expr(r#""a" < 1"#).unwrap();
209        let err = eval(&expr, &storage, &no_fns).unwrap_err();
210        assert!(err.to_string().contains('<'), "got {err}");
211    }
212
213    #[test]
214    fn modulo_by_zero_errors() {
215        let storage = HashMapStorage::new();
216        let expr = parse_expr("5 % 0").unwrap();
217        let err = eval(&expr, &storage, &no_fns).unwrap_err();
218        assert!(err.to_string().contains("modulo"), "got {err}");
219    }
220
221    #[test]
222    fn multiply_non_numbers_errors() {
223        let storage = HashMapStorage::new();
224        let expr = parse_expr(r#""x" * 2"#).unwrap();
225        assert!(eval(&expr, &storage, &no_fns).is_err());
226    }
227
228    #[test]
229    fn comparison_operators_on_numbers() {
230        assert_eq!(ev("1 != 2"), Value::Bool(true));
231        assert_eq!(ev("2 != 2"), Value::Bool(false));
232        assert_eq!(ev("1 < 2"), Value::Bool(true));
233        assert_eq!(ev("2 <= 2"), Value::Bool(true));
234        assert_eq!(ev("3 > 2"), Value::Bool(true));
235        assert_eq!(ev("3 >= 4"), Value::Bool(false));
236    }
237
238    #[test]
239    fn div_and_rem_on_numbers() {
240        assert_eq!(ev("9 / 2"), Value::Number(4.5));
241        assert_eq!(ev("7 % 3"), Value::Number(1.0));
242    }
243
244    #[test]
245    fn add_number_and_bool_is_type_error() {
246        let storage = HashMapStorage::new();
247        let expr = parse_expr("1 + true").unwrap();
248        assert!(eval(&expr, &storage, &no_fns).is_err());
249    }
250
251    #[test]
252    fn bool_equality() {
253        assert_eq!(ev("true == true"), Value::Bool(true));
254        assert_eq!(ev("true == false"), Value::Bool(false));
255    }
256
257    #[test]
258    fn text_equality_and_inequality() {
259        assert_eq!(ev(r#""a" == "a""#), Value::Bool(true));
260        assert_eq!(ev(r#""a" != "b""#), Value::Bool(true));
261    }
262
263    #[test]
264    fn rem_requires_numbers() {
265        let storage = HashMapStorage::new();
266        let expr = parse_expr(r#""x" % 2"#).unwrap();
267        assert!(eval(&expr, &storage, &no_fns).is_err());
268    }
269}