1use crate::interpreter::Scope;
16use crate::ast::Value;
17use anyhow::{bail, Context, Result};
18
19pub fn eval_arithmetic(expr: &str, scope: &Scope) -> Result<i64> {
31 let mut parser = ArithParser::new(expr, scope);
32 let result = parser.parse_comparison()?;
33 parser.expect_end()?;
34 Ok(result)
35}
36
37struct ArithParser<'a> {
39 input: &'a str,
40 pos: usize,
41 scope: &'a Scope,
42}
43
44impl<'a> ArithParser<'a> {
45 fn new(input: &'a str, scope: &'a Scope) -> Self {
46 Self { input, pos: 0, scope }
47 }
48
49 fn skip_whitespace(&mut self) {
50 while self.pos < self.input.len() {
51 let ch = self.input.as_bytes()[self.pos];
52 if ch == b' ' || ch == b'\t' {
53 self.pos += 1;
54 } else {
55 break;
56 }
57 }
58 }
59
60 fn peek(&mut self) -> Option<char> {
61 self.skip_whitespace();
62 self.input[self.pos..].chars().next()
63 }
64
65 fn advance(&mut self) -> Option<char> {
66 self.skip_whitespace();
67 let ch = self.input[self.pos..].chars().next()?;
68 self.pos += ch.len_utf8();
69 Some(ch)
70 }
71
72 fn peek_ahead(&mut self, n: usize) -> Option<char> {
74 self.skip_whitespace();
75 self.input[self.pos..].chars().nth(n)
76 }
77
78 fn expect_end(&mut self) -> Result<()> {
79 self.skip_whitespace();
80 if self.pos < self.input.len() {
81 bail!("unexpected characters at end of arithmetic expression: {:?}",
82 &self.input[self.pos..]);
83 }
84 Ok(())
85 }
86
87 fn parse_comparison(&mut self) -> Result<i64> {
90 let mut left = self.parse_expr()?;
91
92 loop {
93 self.skip_whitespace();
94 match (self.peek_ahead(0), self.peek_ahead(1)) {
95 (Some('>'), Some('=')) => {
97 self.advance(); self.advance(); let right = self.parse_expr()?;
100 left = if left >= right { 1 } else { 0 };
101 }
102 (Some('<'), Some('=')) => {
103 self.advance(); self.advance(); let right = self.parse_expr()?;
106 left = if left <= right { 1 } else { 0 };
107 }
108 (Some('='), Some('=')) => {
109 self.advance(); self.advance(); let right = self.parse_expr()?;
112 left = if left == right { 1 } else { 0 };
113 }
114 (Some('!'), Some('=')) => {
115 self.advance(); self.advance(); let right = self.parse_expr()?;
118 left = if left != right { 1 } else { 0 };
119 }
120 (Some('>'), _) => {
122 self.advance(); let right = self.parse_expr()?;
124 left = if left > right { 1 } else { 0 };
125 }
126 (Some('<'), _) => {
127 self.advance(); let right = self.parse_expr()?;
129 left = if left < right { 1 } else { 0 };
130 }
131 _ => break,
132 }
133 }
134
135 Ok(left)
136 }
137
138 fn parse_expr(&mut self) -> Result<i64> {
140 let mut left = self.parse_term()?;
141
142 loop {
143 match self.peek() {
144 Some('+') => {
145 self.advance();
146 let right = self.parse_term()?;
147 left = left.checked_add(right)
148 .context("arithmetic overflow in addition")?;
149 }
150 Some('-') => {
151 self.advance();
152 let right = self.parse_term()?;
153 left = left.checked_sub(right)
154 .context("arithmetic overflow in subtraction")?;
155 }
156 _ => break,
157 }
158 }
159
160 Ok(left)
161 }
162
163 fn parse_term(&mut self) -> Result<i64> {
165 let mut left = self.parse_unary()?;
166
167 loop {
168 match self.peek() {
169 Some('*') => {
170 self.advance();
171 let right = self.parse_unary()?;
172 left = left.checked_mul(right)
173 .context("arithmetic overflow in multiplication")?;
174 }
175 Some('/') => {
176 self.advance();
177 let right = self.parse_unary()?;
178 if right == 0 {
179 bail!("division by zero");
180 }
181 left = left.checked_div(right)
182 .context("arithmetic overflow in division")?;
183 }
184 Some('%') => {
185 self.advance();
186 let right = self.parse_unary()?;
187 if right == 0 {
188 bail!("modulo by zero");
189 }
190 left = left.checked_rem(right)
191 .context("arithmetic overflow in modulo")?;
192 }
193 _ => break,
194 }
195 }
196
197 Ok(left)
198 }
199
200 fn parse_unary(&mut self) -> Result<i64> {
202 match self.peek() {
203 Some('+') => {
204 self.advance();
205 self.parse_unary()
206 }
207 Some('-') => {
208 self.advance();
209 let val = self.parse_unary()?;
210 val.checked_neg().context("arithmetic overflow in negation")
211 }
212 _ => self.parse_primary(),
213 }
214 }
215
216 fn parse_primary(&mut self) -> Result<i64> {
218 self.skip_whitespace();
219
220 match self.peek() {
221 Some('(') => {
222 self.advance(); let val = self.parse_expr()?;
224 match self.peek() {
225 Some(')') => {
226 self.advance();
227 Ok(val)
228 }
229 _ => bail!("expected ')' in arithmetic expression"),
230 }
231 }
232 Some('$') => {
233 self.advance(); if self.peek() == Some('?') {
238 self.advance(); return Ok(self.scope.last_result().code);
240 }
241
242 if self.peek() == Some('$') {
244 self.advance(); return Ok(self.scope.pid() as i64);
246 }
247
248 let var_name = if self.peek() == Some('{') {
249 self.advance(); if self.peek() == Some('?') {
253 self.advance(); if self.peek() != Some('}') {
255 bail!("expected '}}' after ${{?}} in arithmetic");
256 }
257 self.advance(); return Ok(self.scope.last_result().code);
259 }
260
261 if self.peek() == Some('$') {
263 self.advance(); if self.peek() != Some('}') {
265 bail!("expected '}}' after ${{$}} in arithmetic");
266 }
267 self.advance(); return Ok(self.scope.pid() as i64);
269 }
270
271 let name = self.parse_identifier()?;
272 if self.peek() != Some('}') {
273 bail!("expected '}}' after variable name in arithmetic");
274 }
275 self.advance(); name
277 } else {
278 self.parse_identifier()?
279 };
280 self.get_var_value(&var_name)
281 }
282 Some(c) if c.is_ascii_digit() => {
283 self.parse_number()
284 }
285 Some(c) if c.is_ascii_alphabetic() || c == '_' => {
286 let var_name = self.parse_identifier()?;
288 self.get_var_value(&var_name)
289 }
290 Some(c) => bail!("unexpected character in arithmetic expression: {:?}", c),
291 None => bail!("unexpected end of arithmetic expression"),
292 }
293 }
294
295 fn parse_number(&mut self) -> Result<i64> {
296 let start = self.pos;
297 while self.pos < self.input.len() {
298 let ch = self.input.as_bytes()[self.pos];
299 if ch.is_ascii_digit() {
300 self.pos += 1;
301 } else {
302 break;
303 }
304 }
305 let num_str = &self.input[start..self.pos];
306 num_str.parse().context("invalid number in arithmetic expression")
307 }
308
309 fn parse_identifier(&mut self) -> Result<String> {
310 let start = self.pos;
311 while self.pos < self.input.len() {
312 let ch = self.input.as_bytes()[self.pos];
313 if ch.is_ascii_alphanumeric() || ch == b'_' {
314 self.pos += 1;
315 } else {
316 break;
317 }
318 }
319 if start == self.pos {
320 bail!("expected identifier in arithmetic expression");
321 }
322 Ok(self.input[start..self.pos].to_string())
323 }
324
325 fn get_var_value(&self, name: &str) -> Result<i64> {
326 if let Ok(index) = name.parse::<usize>() {
329 if let Some(pos_val) = self.scope.get_positional(index) {
330 return pos_val.parse().with_context(|| {
331 format!("${} has non-numeric value: {:?}", index, pos_val)
332 });
333 }
334 return Ok(0); }
336
337 match self.scope.get(name) {
339 Some(Value::Int(n)) => Ok(*n),
340 Some(Value::String(s)) => {
341 s.parse().with_context(|| format!(
343 "variable '{}' has non-numeric value: {:?}", name, s
344 ))
345 }
346 Some(Value::Float(f)) => Ok(*f as i64),
347 Some(Value::Bool(b)) => Ok(if *b { 1 } else { 0 }),
348 Some(Value::Null) => Ok(0), Some(Value::Json(_)) => anyhow::bail!("variable '{}' is JSON, not a number", name),
350 Some(Value::Blob(_)) => anyhow::bail!("variable '{}' is a blob, not a number", name),
351 None => Ok(0), }
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 fn eval(expr: &str) -> i64 {
361 let scope = Scope::new();
362 eval_arithmetic(expr, &scope).expect("eval should succeed")
363 }
364
365 fn eval_with_var(expr: &str, name: &str, value: i64) -> i64 {
366 let mut scope = Scope::new();
367 scope.set(name, Value::Int(value));
368 eval_arithmetic(expr, &scope).expect("eval should succeed")
369 }
370
371 #[test]
372 fn test_simple_integers() {
373 assert_eq!(eval("42"), 42);
374 assert_eq!(eval("0"), 0);
375 assert_eq!(eval("12345"), 12345);
376 }
377
378 #[test]
379 fn test_addition() {
380 assert_eq!(eval("1 + 2"), 3);
381 assert_eq!(eval("10 + 20 + 30"), 60);
382 }
383
384 #[test]
385 fn test_subtraction() {
386 assert_eq!(eval("10 - 3"), 7);
387 assert_eq!(eval("100 - 50 - 25"), 25);
388 }
389
390 #[test]
391 fn test_multiplication() {
392 assert_eq!(eval("3 * 4"), 12);
393 assert_eq!(eval("2 * 3 * 4"), 24);
394 }
395
396 #[test]
397 fn test_division() {
398 assert_eq!(eval("10 / 2"), 5);
399 assert_eq!(eval("100 / 10 / 2"), 5);
400 }
401
402 #[test]
403 fn test_modulo() {
404 assert_eq!(eval("10 % 3"), 1);
405 assert_eq!(eval("17 % 5"), 2);
406 }
407
408 #[test]
409 fn test_precedence() {
410 assert_eq!(eval("2 + 3 * 4"), 14); assert_eq!(eval("10 - 6 / 2"), 7); }
413
414 #[test]
415 fn test_parentheses() {
416 assert_eq!(eval("(2 + 3) * 4"), 20);
417 assert_eq!(eval("((1 + 2) * (3 + 4))"), 21);
418 }
419
420 #[test]
421 fn test_unary_minus() {
422 assert_eq!(eval("-5"), -5);
423 assert_eq!(eval("10 + -3"), 7);
424 assert_eq!(eval("--5"), 5);
425 }
426
427 #[test]
428 fn test_unary_plus() {
429 assert_eq!(eval("+5"), 5);
430 assert_eq!(eval("++5"), 5);
431 }
432
433 #[test]
434 fn test_whitespace() {
435 assert_eq!(eval(" 1 + 2 "), 3);
436 assert_eq!(eval("1+2"), 3);
437 }
438
439 #[test]
440 fn test_variable_dollar() {
441 assert_eq!(eval_with_var("$X", "X", 10), 10);
442 assert_eq!(eval_with_var("$X + 5", "X", 10), 15);
443 }
444
445 #[test]
446 fn test_variable_dollar_braces() {
447 assert_eq!(eval_with_var("${X}", "X", 10), 10);
448 assert_eq!(eval_with_var("${X} * 2", "X", 10), 20);
449 }
450
451 #[test]
452 fn test_variable_bare() {
453 assert_eq!(eval_with_var("X", "X", 10), 10);
454 assert_eq!(eval_with_var("X + Y", "X", 10), 10); }
456
457 #[test]
458 fn test_unset_variable() {
459 let scope = Scope::new();
460 let result = eval_arithmetic("UNDEFINED", &scope).expect("should succeed");
461 assert_eq!(result, 0); }
463
464 #[test]
465 fn test_division_by_zero() {
466 let scope = Scope::new();
467 let result = eval_arithmetic("10 / 0", &scope);
468 assert!(result.is_err());
469 }
470
471 #[test]
472 fn test_modulo_by_zero() {
473 let scope = Scope::new();
474 let result = eval_arithmetic("10 % 0", &scope);
475 assert!(result.is_err());
476 }
477
478 #[test]
479 fn test_complex_expression() {
480 assert_eq!(eval("(1 + 2) * (3 + 4) - 5"), 16);
481 }
482
483 #[test]
485 fn test_greater_than() {
486 assert_eq!(eval("5 > 3"), 1);
487 assert_eq!(eval("3 > 5"), 0);
488 assert_eq!(eval("5 > 5"), 0);
489 }
490
491 #[test]
492 fn test_less_than() {
493 assert_eq!(eval("3 < 5"), 1);
494 assert_eq!(eval("5 < 3"), 0);
495 assert_eq!(eval("5 < 5"), 0);
496 }
497
498 #[test]
499 fn test_greater_or_equal() {
500 assert_eq!(eval("5 >= 3"), 1);
501 assert_eq!(eval("5 >= 5"), 1);
502 assert_eq!(eval("3 >= 5"), 0);
503 }
504
505 #[test]
506 fn test_less_or_equal() {
507 assert_eq!(eval("3 <= 5"), 1);
508 assert_eq!(eval("5 <= 5"), 1);
509 assert_eq!(eval("5 <= 3"), 0);
510 }
511
512 #[test]
513 fn test_equal() {
514 assert_eq!(eval("5 == 5"), 1);
515 assert_eq!(eval("5 == 3"), 0);
516 }
517
518 #[test]
519 fn test_not_equal() {
520 assert_eq!(eval("5 != 3"), 1);
521 assert_eq!(eval("5 != 5"), 0);
522 }
523
524 #[test]
525 fn test_comparison_with_arithmetic() {
526 assert_eq!(eval("(2 + 3) > 4"), 1);
527 assert_eq!(eval("10 / 2 == 5"), 1);
528 assert_eq!(eval("3 * 4 >= 12"), 1);
529 assert_eq!(eval("10 - 5 < 6"), 1);
530 }
531
532 #[test]
533 fn test_comparison_with_variables() {
534 assert_eq!(eval_with_var("X > 5", "X", 10), 1);
535 assert_eq!(eval_with_var("X == 10", "X", 10), 1);
536 assert_eq!(eval_with_var("X <= 10", "X", 10), 1);
537 }
538
539 #[test]
540 fn test_chained_comparison() {
541 assert_eq!(eval("5 > 3 > 2"), 0);
544 assert_eq!(eval("5 > 3 == 1"), 1);
546 }
547}