brush_core/
arithmetic.rs

1//! Arithmetic evaluation
2
3use std::borrow::Cow;
4
5use crate::{ExecutionParameters, Shell, env, expansion, variables};
6use brush_parser::ast;
7
8/// Represents an error that occurs during evaluation of an arithmetic expression.
9#[derive(Debug, thiserror::Error)]
10pub enum EvalError {
11    /// Division by zero.
12    #[error("division by zero")]
13    DivideByZero,
14
15    /// Negative exponent.
16    #[error("exponent less than 0")]
17    NegativeExponent,
18
19    /// Failed to tokenize an arithmetic expression.
20    #[error("failed to tokenize expression")]
21    FailedToTokenizeExpression,
22
23    /// Failed to expand an arithmetic expression.
24    #[error("failed to expand expression: '{0}'")]
25    FailedToExpandExpression(String),
26
27    /// Failed to access an element of an array.
28    #[error("failed to access array")]
29    FailedToAccessArray,
30
31    /// Failed to update the shell environment in an assignment operator.
32    #[error("failed to update environment")]
33    FailedToUpdateEnvironment,
34
35    /// Failed to parse an arithmetic expression.
36    #[error("failed to parse expression: '{0}'")]
37    ParseError(String),
38
39    /// Failed to trace an arithmetic expression.
40    #[error("failed tracing expression")]
41    TraceError,
42}
43
44/// Trait implemented by arithmetic expressions that can be evaluated.
45pub(crate) trait ExpandAndEvaluate {
46    /// Evaluate the given expression, returning the resulting numeric value.
47    ///
48    /// # Arguments
49    ///
50    /// * `shell` - The shell to use for evaluation.
51    /// * `trace_if_needed` - Whether to trace the evaluation.
52    async fn eval(
53        &self,
54        shell: &mut Shell,
55        params: &ExecutionParameters,
56        trace_if_needed: bool,
57    ) -> Result<i64, EvalError>;
58}
59
60impl ExpandAndEvaluate for ast::UnexpandedArithmeticExpr {
61    async fn eval(
62        &self,
63        shell: &mut Shell,
64        params: &ExecutionParameters,
65        trace_if_needed: bool,
66    ) -> Result<i64, EvalError> {
67        expand_and_eval(shell, params, self.value.as_str(), trace_if_needed).await
68    }
69}
70
71/// Evaluate the given arithmetic expression, returning the resulting numeric value.
72///
73/// # Arguments
74///
75/// * `shell` - The shell to use for evaluation.
76/// * `expr` - The unexpanded arithmetic expression to evaluate.
77/// * `trace_if_needed` - Whether to trace the evaluation.
78pub(crate) async fn expand_and_eval(
79    shell: &mut Shell,
80    params: &ExecutionParameters,
81    expr: &str,
82    trace_if_needed: bool,
83) -> Result<i64, EvalError> {
84    // Per documentation, first shell-expand it.
85    let expanded_self = expansion::basic_expand_str_without_tilde(shell, params, expr)
86        .await
87        .map_err(|_e| EvalError::FailedToExpandExpression(expr.to_owned()))?;
88
89    // Now parse.
90    let expr = brush_parser::arithmetic::parse(&expanded_self)
91        .map_err(|_e| EvalError::ParseError(expanded_self))?;
92
93    // Trace if applicable.
94    if trace_if_needed && shell.options.print_commands_and_arguments {
95        shell
96            .trace_command(params, std::format!("(( {expr} ))"))
97            .await
98            .map_err(|_err| EvalError::TraceError)?;
99    }
100
101    // Now evaluate.
102    expr.eval(shell)
103}
104
105/// Trait implemented by evaluatable arithmetic expressions.
106pub trait Evaluatable {
107    /// Evaluate the given arithmetic expression, returning the resulting numeric value.
108    ///
109    /// # Arguments
110    ///
111    /// * `shell` - The shell to use for evaluation.
112    fn eval(&self, shell: &mut Shell) -> Result<i64, EvalError>;
113}
114
115impl Evaluatable for ast::ArithmeticExpr {
116    fn eval(&self, shell: &mut Shell) -> Result<i64, EvalError> {
117        let value = match self {
118            Self::Literal(l) => *l,
119            Self::Reference(lvalue) => deref_lvalue(shell, lvalue)?,
120            Self::UnaryOp(op, operand) => apply_unary_op(shell, *op, operand)?,
121            Self::BinaryOp(op, left, right) => apply_binary_op(shell, *op, left, right)?,
122            Self::Conditional(condition, then_expr, else_expr) => {
123                let conditional_eval = condition.eval(shell)?;
124
125                // Ensure we only evaluate the branch indicated by the condition.
126                if conditional_eval != 0 {
127                    then_expr.eval(shell)?
128                } else {
129                    else_expr.eval(shell)?
130                }
131            }
132            Self::Assignment(lvalue, expr) => {
133                let expr_eval = expr.eval(shell)?;
134                assign(shell, lvalue, expr_eval)?
135            }
136            Self::UnaryAssignment(op, lvalue) => apply_unary_assignment_op(shell, lvalue, *op)?,
137            Self::BinaryAssignment(op, lvalue, operand) => {
138                let value = apply_binary_op(shell, *op, &Self::Reference(lvalue.clone()), operand)?;
139                assign(shell, lvalue, value)?
140            }
141        };
142
143        Ok(value)
144    }
145}
146
147fn deref_lvalue(shell: &mut Shell, lvalue: &ast::ArithmeticTarget) -> Result<i64, EvalError> {
148    let value_str: Cow<'_, str> = match lvalue {
149        ast::ArithmeticTarget::Variable(name) => shell.env_str(name).unwrap_or(Cow::Borrowed("")),
150        ast::ArithmeticTarget::ArrayElement(name, index_expr) => {
151            let index_str = index_expr.eval(shell)?.to_string();
152
153            shell
154                .env
155                .get(name)
156                .map_or_else(
157                    || Ok(None),
158                    |(_, v)| v.value().get_at(index_str.as_str(), shell),
159                )
160                .map_err(|_err| EvalError::FailedToAccessArray)?
161                .unwrap_or(Cow::Borrowed(""))
162        }
163    };
164
165    let parsed_value = brush_parser::arithmetic::parse(value_str.as_ref())
166        .map_err(|_err| EvalError::ParseError(value_str.to_string()))?;
167
168    parsed_value.eval(shell)
169}
170
171fn apply_unary_op(
172    shell: &mut Shell,
173    op: ast::UnaryOperator,
174    operand: &ast::ArithmeticExpr,
175) -> Result<i64, EvalError> {
176    let operand_eval = operand.eval(shell)?;
177
178    match op {
179        ast::UnaryOperator::UnaryPlus => Ok(operand_eval),
180        ast::UnaryOperator::UnaryMinus => Ok(-operand_eval),
181        ast::UnaryOperator::BitwiseNot => Ok(!operand_eval),
182        ast::UnaryOperator::LogicalNot => Ok(bool_to_i64(operand_eval == 0)),
183    }
184}
185
186fn apply_binary_op(
187    shell: &mut Shell,
188    op: ast::BinaryOperator,
189    left: &ast::ArithmeticExpr,
190    right: &ast::ArithmeticExpr,
191) -> Result<i64, EvalError> {
192    // First, special-case short-circuiting operators. For those, we need
193    // to ensure we don't eagerly evaluate both operands. After we
194    // get these out of the way, we can easily just evaluate operands
195    // for the other operators.
196    match op {
197        ast::BinaryOperator::LogicalAnd => {
198            let left = left.eval(shell)?;
199            if left == 0 {
200                return Ok(bool_to_i64(false));
201            }
202
203            let right = right.eval(shell)?;
204            return Ok(bool_to_i64(right != 0));
205        }
206        ast::BinaryOperator::LogicalOr => {
207            let left = left.eval(shell)?;
208            if left != 0 {
209                return Ok(bool_to_i64(true));
210            }
211
212            let right = right.eval(shell)?;
213            return Ok(bool_to_i64(right != 0));
214        }
215        _ => (),
216    }
217
218    // The remaining operators unconditionally operate both operands.
219    let left = left.eval(shell)?;
220    let right = right.eval(shell)?;
221
222    #[expect(clippy::cast_possible_truncation)]
223    #[expect(clippy::cast_sign_loss)]
224    match op {
225        ast::BinaryOperator::Power => {
226            if right >= 0 {
227                Ok(wrapping_pow_u64(left, right as u64))
228            } else {
229                Err(EvalError::NegativeExponent)
230            }
231        }
232        ast::BinaryOperator::Multiply => Ok(left.wrapping_mul(right)),
233        ast::BinaryOperator::Divide => {
234            if right == 0 {
235                Err(EvalError::DivideByZero)
236            } else {
237                Ok(left.wrapping_div(right))
238            }
239        }
240        ast::BinaryOperator::Modulo => {
241            if right == 0 {
242                Err(EvalError::DivideByZero)
243            } else {
244                Ok(left % right)
245            }
246        }
247        ast::BinaryOperator::Comma => Ok(right),
248        ast::BinaryOperator::Add => Ok(left.wrapping_add(right)),
249        ast::BinaryOperator::Subtract => Ok(left.wrapping_sub(right)),
250        ast::BinaryOperator::ShiftLeft => Ok(left.wrapping_shl(right as u32)),
251        ast::BinaryOperator::ShiftRight => Ok(left.wrapping_shr(right as u32)),
252        ast::BinaryOperator::LessThan => Ok(bool_to_i64(left < right)),
253        ast::BinaryOperator::LessThanOrEqualTo => Ok(bool_to_i64(left <= right)),
254        ast::BinaryOperator::GreaterThan => Ok(bool_to_i64(left > right)),
255        ast::BinaryOperator::GreaterThanOrEqualTo => Ok(bool_to_i64(left >= right)),
256        ast::BinaryOperator::Equals => Ok(bool_to_i64(left == right)),
257        ast::BinaryOperator::NotEquals => Ok(bool_to_i64(left != right)),
258        ast::BinaryOperator::BitwiseAnd => Ok(left & right),
259        ast::BinaryOperator::BitwiseXor => Ok(left ^ right),
260        ast::BinaryOperator::BitwiseOr => Ok(left | right),
261        ast::BinaryOperator::LogicalAnd => unreachable!("LogicalAnd covered above"),
262        ast::BinaryOperator::LogicalOr => unreachable!("LogicalOr covered above"),
263    }
264}
265
266fn apply_unary_assignment_op(
267    shell: &mut Shell,
268    lvalue: &ast::ArithmeticTarget,
269    op: ast::UnaryAssignmentOperator,
270) -> Result<i64, EvalError> {
271    let value = deref_lvalue(shell, lvalue)?;
272
273    match op {
274        ast::UnaryAssignmentOperator::PrefixIncrement => {
275            let new_value = value + 1;
276            assign(shell, lvalue, new_value)?;
277            Ok(new_value)
278        }
279        ast::UnaryAssignmentOperator::PrefixDecrement => {
280            let new_value = value - 1;
281            assign(shell, lvalue, new_value)?;
282            Ok(new_value)
283        }
284        ast::UnaryAssignmentOperator::PostfixIncrement => {
285            let new_value = value + 1;
286            assign(shell, lvalue, new_value)?;
287            Ok(value)
288        }
289        ast::UnaryAssignmentOperator::PostfixDecrement => {
290            let new_value = value - 1;
291            assign(shell, lvalue, new_value)?;
292            Ok(value)
293        }
294    }
295}
296
297fn assign(shell: &mut Shell, lvalue: &ast::ArithmeticTarget, value: i64) -> Result<i64, EvalError> {
298    match lvalue {
299        ast::ArithmeticTarget::Variable(name) => {
300            shell
301                .env
302                .update_or_add(
303                    name.as_str(),
304                    variables::ShellValueLiteral::Scalar(value.to_string()),
305                    |_| Ok(()),
306                    env::EnvironmentLookup::Anywhere,
307                    env::EnvironmentScope::Global,
308                )
309                .map_err(|_err| EvalError::FailedToUpdateEnvironment)?;
310        }
311        ast::ArithmeticTarget::ArrayElement(name, index_expr) => {
312            let index_str = index_expr.eval(shell)?.to_string();
313
314            shell
315                .env
316                .update_or_add_array_element(
317                    name.as_str(),
318                    index_str,
319                    value.to_string(),
320                    |_| Ok(()),
321                    env::EnvironmentLookup::Anywhere,
322                    env::EnvironmentScope::Global,
323                )
324                .map_err(|_err| EvalError::FailedToUpdateEnvironment)?;
325        }
326    }
327
328    Ok(value)
329}
330
331const fn bool_to_i64(value: bool) -> i64 {
332    if value { 1 } else { 0 }
333}
334
335// N.B. We implement our own version of wrapping_pow that takes a 64-bit exponent.
336// This seems to be the best way to guarantee that we handle overflow cases
337// with exponents correctly.
338const fn wrapping_pow_u64(mut base: i64, mut exponent: u64) -> i64 {
339    let mut result: i64 = 1;
340
341    while exponent > 0 {
342        if exponent % 2 == 1 {
343            result = result.wrapping_mul(base);
344        }
345
346        base = base.wrapping_mul(base);
347        exponent /= 2;
348    }
349
350    result
351}