rush_sh/
arithmetic.rs

1use super::state::ShellState;
2
3/// Token types for arithmetic expressions
4#[derive(Debug, Clone, PartialEq)]
5pub enum ArithmeticToken {
6    Number(i64),
7    Variable(String),
8    Operator(ArithmeticOperator),
9    LeftParen,
10    RightParen,
11}
12
13/// Arithmetic operators with their precedence and associativity
14#[derive(Debug, Clone, PartialEq)]
15pub enum ArithmeticOperator {
16    // Unary operators (precedence 100)
17    LogicalNot, // !
18    BitwiseNot, // ~
19
20    // Binary operators in order of precedence (highest to lowest)
21    Multiply,     // *   (precedence 90)
22    Divide,       // /   (precedence 90)
23    Modulo,       // %   (precedence 90)
24    Add,          // +   (precedence 80)
25    Subtract,     // -   (precedence 80)
26    ShiftLeft,    // <<  (precedence 70)
27    ShiftRight,   // >>  (precedence 70)
28    LessThan,     // <   (precedence 60)
29    LessEqual,    // <=  (precedence 60)
30    GreaterThan,  // >   (precedence 60)
31    GreaterEqual, // >=  (precedence 60)
32    Equal,        // ==  (precedence 50)
33    NotEqual,     // !=  (precedence 50)
34    BitwiseAnd,   // &   (precedence 40)
35    BitwiseXor,   // ^   (precedence 30)
36    BitwiseOr,    // |   (precedence 20)
37    LogicalAnd,   // &&  (precedence 10)
38    LogicalOr,    // ||  (precedence 5)
39}
40
41impl ArithmeticOperator {
42    pub fn precedence(&self) -> i32 {
43        match self {
44            ArithmeticOperator::LogicalNot | ArithmeticOperator::BitwiseNot => 100,
45
46            ArithmeticOperator::Multiply
47            | ArithmeticOperator::Divide
48            | ArithmeticOperator::Modulo => 90,
49            ArithmeticOperator::Add | ArithmeticOperator::Subtract => 80,
50            ArithmeticOperator::ShiftLeft | ArithmeticOperator::ShiftRight => 70,
51            ArithmeticOperator::LessThan
52            | ArithmeticOperator::LessEqual
53            | ArithmeticOperator::GreaterThan
54            | ArithmeticOperator::GreaterEqual => 60,
55            ArithmeticOperator::Equal | ArithmeticOperator::NotEqual => 50,
56            ArithmeticOperator::BitwiseAnd => 40,
57            ArithmeticOperator::BitwiseXor => 30,
58            ArithmeticOperator::BitwiseOr => 20,
59            ArithmeticOperator::LogicalAnd => 10,
60            ArithmeticOperator::LogicalOr => 5,
61        }
62    }
63
64    pub fn is_unary(&self) -> bool {
65        matches!(
66            self,
67            ArithmeticOperator::LogicalNot | ArithmeticOperator::BitwiseNot
68        )
69    }
70}
71
72/// Errors that can occur during arithmetic evaluation
73#[derive(Debug, Clone)]
74pub enum ArithmeticError {
75    SyntaxError(String),
76    DivisionByZero,
77    UnmatchedParentheses,
78    EmptyExpression,
79}
80
81impl std::fmt::Display for ArithmeticError {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        match self {
84            ArithmeticError::SyntaxError(msg) => write!(f, "Syntax error: {}", msg),
85            ArithmeticError::DivisionByZero => write!(f, "Division by zero"),
86            ArithmeticError::UnmatchedParentheses => write!(f, "Unmatched parentheses"),
87            ArithmeticError::EmptyExpression => write!(f, "Empty expression"),
88        }
89    }
90}
91
92/// Tokenize an arithmetic expression into tokens
93pub fn tokenize_expression(expr: &str) -> Result<Vec<ArithmeticToken>, ArithmeticError> {
94    let mut tokens = Vec::new();
95    let mut chars = expr.chars().peekable();
96
97    while let Some(ch) = chars.next() {
98        match ch {
99            ' ' | '\t' | '\n' => continue, // Skip whitespace
100
101            '(' => tokens.push(ArithmeticToken::LeftParen),
102            ')' => tokens.push(ArithmeticToken::RightParen),
103
104            '+' => {
105                if let Some(next_ch) = chars.peek()
106                    && *next_ch == '+' {
107                        return Err(ArithmeticError::SyntaxError("Unexpected ++".to_string()));
108                    }
109                tokens.push(ArithmeticToken::Operator(ArithmeticOperator::Add));
110            }
111
112            '-' => {
113                if let Some(next_ch) = chars.peek()
114                    && *next_ch == '-' {
115                        return Err(ArithmeticError::SyntaxError("Unexpected --".to_string()));
116                    }
117                tokens.push(ArithmeticToken::Operator(ArithmeticOperator::Subtract));
118            }
119
120            '*' => tokens.push(ArithmeticToken::Operator(ArithmeticOperator::Multiply)),
121            '/' => tokens.push(ArithmeticToken::Operator(ArithmeticOperator::Divide)),
122            '%' => tokens.push(ArithmeticToken::Operator(ArithmeticOperator::Modulo)),
123
124            '<' => {
125                if let Some(&next_ch) = chars.peek() {
126                    if next_ch == '<' {
127                        chars.next();
128                        tokens.push(ArithmeticToken::Operator(ArithmeticOperator::ShiftLeft));
129                    } else if next_ch == '=' {
130                        chars.next();
131                        tokens.push(ArithmeticToken::Operator(ArithmeticOperator::LessEqual));
132                    } else {
133                        tokens.push(ArithmeticToken::Operator(ArithmeticOperator::LessThan));
134                    }
135                } else {
136                    tokens.push(ArithmeticToken::Operator(ArithmeticOperator::LessThan));
137                }
138            }
139
140            '>' => {
141                if let Some(&next_ch) = chars.peek() {
142                    if next_ch == '>' {
143                        chars.next();
144                        tokens.push(ArithmeticToken::Operator(ArithmeticOperator::ShiftRight));
145                    } else if next_ch == '=' {
146                        chars.next();
147                        tokens.push(ArithmeticToken::Operator(ArithmeticOperator::GreaterEqual));
148                    } else {
149                        tokens.push(ArithmeticToken::Operator(ArithmeticOperator::GreaterThan));
150                    }
151                } else {
152                    tokens.push(ArithmeticToken::Operator(ArithmeticOperator::GreaterThan));
153                }
154            }
155
156            '=' => {
157                if let Some(&next_ch) = chars.peek() {
158                    if next_ch == '=' {
159                        chars.next();
160                        tokens.push(ArithmeticToken::Operator(ArithmeticOperator::Equal));
161                    } else {
162                        return Err(ArithmeticError::SyntaxError("Unexpected =".to_string()));
163                    }
164                } else {
165                    return Err(ArithmeticError::SyntaxError("Unexpected =".to_string()));
166                }
167            }
168
169            '!' => {
170                if let Some(&next_ch) = chars.peek() {
171                    if next_ch == '=' {
172                        chars.next();
173                        tokens.push(ArithmeticToken::Operator(ArithmeticOperator::NotEqual));
174                    } else {
175                        tokens.push(ArithmeticToken::Operator(ArithmeticOperator::LogicalNot));
176                    }
177                } else {
178                    tokens.push(ArithmeticToken::Operator(ArithmeticOperator::LogicalNot));
179                }
180            }
181
182            '&' => {
183                if let Some(&next_ch) = chars.peek() {
184                    if next_ch == '&' {
185                        chars.next();
186                        tokens.push(ArithmeticToken::Operator(ArithmeticOperator::LogicalAnd));
187                    } else {
188                        tokens.push(ArithmeticToken::Operator(ArithmeticOperator::BitwiseAnd));
189                    }
190                } else {
191                    tokens.push(ArithmeticToken::Operator(ArithmeticOperator::BitwiseAnd));
192                }
193            }
194
195            '|' => {
196                if let Some(&next_ch) = chars.peek() {
197                    if next_ch == '|' {
198                        chars.next();
199                        tokens.push(ArithmeticToken::Operator(ArithmeticOperator::LogicalOr));
200                    } else {
201                        tokens.push(ArithmeticToken::Operator(ArithmeticOperator::BitwiseOr));
202                    }
203                } else {
204                    tokens.push(ArithmeticToken::Operator(ArithmeticOperator::BitwiseOr));
205                }
206            }
207
208            '^' => tokens.push(ArithmeticToken::Operator(ArithmeticOperator::BitwiseXor)),
209            '~' => tokens.push(ArithmeticToken::Operator(ArithmeticOperator::BitwiseNot)),
210
211            // Numbers and variables
212            '0'..='9' => {
213                let mut num_str = String::new();
214                num_str.push(ch);
215                while let Some(&next_ch) = chars.peek() {
216                    if next_ch.is_ascii_digit() {
217                        num_str.push(next_ch);
218                        chars.next();
219                    } else {
220                        break;
221                    }
222                }
223                match num_str.parse::<i64>() {
224                    Ok(num) => tokens.push(ArithmeticToken::Number(num)),
225                    Err(_) => {
226                        return Err(ArithmeticError::SyntaxError("Invalid number".to_string()));
227                    }
228                }
229            }
230
231            // Variables (start with letter or underscore)
232            'a'..='z' | 'A'..='Z' | '_' => {
233                let mut var_name = String::new();
234                var_name.push(ch);
235                while let Some(&next_ch) = chars.peek() {
236                    if next_ch.is_alphanumeric() || next_ch == '_' {
237                        var_name.push(next_ch);
238                        chars.next();
239                    } else {
240                        break;
241                    }
242                }
243                tokens.push(ArithmeticToken::Variable(var_name));
244            }
245
246            _ => {
247                return Err(ArithmeticError::SyntaxError(format!(
248                    "Unexpected character: {}",
249                    ch
250                )));
251            }
252        }
253    }
254
255    Ok(tokens)
256}
257
258/// Parse tokens into Reverse Polish Notation (RPN) using Shunting-yard algorithm
259pub fn parse_to_rpn(tokens: Vec<ArithmeticToken>) -> Result<Vec<ArithmeticToken>, ArithmeticError> {
260    let mut output = Vec::new();
261    let mut operators = Vec::new();
262
263    for token in tokens {
264        match token {
265            ArithmeticToken::Number(_) | ArithmeticToken::Variable(_) => {
266                output.push(token);
267            }
268
269            ArithmeticToken::Operator(op) => {
270                // Handle unary operators
271                if op.is_unary()
272                    && (output.is_empty()
273                        || matches!(
274                            output.last(),
275                            Some(ArithmeticToken::Operator(_) | ArithmeticToken::LeftParen)
276                        ))
277                {
278                    // This is a unary operator
279                    while !operators.is_empty() {
280                        if let Some(ArithmeticToken::Operator(top_op)) = operators.last() {
281                            if top_op.precedence() >= op.precedence() && !top_op.is_unary() {
282                                output.push(operators.pop().unwrap());
283                            } else {
284                                break;
285                            }
286                        } else {
287                            break;
288                        }
289                    }
290                    operators.push(ArithmeticToken::Operator(op));
291                } else {
292                    // Binary operator
293                    while !operators.is_empty() {
294                        if let Some(ArithmeticToken::Operator(top_op)) = operators.last() {
295                            if (top_op.precedence() > op.precedence())
296                                || (top_op.precedence() == op.precedence() && !op.is_unary())
297                            {
298                                output.push(operators.pop().unwrap());
299                            } else {
300                                break;
301                            }
302                        } else {
303                            break;
304                        }
305                    }
306                    operators.push(ArithmeticToken::Operator(op));
307                }
308            }
309
310            ArithmeticToken::LeftParen => {
311                operators.push(token);
312            }
313
314            ArithmeticToken::RightParen => {
315                let mut found_left = false;
316                while let Some(op) = operators.pop() {
317                    if op == ArithmeticToken::LeftParen {
318                        found_left = true;
319                        break;
320                    } else {
321                        output.push(op);
322                    }
323                }
324                if !found_left {
325                    return Err(ArithmeticError::UnmatchedParentheses);
326                }
327            }
328        }
329    }
330
331    // Pop remaining operators
332    while let Some(op) = operators.pop() {
333        if op == ArithmeticToken::LeftParen {
334            return Err(ArithmeticError::UnmatchedParentheses);
335        }
336        output.push(op);
337    }
338
339    Ok(output)
340}
341
342/// Evaluate an arithmetic expression in Reverse Polish Notation
343pub fn evaluate_rpn(
344    rpn_tokens: Vec<ArithmeticToken>,
345    shell_state: &ShellState,
346) -> Result<i64, ArithmeticError> {
347    let mut stack = Vec::new();
348
349    for token in rpn_tokens {
350        match token {
351            ArithmeticToken::Number(num) => {
352                stack.push(num);
353            }
354
355            ArithmeticToken::Variable(var_name) => {
356                if let Some(value) = shell_state.get_var(&var_name) {
357                    match value.parse::<i64>() {
358                        Ok(num) => stack.push(num),
359                        Err(_) => {
360                            // Variable exists but is not a valid number, treat as 0 (bash behavior)
361                            stack.push(0)
362                        }
363                    }
364                } else {
365                    // Variable is undefined, treat as 0 (bash behavior)
366                    stack.push(0)
367                }
368            }
369
370            ArithmeticToken::Operator(op) => {
371                if op.is_unary() {
372                    if stack.is_empty() {
373                        return Err(ArithmeticError::SyntaxError(
374                            "Missing operand for unary operator".to_string(),
375                        ));
376                    }
377                    let operand = stack.pop().unwrap();
378                    let result = match op {
379                        ArithmeticOperator::LogicalNot => !operand,
380                        ArithmeticOperator::BitwiseNot => !operand,
381                        _ => unreachable!(),
382                    };
383                    stack.push(result);
384                } else {
385                    if stack.len() < 2 {
386                        return Err(ArithmeticError::SyntaxError(
387                            "Missing operands for binary operator".to_string(),
388                        ));
389                    }
390                    let right = stack.pop().unwrap();
391                    let left = stack.pop().unwrap();
392                    let result = match op {
393                        ArithmeticOperator::Add => left + right,
394                        ArithmeticOperator::Subtract => left - right,
395                        ArithmeticOperator::Multiply => left * right,
396                        ArithmeticOperator::Divide => {
397                            if right == 0 {
398                                return Err(ArithmeticError::DivisionByZero);
399                            }
400                            left / right
401                        }
402                        ArithmeticOperator::Modulo => {
403                            if right == 0 {
404                                return Err(ArithmeticError::DivisionByZero);
405                            }
406                            left % right
407                        }
408                        ArithmeticOperator::ShiftLeft => left << right,
409                        ArithmeticOperator::ShiftRight => left >> right,
410                        ArithmeticOperator::LessThan => {
411                            if left < right {
412                                1
413                            } else {
414                                0
415                            }
416                        }
417                        ArithmeticOperator::LessEqual => {
418                            if left <= right {
419                                1
420                            } else {
421                                0
422                            }
423                        }
424                        ArithmeticOperator::GreaterThan => {
425                            if left > right {
426                                1
427                            } else {
428                                0
429                            }
430                        }
431                        ArithmeticOperator::GreaterEqual => {
432                            if left >= right {
433                                1
434                            } else {
435                                0
436                            }
437                        }
438                        ArithmeticOperator::Equal => {
439                            if left == right {
440                                1
441                            } else {
442                                0
443                            }
444                        }
445                        ArithmeticOperator::NotEqual => {
446                            if left != right {
447                                1
448                            } else {
449                                0
450                            }
451                        }
452                        ArithmeticOperator::BitwiseAnd => left & right,
453                        ArithmeticOperator::BitwiseXor => left ^ right,
454                        ArithmeticOperator::BitwiseOr => left | right,
455                        ArithmeticOperator::LogicalAnd => {
456                            if left != 0 && right != 0 {
457                                1
458                            } else {
459                                0
460                            }
461                        }
462                        ArithmeticOperator::LogicalOr => {
463                            if left != 0 || right != 0 {
464                                1
465                            } else {
466                                0
467                            }
468                        }
469                        _ => unreachable!(),
470                    };
471                    stack.push(result);
472                }
473            }
474
475            ArithmeticToken::LeftParen | ArithmeticToken::RightParen => {
476                return Err(ArithmeticError::SyntaxError(
477                    "Unexpected parenthesis in RPN".to_string(),
478                ));
479            }
480        }
481    }
482
483    if stack.len() != 1 {
484        return Err(ArithmeticError::SyntaxError(
485            "Invalid expression".to_string(),
486        ));
487    }
488
489    Ok(stack[0])
490}
491
492/// Main function to evaluate an arithmetic expression
493pub fn evaluate_arithmetic_expression(
494    expr: &str,
495    shell_state: &ShellState,
496) -> Result<i64, ArithmeticError> {
497    if expr.trim().is_empty() {
498        return Err(ArithmeticError::EmptyExpression);
499    }
500
501    let tokens = tokenize_expression(expr)?;
502    let rpn_tokens = parse_to_rpn(tokens)?;
503    let result = evaluate_rpn(rpn_tokens, shell_state)?;
504
505    Ok(result)
506}
507
508#[cfg(test)]
509mod tests {
510    use super::*;
511
512    #[test]
513    fn test_tokenize_simple_numbers() {
514        let tokens = tokenize_expression("42").unwrap();
515        assert_eq!(tokens, vec![ArithmeticToken::Number(42)]);
516    }
517
518    #[test]
519    fn test_tokenize_operators() {
520        let tokens = tokenize_expression("2+3").unwrap();
521        assert_eq!(
522            tokens,
523            vec![
524                ArithmeticToken::Number(2),
525                ArithmeticToken::Operator(ArithmeticOperator::Add),
526                ArithmeticToken::Number(3)
527            ]
528        );
529    }
530
531    #[test]
532    fn test_tokenize_parentheses() {
533        let tokens = tokenize_expression("(2+3)").unwrap();
534        assert_eq!(
535            tokens,
536            vec![
537                ArithmeticToken::LeftParen,
538                ArithmeticToken::Number(2),
539                ArithmeticToken::Operator(ArithmeticOperator::Add),
540                ArithmeticToken::Number(3),
541                ArithmeticToken::RightParen
542            ]
543        );
544    }
545
546    #[test]
547    fn test_tokenize_variables() {
548        let tokens = tokenize_expression("x+y").unwrap();
549        assert_eq!(
550            tokens,
551            vec![
552                ArithmeticToken::Variable("x".to_string()),
553                ArithmeticToken::Operator(ArithmeticOperator::Add),
554                ArithmeticToken::Variable("y".to_string())
555            ]
556        );
557    }
558
559    #[test]
560    fn test_evaluate_simple() {
561        let shell_state = ShellState::new();
562        let result = evaluate_arithmetic_expression("42", &shell_state).unwrap();
563        assert_eq!(result, 42);
564    }
565
566    #[test]
567    fn test_evaluate_addition() {
568        let shell_state = ShellState::new();
569        let result = evaluate_arithmetic_expression("2+3", &shell_state).unwrap();
570        assert_eq!(result, 5);
571    }
572
573    #[test]
574    fn test_evaluate_with_precedence() {
575        let shell_state = ShellState::new();
576        let result = evaluate_arithmetic_expression("2+3*4", &shell_state).unwrap();
577        assert_eq!(result, 14); // 3*4 = 12, +2 = 14
578    }
579
580    #[test]
581    fn test_evaluate_with_parentheses() {
582        let shell_state = ShellState::new();
583        let result = evaluate_arithmetic_expression("(2+3)*4", &shell_state).unwrap();
584        assert_eq!(result, 20); // (2+3) = 5, *4 = 20
585    }
586
587    #[test]
588    fn test_evaluate_comparison() {
589        let shell_state = ShellState::new();
590        let result = evaluate_arithmetic_expression("5>3", &shell_state).unwrap();
591        assert_eq!(result, 1); // true
592
593        let result = evaluate_arithmetic_expression("3>5", &shell_state).unwrap();
594        assert_eq!(result, 0); // false
595    }
596
597    #[test]
598    fn test_evaluate_variable() {
599        let mut shell_state = ShellState::new();
600        shell_state.set_var("x", "10".to_string());
601        let result = evaluate_arithmetic_expression("x + 5", &shell_state).unwrap();
602        assert_eq!(result, 15);
603    }
604
605    #[test]
606    fn test_evaluate_division_by_zero() {
607        let shell_state = ShellState::new();
608        let result = evaluate_arithmetic_expression("5/0", &shell_state);
609        assert!(matches!(result, Err(ArithmeticError::DivisionByZero)));
610    }
611
612    #[test]
613    fn test_evaluate_undefined_variable() {
614        let shell_state = ShellState::new();
615        let result = evaluate_arithmetic_expression("undefined + 5", &shell_state);
616        // Undefined variables are treated as 0 (bash behavior)
617        assert_eq!(result.unwrap(), 5);
618    }
619}