1use crate::compiler::expr::{BinOp, Expr, UnOp};
4use crate::error::{DialogueError, Result};
5use crate::value::{Value, VariableStorage};
6
7pub fn eval<S, F>(expr: &Expr, storage: &S, fns: &F) -> Result<Value>
15where
16 S: VariableStorage,
17 F: Fn(&str, Vec<Value>) -> Result<Value>,
18{
19 match expr {
20 Expr::Number(n) => Ok(Value::Number(*n)),
21 Expr::Text(s) => Ok(Value::Text(s.clone())),
22 Expr::Bool(b) => Ok(Value::Bool(*b)),
23 Expr::Var(name) => storage
24 .get(name)
25 .ok_or_else(|| DialogueError::UndefinedVariable(name.clone())),
26 Expr::Call { name, args } => {
27 let evaluated: Result<Vec<Value>> =
28 args.iter().map(|a| eval(a, storage, fns)).collect();
29 fns(name, evaluated?)
30 }
31 Expr::Unary { op, expr } => {
32 let v = eval(expr, storage, fns)?;
33 match op {
34 UnOp::Neg => {
35 if let Value::Number(n) = v {
36 Ok(Value::Number(-n))
37 } else {
38 Err(DialogueError::Type(format!("cannot negate {v:?}")))
39 }
40 }
41 UnOp::Not => Ok(Value::Bool(!v.is_truthy())),
42 }
43 }
44 Expr::Binary { left, op, right } => eval_binary(left, *op, right, storage, fns),
45 }
46}
47
48fn eval_binary<S, F>(left: &Expr, op: BinOp, right: &Expr, storage: &S, fns: &F) -> Result<Value>
49where
50 S: VariableStorage,
51 F: Fn(&str, Vec<Value>) -> Result<Value>,
52{
53 match op {
55 BinOp::And => {
56 let lv = eval(left, storage, fns)?;
57 if !lv.is_truthy() {
58 return Ok(Value::Bool(false));
59 }
60 return Ok(Value::Bool(eval(right, storage, fns)?.is_truthy()));
61 }
62 BinOp::Or => {
63 let lv = eval(left, storage, fns)?;
64 if lv.is_truthy() {
65 return Ok(Value::Bool(true));
66 }
67 return Ok(Value::Bool(eval(right, storage, fns)?.is_truthy()));
68 }
69 _ => {}
70 }
71
72 let lv = eval(left, storage, fns)?;
73 let rv = eval(right, storage, fns)?;
74
75 match op {
76 BinOp::Add => match (lv, rv) {
77 (Value::Number(a), Value::Number(b)) => Ok(Value::Number(a + b)),
78 (Value::Text(a), Value::Text(b)) => Ok(Value::Text(a + &b)),
79 (Value::Text(a), b) => Ok(Value::Text(a + &b.to_string())),
80 (a, b) => Err(DialogueError::Type(format!("cannot add {a:?} and {b:?}"))),
81 },
82 BinOp::Sub => num_op(lv, rv, "-", |x, y| x - y),
83 BinOp::Mul => num_op(lv, rv, "*", |x, y| x * y),
84 BinOp::Div => num_op(lv, rv, "/", |x, y| x / y),
85 BinOp::Rem => num_op(lv, rv, "%", |x, y| x % y),
86 BinOp::Eq => Ok(Value::Bool(lv == rv)),
87 BinOp::Neq => Ok(Value::Bool(lv != rv)),
88 BinOp::Lt => cmp_op(lv, rv, "<", |x: f64, y: f64| x < y),
89 BinOp::Lte => cmp_op(lv, rv, "<=", |x: f64, y: f64| x <= y),
90 BinOp::Gt => cmp_op(lv, rv, ">", |x: f64, y: f64| x > y),
91 BinOp::Gte => cmp_op(lv, rv, ">=", |x: f64, y: f64| x >= y),
92 BinOp::And | BinOp::Or => unreachable!("handled above"),
93 }
94}
95
96fn num_op(left: Value, right: Value, op: &str, calc: impl Fn(f64, f64) -> f64) -> Result<Value> {
97 match (left, right) {
98 (Value::Number(a), Value::Number(b)) => {
99 if op == "/" && b == 0.0 {
100 return Err(DialogueError::Runtime("division by zero".into()));
101 }
102 if op == "%" && b == 0.0 {
103 return Err(DialogueError::Runtime("modulo by zero".into()));
104 }
105 Ok(Value::Number(calc(a, b)))
106 }
107 (lv, rv) => Err(DialogueError::Type(format!(
108 "operator `{op}` requires numbers, got {lv:?} and {rv:?}"
109 ))),
110 }
111}
112
113fn cmp_op(left: Value, right: Value, op: &str, pred: impl Fn(f64, f64) -> bool) -> Result<Value> {
114 match (left, right) {
115 (Value::Number(a), Value::Number(b)) => Ok(Value::Bool(pred(a, b))),
116 (lv, rv) => Err(DialogueError::Type(format!(
117 "operator `{op}` requires numbers, got {lv:?} and {rv:?}"
118 ))),
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use crate::compiler::expr::parse_expr;
126 use crate::value::HashMapStorage;
127
128 fn no_fns(_: &str, _: Vec<Value>) -> Result<Value> {
129 Err(DialogueError::Runtime("no functions registered".into()))
130 }
131
132 fn ev(src: &str) -> Value {
133 let storage = HashMapStorage::new();
134 let expr = parse_expr(src).unwrap();
135 eval(&expr, &storage, &no_fns).unwrap()
136 }
137
138 #[test]
139 fn eval_addition() {
140 assert_eq!(ev("1 + 2"), Value::Number(3.0));
141 }
142
143 #[test]
144 fn eval_precedence() {
145 assert_eq!(ev("1 + 2 * 3"), Value::Number(7.0));
146 }
147
148 #[test]
149 fn eval_parentheses() {
150 assert_eq!(ev("(1 + 2) * 3"), Value::Number(9.0));
151 }
152
153 #[test]
154 fn eval_comparison() {
155 assert_eq!(ev("3 > 2"), Value::Bool(true));
156 assert_eq!(ev("1 >= 2"), Value::Bool(false));
157 }
158
159 #[test]
160 fn eval_logical_and_short_circuit() {
161 assert_eq!(ev("false && true"), Value::Bool(false));
162 assert_eq!(ev("true && true"), Value::Bool(true));
163 }
164
165 #[test]
166 fn eval_logical_or_short_circuit() {
167 assert_eq!(ev("true || false"), Value::Bool(true));
168 assert_eq!(ev("false || false"), Value::Bool(false));
169 }
170
171 #[test]
172 fn eval_string_concat() {
173 assert_eq!(
174 ev(r#""hello" + " world""#),
175 Value::Text("hello world".into())
176 );
177 }
178
179 #[test]
180 fn eval_unary_neg() {
181 assert_eq!(ev("-3"), Value::Number(-3.0));
182 }
183
184 #[test]
185 fn eval_unary_not() {
186 assert_eq!(ev("!false"), Value::Bool(true));
187 }
188
189 #[test]
190 fn negate_non_number_errors() {
191 let storage = HashMapStorage::new();
192 let expr = parse_expr(r#"-"hello""#).unwrap();
193 let err = eval(&expr, &storage, &no_fns).unwrap_err();
194 assert!(err.to_string().contains("negate"), "got {err}");
195 }
196
197 #[test]
198 fn subtract_non_numbers_errors() {
199 let storage = HashMapStorage::new();
200 let expr = parse_expr(r#""a" - 1"#).unwrap();
201 let err = eval(&expr, &storage, &no_fns).unwrap_err();
202 assert!(err.to_string().contains('-'), "got {err}");
203 }
204
205 #[test]
206 fn compare_non_numbers_errors() {
207 let storage = HashMapStorage::new();
208 let expr = parse_expr(r#""a" < 1"#).unwrap();
209 let err = eval(&expr, &storage, &no_fns).unwrap_err();
210 assert!(err.to_string().contains('<'), "got {err}");
211 }
212
213 #[test]
214 fn modulo_by_zero_errors() {
215 let storage = HashMapStorage::new();
216 let expr = parse_expr("5 % 0").unwrap();
217 let err = eval(&expr, &storage, &no_fns).unwrap_err();
218 assert!(err.to_string().contains("modulo"), "got {err}");
219 }
220
221 #[test]
222 fn multiply_non_numbers_errors() {
223 let storage = HashMapStorage::new();
224 let expr = parse_expr(r#""x" * 2"#).unwrap();
225 assert!(eval(&expr, &storage, &no_fns).is_err());
226 }
227
228 #[test]
229 fn comparison_operators_on_numbers() {
230 assert_eq!(ev("1 != 2"), Value::Bool(true));
231 assert_eq!(ev("2 != 2"), Value::Bool(false));
232 assert_eq!(ev("1 < 2"), Value::Bool(true));
233 assert_eq!(ev("2 <= 2"), Value::Bool(true));
234 assert_eq!(ev("3 > 2"), Value::Bool(true));
235 assert_eq!(ev("3 >= 4"), Value::Bool(false));
236 }
237
238 #[test]
239 fn div_and_rem_on_numbers() {
240 assert_eq!(ev("9 / 2"), Value::Number(4.5));
241 assert_eq!(ev("7 % 3"), Value::Number(1.0));
242 }
243
244 #[test]
245 fn add_number_and_bool_is_type_error() {
246 let storage = HashMapStorage::new();
247 let expr = parse_expr("1 + true").unwrap();
248 assert!(eval(&expr, &storage, &no_fns).is_err());
249 }
250
251 #[test]
252 fn bool_equality() {
253 assert_eq!(ev("true == true"), Value::Bool(true));
254 assert_eq!(ev("true == false"), Value::Bool(false));
255 }
256
257 #[test]
258 fn text_equality_and_inequality() {
259 assert_eq!(ev(r#""a" == "a""#), Value::Bool(true));
260 assert_eq!(ev(r#""a" != "b""#), Value::Bool(true));
261 }
262
263 #[test]
264 fn rem_requires_numbers() {
265 let storage = HashMapStorage::new();
266 let expr = parse_expr(r#""x" % 2"#).unwrap();
267 assert!(eval(&expr, &storage, &no_fns).is_err());
268 }
269}