1use super::state::ShellState;
2
3#[derive(Debug, Clone, PartialEq)]
5pub enum ArithmeticToken {
6 Number(i64),
7 Variable(String),
8 Operator(ArithmeticOperator),
9 LeftParen,
10 RightParen,
11}
12
13#[derive(Debug, Clone, PartialEq)]
15pub enum ArithmeticOperator {
16 LogicalNot, BitwiseNot, Multiply, Divide, Modulo, Add, Subtract, ShiftLeft, ShiftRight, LessThan, LessEqual, GreaterThan, GreaterEqual, Equal, NotEqual, BitwiseAnd, BitwiseXor, BitwiseOr, LogicalAnd, LogicalOr, }
40
41impl ArithmeticOperator {
42 pub fn precedence(&self) -> i32 {
43 match self {
44 ArithmeticOperator::LogicalNot | ArithmeticOperator::BitwiseNot => 100,
45
46 ArithmeticOperator::Multiply
47 | ArithmeticOperator::Divide
48 | ArithmeticOperator::Modulo => 90,
49 ArithmeticOperator::Add | ArithmeticOperator::Subtract => 80,
50 ArithmeticOperator::ShiftLeft | ArithmeticOperator::ShiftRight => 70,
51 ArithmeticOperator::LessThan
52 | ArithmeticOperator::LessEqual
53 | ArithmeticOperator::GreaterThan
54 | ArithmeticOperator::GreaterEqual => 60,
55 ArithmeticOperator::Equal | ArithmeticOperator::NotEqual => 50,
56 ArithmeticOperator::BitwiseAnd => 40,
57 ArithmeticOperator::BitwiseXor => 30,
58 ArithmeticOperator::BitwiseOr => 20,
59 ArithmeticOperator::LogicalAnd => 10,
60 ArithmeticOperator::LogicalOr => 5,
61 }
62 }
63
64 pub fn is_unary(&self) -> bool {
65 matches!(
66 self,
67 ArithmeticOperator::LogicalNot | ArithmeticOperator::BitwiseNot
68 )
69 }
70}
71
72#[derive(Debug, Clone)]
74pub enum ArithmeticError {
75 SyntaxError(String),
76 DivisionByZero,
77 UnmatchedParentheses,
78 EmptyExpression,
79}
80
81impl std::fmt::Display for ArithmeticError {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 match self {
84 ArithmeticError::SyntaxError(msg) => write!(f, "Syntax error: {}", msg),
85 ArithmeticError::DivisionByZero => write!(f, "Division by zero"),
86 ArithmeticError::UnmatchedParentheses => write!(f, "Unmatched parentheses"),
87 ArithmeticError::EmptyExpression => write!(f, "Empty expression"),
88 }
89 }
90}
91
92pub fn tokenize_expression(expr: &str) -> Result<Vec<ArithmeticToken>, ArithmeticError> {
94 let mut tokens = Vec::new();
95 let mut chars = expr.chars().peekable();
96
97 while let Some(ch) = chars.next() {
98 match ch {
99 ' ' | '\t' | '\n' => continue, '(' => tokens.push(ArithmeticToken::LeftParen),
102 ')' => tokens.push(ArithmeticToken::RightParen),
103
104 '+' => {
105 if let Some(next_ch) = chars.peek()
106 && *next_ch == '+'
107 {
108 return Err(ArithmeticError::SyntaxError("Unexpected ++".to_string()));
109 }
110 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::Add));
111 }
112
113 '-' => {
114 if let Some(next_ch) = chars.peek()
115 && *next_ch == '-'
116 {
117 return Err(ArithmeticError::SyntaxError("Unexpected --".to_string()));
118 }
119 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::Subtract));
120 }
121
122 '*' => tokens.push(ArithmeticToken::Operator(ArithmeticOperator::Multiply)),
123 '/' => tokens.push(ArithmeticToken::Operator(ArithmeticOperator::Divide)),
124 '%' => tokens.push(ArithmeticToken::Operator(ArithmeticOperator::Modulo)),
125
126 '<' => {
127 if let Some(&next_ch) = chars.peek() {
128 if next_ch == '<' {
129 chars.next();
130 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::ShiftLeft));
131 } else if next_ch == '=' {
132 chars.next();
133 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::LessEqual));
134 } else {
135 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::LessThan));
136 }
137 } else {
138 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::LessThan));
139 }
140 }
141
142 '>' => {
143 if let Some(&next_ch) = chars.peek() {
144 if next_ch == '>' {
145 chars.next();
146 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::ShiftRight));
147 } else if next_ch == '=' {
148 chars.next();
149 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::GreaterEqual));
150 } else {
151 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::GreaterThan));
152 }
153 } else {
154 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::GreaterThan));
155 }
156 }
157
158 '=' => {
159 if let Some(&next_ch) = chars.peek() {
160 if next_ch == '=' {
161 chars.next();
162 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::Equal));
163 } else {
164 return Err(ArithmeticError::SyntaxError("Unexpected =".to_string()));
165 }
166 } else {
167 return Err(ArithmeticError::SyntaxError("Unexpected =".to_string()));
168 }
169 }
170
171 '!' => {
172 if let Some(&next_ch) = chars.peek() {
173 if next_ch == '=' {
174 chars.next();
175 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::NotEqual));
176 } else {
177 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::LogicalNot));
178 }
179 } else {
180 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::LogicalNot));
181 }
182 }
183
184 '&' => {
185 if let Some(&next_ch) = chars.peek() {
186 if next_ch == '&' {
187 chars.next();
188 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::LogicalAnd));
189 } else {
190 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::BitwiseAnd));
191 }
192 } else {
193 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::BitwiseAnd));
194 }
195 }
196
197 '|' => {
198 if let Some(&next_ch) = chars.peek() {
199 if next_ch == '|' {
200 chars.next();
201 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::LogicalOr));
202 } else {
203 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::BitwiseOr));
204 }
205 } else {
206 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::BitwiseOr));
207 }
208 }
209
210 '^' => tokens.push(ArithmeticToken::Operator(ArithmeticOperator::BitwiseXor)),
211 '~' => tokens.push(ArithmeticToken::Operator(ArithmeticOperator::BitwiseNot)),
212
213 '0'..='9' => {
215 let mut num_str = String::new();
216 num_str.push(ch);
217 while let Some(&next_ch) = chars.peek() {
218 if next_ch.is_ascii_digit() {
219 num_str.push(next_ch);
220 chars.next();
221 } else {
222 break;
223 }
224 }
225 match num_str.parse::<i64>() {
226 Ok(num) => tokens.push(ArithmeticToken::Number(num)),
227 Err(_) => {
228 return Err(ArithmeticError::SyntaxError("Invalid number".to_string()));
229 }
230 }
231 }
232
233 'a'..='z' | 'A'..='Z' | '_' => {
235 let mut var_name = String::new();
236 var_name.push(ch);
237 while let Some(&next_ch) = chars.peek() {
238 if next_ch.is_alphanumeric() || next_ch == '_' {
239 var_name.push(next_ch);
240 chars.next();
241 } else {
242 break;
243 }
244 }
245 tokens.push(ArithmeticToken::Variable(var_name));
246 }
247
248 _ => {
249 return Err(ArithmeticError::SyntaxError(format!(
250 "Unexpected character: {}",
251 ch
252 )));
253 }
254 }
255 }
256
257 Ok(tokens)
258}
259
260pub fn parse_to_rpn(tokens: Vec<ArithmeticToken>) -> Result<Vec<ArithmeticToken>, ArithmeticError> {
262 let mut output = Vec::new();
263 let mut operators = Vec::new();
264
265 for token in tokens {
266 match token {
267 ArithmeticToken::Number(_) | ArithmeticToken::Variable(_) => {
268 output.push(token);
269 }
270
271 ArithmeticToken::Operator(op) => {
272 if op.is_unary()
274 && (output.is_empty()
275 || matches!(
276 output.last(),
277 Some(ArithmeticToken::Operator(_) | ArithmeticToken::LeftParen)
278 ))
279 {
280 while !operators.is_empty() {
282 if let Some(ArithmeticToken::Operator(top_op)) = operators.last() {
283 if top_op.precedence() >= op.precedence() && !top_op.is_unary() {
284 output.push(operators.pop().unwrap());
285 } else {
286 break;
287 }
288 } else {
289 break;
290 }
291 }
292 operators.push(ArithmeticToken::Operator(op));
293 } else {
294 while !operators.is_empty() {
296 if let Some(ArithmeticToken::Operator(top_op)) = operators.last() {
297 if (top_op.precedence() > op.precedence())
298 || (top_op.precedence() == op.precedence() && !op.is_unary())
299 {
300 output.push(operators.pop().unwrap());
301 } else {
302 break;
303 }
304 } else {
305 break;
306 }
307 }
308 operators.push(ArithmeticToken::Operator(op));
309 }
310 }
311
312 ArithmeticToken::LeftParen => {
313 operators.push(token);
314 }
315
316 ArithmeticToken::RightParen => {
317 let mut found_left = false;
318 while let Some(op) = operators.pop() {
319 if op == ArithmeticToken::LeftParen {
320 found_left = true;
321 break;
322 } else {
323 output.push(op);
324 }
325 }
326 if !found_left {
327 return Err(ArithmeticError::UnmatchedParentheses);
328 }
329 }
330 }
331 }
332
333 while let Some(op) = operators.pop() {
335 if op == ArithmeticToken::LeftParen {
336 return Err(ArithmeticError::UnmatchedParentheses);
337 }
338 output.push(op);
339 }
340
341 Ok(output)
342}
343
344pub fn evaluate_rpn(
346 rpn_tokens: Vec<ArithmeticToken>,
347 shell_state: &ShellState,
348) -> Result<i64, ArithmeticError> {
349 let mut stack = Vec::new();
350
351 for token in rpn_tokens {
352 match token {
353 ArithmeticToken::Number(num) => {
354 stack.push(num);
355 }
356
357 ArithmeticToken::Variable(var_name) => {
358 if let Some(value) = shell_state.get_var(&var_name) {
359 match value.parse::<i64>() {
360 Ok(num) => stack.push(num),
361 Err(_) => {
362 stack.push(0)
364 }
365 }
366 } else {
367 stack.push(0)
369 }
370 }
371
372 ArithmeticToken::Operator(op) => {
373 if op.is_unary() {
374 if stack.is_empty() {
375 return Err(ArithmeticError::SyntaxError(
376 "Missing operand for unary operator".to_string(),
377 ));
378 }
379 let operand = stack.pop().unwrap();
380 let result = match op {
381 ArithmeticOperator::LogicalNot => !operand,
382 ArithmeticOperator::BitwiseNot => !operand,
383 _ => unreachable!(),
384 };
385 stack.push(result);
386 } else {
387 if stack.len() < 2 {
388 return Err(ArithmeticError::SyntaxError(
389 "Missing operands for binary operator".to_string(),
390 ));
391 }
392 let right = stack.pop().unwrap();
393 let left = stack.pop().unwrap();
394 let result = match op {
395 ArithmeticOperator::Add => left + right,
396 ArithmeticOperator::Subtract => left - right,
397 ArithmeticOperator::Multiply => left * right,
398 ArithmeticOperator::Divide => {
399 if right == 0 {
400 return Err(ArithmeticError::DivisionByZero);
401 }
402 left / right
403 }
404 ArithmeticOperator::Modulo => {
405 if right == 0 {
406 return Err(ArithmeticError::DivisionByZero);
407 }
408 left % right
409 }
410 ArithmeticOperator::ShiftLeft => left << right,
411 ArithmeticOperator::ShiftRight => left >> right,
412 ArithmeticOperator::LessThan => {
413 if left < right {
414 1
415 } else {
416 0
417 }
418 }
419 ArithmeticOperator::LessEqual => {
420 if left <= right {
421 1
422 } else {
423 0
424 }
425 }
426 ArithmeticOperator::GreaterThan => {
427 if left > right {
428 1
429 } else {
430 0
431 }
432 }
433 ArithmeticOperator::GreaterEqual => {
434 if left >= right {
435 1
436 } else {
437 0
438 }
439 }
440 ArithmeticOperator::Equal => {
441 if left == right {
442 1
443 } else {
444 0
445 }
446 }
447 ArithmeticOperator::NotEqual => {
448 if left != right {
449 1
450 } else {
451 0
452 }
453 }
454 ArithmeticOperator::BitwiseAnd => left & right,
455 ArithmeticOperator::BitwiseXor => left ^ right,
456 ArithmeticOperator::BitwiseOr => left | right,
457 ArithmeticOperator::LogicalAnd => {
458 if left != 0 && right != 0 {
459 1
460 } else {
461 0
462 }
463 }
464 ArithmeticOperator::LogicalOr => {
465 if left != 0 || right != 0 {
466 1
467 } else {
468 0
469 }
470 }
471 _ => unreachable!(),
472 };
473 stack.push(result);
474 }
475 }
476
477 ArithmeticToken::LeftParen | ArithmeticToken::RightParen => {
478 return Err(ArithmeticError::SyntaxError(
479 "Unexpected parenthesis in RPN".to_string(),
480 ));
481 }
482 }
483 }
484
485 if stack.len() != 1 {
486 return Err(ArithmeticError::SyntaxError(
487 "Invalid expression".to_string(),
488 ));
489 }
490
491 Ok(stack[0])
492}
493
494pub fn evaluate_arithmetic_expression(
496 expr: &str,
497 shell_state: &ShellState,
498) -> Result<i64, ArithmeticError> {
499 if expr.trim().is_empty() {
500 return Err(ArithmeticError::EmptyExpression);
501 }
502
503 let tokens = tokenize_expression(expr)?;
504 let rpn_tokens = parse_to_rpn(tokens)?;
505 let result = evaluate_rpn(rpn_tokens, shell_state)?;
506
507 Ok(result)
508}
509
510#[cfg(test)]
511mod tests {
512 use super::*;
513
514 #[test]
515 fn test_tokenize_simple_numbers() {
516 let tokens = tokenize_expression("42").unwrap();
517 assert_eq!(tokens, vec![ArithmeticToken::Number(42)]);
518 }
519
520 #[test]
521 fn test_tokenize_operators() {
522 let tokens = tokenize_expression("2+3").unwrap();
523 assert_eq!(
524 tokens,
525 vec![
526 ArithmeticToken::Number(2),
527 ArithmeticToken::Operator(ArithmeticOperator::Add),
528 ArithmeticToken::Number(3)
529 ]
530 );
531 }
532
533 #[test]
534 fn test_tokenize_parentheses() {
535 let tokens = tokenize_expression("(2+3)").unwrap();
536 assert_eq!(
537 tokens,
538 vec![
539 ArithmeticToken::LeftParen,
540 ArithmeticToken::Number(2),
541 ArithmeticToken::Operator(ArithmeticOperator::Add),
542 ArithmeticToken::Number(3),
543 ArithmeticToken::RightParen
544 ]
545 );
546 }
547
548 #[test]
549 fn test_tokenize_variables() {
550 let tokens = tokenize_expression("x+y").unwrap();
551 assert_eq!(
552 tokens,
553 vec![
554 ArithmeticToken::Variable("x".to_string()),
555 ArithmeticToken::Operator(ArithmeticOperator::Add),
556 ArithmeticToken::Variable("y".to_string())
557 ]
558 );
559 }
560
561 #[test]
562 fn test_evaluate_simple() {
563 let shell_state = ShellState::new();
564 let result = evaluate_arithmetic_expression("42", &shell_state).unwrap();
565 assert_eq!(result, 42);
566 }
567
568 #[test]
569 fn test_evaluate_addition() {
570 let shell_state = ShellState::new();
571 let result = evaluate_arithmetic_expression("2+3", &shell_state).unwrap();
572 assert_eq!(result, 5);
573 }
574
575 #[test]
576 fn test_evaluate_with_precedence() {
577 let shell_state = ShellState::new();
578 let result = evaluate_arithmetic_expression("2+3*4", &shell_state).unwrap();
579 assert_eq!(result, 14); }
581
582 #[test]
583 fn test_evaluate_with_parentheses() {
584 let shell_state = ShellState::new();
585 let result = evaluate_arithmetic_expression("(2+3)*4", &shell_state).unwrap();
586 assert_eq!(result, 20); }
588
589 #[test]
590 fn test_evaluate_comparison() {
591 let shell_state = ShellState::new();
592 let result = evaluate_arithmetic_expression("5>3", &shell_state).unwrap();
593 assert_eq!(result, 1); let result = evaluate_arithmetic_expression("3>5", &shell_state).unwrap();
596 assert_eq!(result, 0); }
598
599 #[test]
600 fn test_evaluate_variable() {
601 let mut shell_state = ShellState::new();
602 shell_state.set_var("x", "10".to_string());
603 let result = evaluate_arithmetic_expression("x + 5", &shell_state).unwrap();
604 assert_eq!(result, 15);
605 }
606
607 #[test]
608 fn test_evaluate_division_by_zero() {
609 let shell_state = ShellState::new();
610 let result = evaluate_arithmetic_expression("5/0", &shell_state);
611 assert!(matches!(result, Err(ArithmeticError::DivisionByZero)));
612 }
613
614 #[test]
615 fn test_evaluate_undefined_variable() {
616 let shell_state = ShellState::new();
617 let result = evaluate_arithmetic_expression("undefined + 5", &shell_state);
618 assert_eq!(result.unwrap(), 5);
620 }
621}