Skip to main content

kaish_kernel/
arithmetic.rs

1//! Arithmetic expression evaluation for shell-style `$(( ))` expressions.
2//!
3//! Supports:
4//! - Integer arithmetic: `+`, `-`, `*`, `/`, `%`
5//! - Comparison operators: `>`, `<`, `>=`, `<=`, `==`, `!=` (return 1 or 0)
6//! - Parentheses for grouping: `(expr)`
7//! - Variable references: `$VAR` or bare `VAR`
8//! - Integer literals
9//!
10//! Does NOT support:
11//! - Floating point (pipe to `jq` for float math)
12//! - Bitwise operations (shell-ism we're skipping)
13//! - Assignment within expressions (confusing)
14
15use crate::interpreter::Scope;
16use crate::ast::Value;
17use anyhow::{bail, Context, Result};
18
19/// Evaluate an arithmetic expression string.
20///
21/// The expression should be the content between `$((` and `))`.
22///
23/// # Example
24/// ```ignore
25/// let scope = Scope::new();
26/// scope.set("X", Value::Int(5));
27/// let result = eval_arithmetic("X + 3", &scope)?;
28/// assert_eq!(result, 8);
29/// ```
30pub fn eval_arithmetic(expr: &str, scope: &Scope) -> Result<i64> {
31    let mut parser = ArithParser::new(expr, scope);
32    let result = parser.parse_comparison()?;
33    parser.expect_end()?;
34    Ok(result)
35}
36
37/// Simple recursive descent parser for arithmetic expressions.
38struct ArithParser<'a> {
39    input: &'a str,
40    pos: usize,
41    scope: &'a Scope,
42}
43
44impl<'a> ArithParser<'a> {
45    fn new(input: &'a str, scope: &'a Scope) -> Self {
46        Self { input, pos: 0, scope }
47    }
48
49    fn skip_whitespace(&mut self) {
50        while self.pos < self.input.len() {
51            let ch = self.input.as_bytes()[self.pos];
52            if ch == b' ' || ch == b'\t' {
53                self.pos += 1;
54            } else {
55                break;
56            }
57        }
58    }
59
60    fn peek(&mut self) -> Option<char> {
61        self.skip_whitespace();
62        self.input[self.pos..].chars().next()
63    }
64
65    fn advance(&mut self) -> Option<char> {
66        self.skip_whitespace();
67        let ch = self.input[self.pos..].chars().next()?;
68        self.pos += ch.len_utf8();
69        Some(ch)
70    }
71
72    /// Peek at the character n positions ahead (0 = current after whitespace skip).
73    fn peek_ahead(&mut self, n: usize) -> Option<char> {
74        self.skip_whitespace();
75        self.input[self.pos..].chars().nth(n)
76    }
77
78    fn expect_end(&mut self) -> Result<()> {
79        self.skip_whitespace();
80        if self.pos < self.input.len() {
81            bail!("unexpected characters at end of arithmetic expression: {:?}",
82                  &self.input[self.pos..]);
83        }
84        Ok(())
85    }
86
87    /// Parse comparison operators (lowest precedence): >, <, >=, <=, ==, !=
88    /// Returns 1 for true, 0 for false.
89    fn parse_comparison(&mut self) -> Result<i64> {
90        let mut left = self.parse_expr()?;
91
92        loop {
93            self.skip_whitespace();
94            match (self.peek_ahead(0), self.peek_ahead(1)) {
95                // Two-character operators must be checked first
96                (Some('>'), Some('=')) => {
97                    self.advance(); // consume '>'
98                    self.advance(); // consume '='
99                    let right = self.parse_expr()?;
100                    left = if left >= right { 1 } else { 0 };
101                }
102                (Some('<'), Some('=')) => {
103                    self.advance(); // consume '<'
104                    self.advance(); // consume '='
105                    let right = self.parse_expr()?;
106                    left = if left <= right { 1 } else { 0 };
107                }
108                (Some('='), Some('=')) => {
109                    self.advance(); // consume '='
110                    self.advance(); // consume '='
111                    let right = self.parse_expr()?;
112                    left = if left == right { 1 } else { 0 };
113                }
114                (Some('!'), Some('=')) => {
115                    self.advance(); // consume '!'
116                    self.advance(); // consume '='
117                    let right = self.parse_expr()?;
118                    left = if left != right { 1 } else { 0 };
119                }
120                // Single-character operators
121                (Some('>'), _) => {
122                    self.advance(); // consume '>'
123                    let right = self.parse_expr()?;
124                    left = if left > right { 1 } else { 0 };
125                }
126                (Some('<'), _) => {
127                    self.advance(); // consume '<'
128                    let right = self.parse_expr()?;
129                    left = if left < right { 1 } else { 0 };
130                }
131                _ => break,
132            }
133        }
134
135        Ok(left)
136    }
137
138    /// Parse an expression: handles + and - (lowest precedence)
139    fn parse_expr(&mut self) -> Result<i64> {
140        let mut left = self.parse_term()?;
141
142        loop {
143            match self.peek() {
144                Some('+') => {
145                    self.advance();
146                    let right = self.parse_term()?;
147                    left = left.checked_add(right)
148                        .context("arithmetic overflow in addition")?;
149                }
150                Some('-') => {
151                    self.advance();
152                    let right = self.parse_term()?;
153                    left = left.checked_sub(right)
154                        .context("arithmetic overflow in subtraction")?;
155                }
156                _ => break,
157            }
158        }
159
160        Ok(left)
161    }
162
163    /// Parse a term: handles * / % (higher precedence)
164    fn parse_term(&mut self) -> Result<i64> {
165        let mut left = self.parse_unary()?;
166
167        loop {
168            match self.peek() {
169                Some('*') => {
170                    self.advance();
171                    let right = self.parse_unary()?;
172                    left = left.checked_mul(right)
173                        .context("arithmetic overflow in multiplication")?;
174                }
175                Some('/') => {
176                    self.advance();
177                    let right = self.parse_unary()?;
178                    if right == 0 {
179                        bail!("division by zero");
180                    }
181                    left = left.checked_div(right)
182                        .context("arithmetic overflow in division")?;
183                }
184                Some('%') => {
185                    self.advance();
186                    let right = self.parse_unary()?;
187                    if right == 0 {
188                        bail!("modulo by zero");
189                    }
190                    left = left.checked_rem(right)
191                        .context("arithmetic overflow in modulo")?;
192                }
193                _ => break,
194            }
195        }
196
197        Ok(left)
198    }
199
200    /// Parse unary operators: + and - prefix
201    fn parse_unary(&mut self) -> Result<i64> {
202        match self.peek() {
203            Some('+') => {
204                self.advance();
205                self.parse_unary()
206            }
207            Some('-') => {
208                self.advance();
209                let val = self.parse_unary()?;
210                val.checked_neg().context("arithmetic overflow in negation")
211            }
212            _ => self.parse_primary(),
213        }
214    }
215
216    /// Parse primary: numbers, variables, parenthesized expressions
217    fn parse_primary(&mut self) -> Result<i64> {
218        self.skip_whitespace();
219
220        match self.peek() {
221            Some('(') => {
222                self.advance(); // consume '('
223                let val = self.parse_expr()?;
224                match self.peek() {
225                    Some(')') => {
226                        self.advance();
227                        Ok(val)
228                    }
229                    _ => bail!("expected ')' in arithmetic expression"),
230                }
231            }
232            Some('$') => {
233                // $VAR, ${VAR}, $?, $$, ${?}, ${$} syntax
234                self.advance(); // consume '$'
235
236                // Special case: $? (last exit code)
237                if self.peek() == Some('?') {
238                    self.advance(); // consume '?'
239                    return Ok(self.scope.last_result().code);
240                }
241
242                // Special case: $$ (current PID)
243                if self.peek() == Some('$') {
244                    self.advance(); // consume second '$'
245                    return Ok(self.scope.pid() as i64);
246                }
247
248                let var_name = if self.peek() == Some('{') {
249                    self.advance(); // consume '{'
250
251                    // Special case: ${?} (last exit code, braced form)
252                    if self.peek() == Some('?') {
253                        self.advance(); // consume '?'
254                        if self.peek() != Some('}') {
255                            bail!("expected '}}' after ${{?}} in arithmetic");
256                        }
257                        self.advance(); // consume '}'
258                        return Ok(self.scope.last_result().code);
259                    }
260
261                    // Special case: ${$} (current PID, braced form)
262                    if self.peek() == Some('$') {
263                        self.advance(); // consume '$'
264                        if self.peek() != Some('}') {
265                            bail!("expected '}}' after ${{$}} in arithmetic");
266                        }
267                        self.advance(); // consume '}'
268                        return Ok(self.scope.pid() as i64);
269                    }
270
271                    let name = self.parse_identifier()?;
272                    if self.peek() != Some('}') {
273                        bail!("expected '}}' after variable name in arithmetic");
274                    }
275                    self.advance(); // consume '}'
276                    name
277                } else {
278                    self.parse_identifier()?
279                };
280                self.get_var_value(&var_name)
281            }
282            Some(c) if c.is_ascii_digit() => {
283                self.parse_number()
284            }
285            Some(c) if c.is_ascii_alphabetic() || c == '_' => {
286                // Bare variable name (bash allows this in $(( )))
287                let var_name = self.parse_identifier()?;
288                self.get_var_value(&var_name)
289            }
290            Some(c) => bail!("unexpected character in arithmetic expression: {:?}", c),
291            None => bail!("unexpected end of arithmetic expression"),
292        }
293    }
294
295    fn parse_number(&mut self) -> Result<i64> {
296        let start = self.pos;
297        while self.pos < self.input.len() {
298            let ch = self.input.as_bytes()[self.pos];
299            if ch.is_ascii_digit() {
300                self.pos += 1;
301            } else {
302                break;
303            }
304        }
305        let num_str = &self.input[start..self.pos];
306        num_str.parse().context("invalid number in arithmetic expression")
307    }
308
309    fn parse_identifier(&mut self) -> Result<String> {
310        let start = self.pos;
311        while self.pos < self.input.len() {
312            let ch = self.input.as_bytes()[self.pos];
313            if ch.is_ascii_alphanumeric() || ch == b'_' {
314                self.pos += 1;
315            } else {
316                break;
317            }
318        }
319        if start == self.pos {
320            bail!("expected identifier in arithmetic expression");
321        }
322        Ok(self.input[start..self.pos].to_string())
323    }
324
325    fn get_var_value(&self, name: &str) -> Result<i64> {
326        // Check for positional parameters ($0, $1, $2, ... $9, etc.)
327        // Name is just the digits when called from `$1` or `${1}` parsing
328        if let Ok(index) = name.parse::<usize>() {
329            if let Some(pos_val) = self.scope.get_positional(index) {
330                return pos_val.parse().with_context(|| {
331                    format!("${} has non-numeric value: {:?}", index, pos_val)
332                });
333            }
334            return Ok(0); // Unset positional defaults to 0
335        }
336
337        // Regular variable lookup
338        match self.scope.get(name) {
339            Some(Value::Int(n)) => Ok(*n),
340            Some(Value::String(s)) => {
341                // Try to parse string as integer
342                s.parse().with_context(|| format!(
343                    "variable '{}' has non-numeric value: {:?}", name, s
344                ))
345            }
346            Some(Value::Float(f)) => Ok(*f as i64),
347            Some(Value::Bool(b)) => Ok(if *b { 1 } else { 0 }),
348            Some(Value::Null) => Ok(0), // Unset variables default to 0 in arithmetic
349            Some(Value::Json(_)) => anyhow::bail!("variable '{}' is JSON, not a number", name),
350            Some(Value::Blob(_)) => anyhow::bail!("variable '{}' is a blob, not a number", name),
351            None => Ok(0), // Unset variables default to 0 in arithmetic
352        }
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359
360    fn eval(expr: &str) -> i64 {
361        let scope = Scope::new();
362        eval_arithmetic(expr, &scope).expect("eval should succeed")
363    }
364
365    fn eval_with_var(expr: &str, name: &str, value: i64) -> i64 {
366        let mut scope = Scope::new();
367        scope.set(name, Value::Int(value));
368        eval_arithmetic(expr, &scope).expect("eval should succeed")
369    }
370
371    #[test]
372    fn test_simple_integers() {
373        assert_eq!(eval("42"), 42);
374        assert_eq!(eval("0"), 0);
375        assert_eq!(eval("12345"), 12345);
376    }
377
378    #[test]
379    fn test_addition() {
380        assert_eq!(eval("1 + 2"), 3);
381        assert_eq!(eval("10 + 20 + 30"), 60);
382    }
383
384    #[test]
385    fn test_subtraction() {
386        assert_eq!(eval("10 - 3"), 7);
387        assert_eq!(eval("100 - 50 - 25"), 25);
388    }
389
390    #[test]
391    fn test_multiplication() {
392        assert_eq!(eval("3 * 4"), 12);
393        assert_eq!(eval("2 * 3 * 4"), 24);
394    }
395
396    #[test]
397    fn test_division() {
398        assert_eq!(eval("10 / 2"), 5);
399        assert_eq!(eval("100 / 10 / 2"), 5);
400    }
401
402    #[test]
403    fn test_modulo() {
404        assert_eq!(eval("10 % 3"), 1);
405        assert_eq!(eval("17 % 5"), 2);
406    }
407
408    #[test]
409    fn test_precedence() {
410        assert_eq!(eval("2 + 3 * 4"), 14); // Not 20
411        assert_eq!(eval("10 - 6 / 2"), 7); // Not 2
412    }
413
414    #[test]
415    fn test_parentheses() {
416        assert_eq!(eval("(2 + 3) * 4"), 20);
417        assert_eq!(eval("((1 + 2) * (3 + 4))"), 21);
418    }
419
420    #[test]
421    fn test_unary_minus() {
422        assert_eq!(eval("-5"), -5);
423        assert_eq!(eval("10 + -3"), 7);
424        assert_eq!(eval("--5"), 5);
425    }
426
427    #[test]
428    fn test_unary_plus() {
429        assert_eq!(eval("+5"), 5);
430        assert_eq!(eval("++5"), 5);
431    }
432
433    #[test]
434    fn test_whitespace() {
435        assert_eq!(eval("  1  +  2  "), 3);
436        assert_eq!(eval("1+2"), 3);
437    }
438
439    #[test]
440    fn test_variable_dollar() {
441        assert_eq!(eval_with_var("$X", "X", 10), 10);
442        assert_eq!(eval_with_var("$X + 5", "X", 10), 15);
443    }
444
445    #[test]
446    fn test_variable_dollar_braces() {
447        assert_eq!(eval_with_var("${X}", "X", 10), 10);
448        assert_eq!(eval_with_var("${X} * 2", "X", 10), 20);
449    }
450
451    #[test]
452    fn test_variable_bare() {
453        assert_eq!(eval_with_var("X", "X", 10), 10);
454        assert_eq!(eval_with_var("X + Y", "X", 10), 10); // Y is unset = 0
455    }
456
457    #[test]
458    fn test_unset_variable() {
459        let scope = Scope::new();
460        let result = eval_arithmetic("UNDEFINED", &scope).expect("should succeed");
461        assert_eq!(result, 0); // Unset variables default to 0
462    }
463
464    #[test]
465    fn test_division_by_zero() {
466        let scope = Scope::new();
467        let result = eval_arithmetic("10 / 0", &scope);
468        assert!(result.is_err());
469    }
470
471    #[test]
472    fn test_modulo_by_zero() {
473        let scope = Scope::new();
474        let result = eval_arithmetic("10 % 0", &scope);
475        assert!(result.is_err());
476    }
477
478    #[test]
479    fn test_complex_expression() {
480        assert_eq!(eval("(1 + 2) * (3 + 4) - 5"), 16);
481    }
482
483    // Comparison operator tests
484    #[test]
485    fn test_greater_than() {
486        assert_eq!(eval("5 > 3"), 1);
487        assert_eq!(eval("3 > 5"), 0);
488        assert_eq!(eval("5 > 5"), 0);
489    }
490
491    #[test]
492    fn test_less_than() {
493        assert_eq!(eval("3 < 5"), 1);
494        assert_eq!(eval("5 < 3"), 0);
495        assert_eq!(eval("5 < 5"), 0);
496    }
497
498    #[test]
499    fn test_greater_or_equal() {
500        assert_eq!(eval("5 >= 3"), 1);
501        assert_eq!(eval("5 >= 5"), 1);
502        assert_eq!(eval("3 >= 5"), 0);
503    }
504
505    #[test]
506    fn test_less_or_equal() {
507        assert_eq!(eval("3 <= 5"), 1);
508        assert_eq!(eval("5 <= 5"), 1);
509        assert_eq!(eval("5 <= 3"), 0);
510    }
511
512    #[test]
513    fn test_equal() {
514        assert_eq!(eval("5 == 5"), 1);
515        assert_eq!(eval("5 == 3"), 0);
516    }
517
518    #[test]
519    fn test_not_equal() {
520        assert_eq!(eval("5 != 3"), 1);
521        assert_eq!(eval("5 != 5"), 0);
522    }
523
524    #[test]
525    fn test_comparison_with_arithmetic() {
526        assert_eq!(eval("(2 + 3) > 4"), 1);
527        assert_eq!(eval("10 / 2 == 5"), 1);
528        assert_eq!(eval("3 * 4 >= 12"), 1);
529        assert_eq!(eval("10 - 5 < 6"), 1);
530    }
531
532    #[test]
533    fn test_comparison_with_variables() {
534        assert_eq!(eval_with_var("X > 5", "X", 10), 1);
535        assert_eq!(eval_with_var("X == 10", "X", 10), 1);
536        assert_eq!(eval_with_var("X <= 10", "X", 10), 1);
537    }
538
539    #[test]
540    fn test_chained_comparison() {
541        // Note: chained comparisons work left-to-right, not mathematically
542        // (5 > 3) > 2 = 1 > 2 = 0
543        assert_eq!(eval("5 > 3 > 2"), 0);
544        // (5 > 3) == 1 = 1 == 1 = 1
545        assert_eq!(eval("5 > 3 == 1"), 1);
546    }
547}