Skip to main content

brush_core/
arithmetic.rs

1//! Arithmetic evaluation
2
3use std::borrow::Cow;
4
5use crate::{ExecutionParameters, Shell, env, expansion, extensions, variables};
6use brush_parser::ast;
7
8/// Maximum recursion depth for arithmetic variable dereference chains
9/// (e.g., a=b, b=c, c=a would cycle through variable dereferences).
10const MAX_VARIABLE_DEREF_DEPTH: u32 = 1024;
11
12/// Represents an error that occurs during evaluation of an arithmetic expression.
13#[derive(Debug, thiserror::Error)]
14pub enum EvalError {
15    /// Division by zero.
16    #[error("division by zero")]
17    DivideByZero,
18
19    /// Negative exponent.
20    #[error("exponent less than 0")]
21    NegativeExponent,
22
23    /// Failed to tokenize an arithmetic expression.
24    #[error("failed to tokenize expression")]
25    FailedToTokenizeExpression,
26
27    /// Failed to expand an arithmetic expression.
28    #[error("failed to expand expression: {0}")]
29    FailedToExpandExpression(String),
30
31    /// Failed to access an element of an array.
32    #[error("failed to access array")]
33    FailedToAccessArray,
34
35    /// Failed to update the shell environment in an assignment operator.
36    #[error("failed to update environment")]
37    FailedToUpdateEnvironment,
38
39    /// Failed to parse an arithmetic expression.
40    #[error("failed to parse expression: {0}")]
41    ParseError(String),
42
43    /// Error expanding an unset variable.
44    #[error("expanding unset variable: {0}")]
45    ExpandingUnsetVariable(String),
46
47    /// Expression recursion level exceeded.
48    #[error("expression recursion level exceeded")]
49    RecursionLimitExceeded,
50}
51
52/// Trait implemented by arithmetic expressions that can be evaluated.
53pub(crate) trait ExpandAndEvaluate {
54    /// Evaluate the given expression, returning the resulting numeric value.
55    ///
56    /// # Arguments
57    ///
58    /// * `shell` - The shell to use for evaluation.
59    /// * `trace_if_needed` - Whether to trace the evaluation.
60    async fn eval(
61        &self,
62        shell: &mut Shell<impl extensions::ShellExtensions>,
63        params: &ExecutionParameters,
64        trace_if_needed: bool,
65    ) -> Result<i64, EvalError>;
66}
67
68impl ExpandAndEvaluate for ast::UnexpandedArithmeticExpr {
69    async fn eval(
70        &self,
71        shell: &mut Shell<impl extensions::ShellExtensions>,
72        params: &ExecutionParameters,
73        trace_if_needed: bool,
74    ) -> Result<i64, EvalError> {
75        expand_and_eval(shell, params, self.value.as_str(), trace_if_needed).await
76    }
77}
78
79/// Evaluate the given arithmetic expression, returning the resulting numeric value.
80///
81/// # Arguments
82///
83/// * `shell` - The shell to use for evaluation.
84/// * `expr` - The unexpanded arithmetic expression to evaluate.
85/// * `trace_if_needed` - Whether to trace the evaluation.
86pub(crate) async fn expand_and_eval(
87    shell: &mut Shell<impl extensions::ShellExtensions>,
88    params: &ExecutionParameters,
89    expr: &str,
90    trace_if_needed: bool,
91) -> Result<i64, EvalError> {
92    // Per documentation, first shell-expand it.
93    let options = expansion::ExpanderOptions {
94        tilde_expand: false,
95        ..Default::default()
96    };
97    let expanded_self = expansion::basic_expand_word_with_options(shell, params, expr, &options)
98        .await
99        .map_err(|_e| EvalError::FailedToExpandExpression(expr.to_owned()))?;
100
101    // Now parse.
102    let expr = brush_parser::arithmetic::parse(&expanded_self)
103        .map_err(|_e| EvalError::ParseError(expanded_self))?;
104
105    // Trace if applicable.
106    if trace_if_needed && shell.options().print_commands_and_arguments {
107        shell
108            .trace_command(params, std::format!("(( {expr} ))"))
109            .await;
110    }
111
112    // Now evaluate.
113    expr.eval(shell)
114}
115
116/// Trait implemented by evaluatable arithmetic expressions.
117pub trait Evaluatable {
118    /// Evaluate the given arithmetic expression, returning the resulting numeric value.
119    ///
120    /// # Arguments
121    ///
122    /// * `shell` - The shell to use for evaluation.
123    fn eval(&self, shell: &mut Shell<impl extensions::ShellExtensions>) -> Result<i64, EvalError>;
124}
125
126impl Evaluatable for ast::ArithmeticExpr {
127    fn eval(&self, shell: &mut Shell<impl extensions::ShellExtensions>) -> Result<i64, EvalError> {
128        eval_expr_impl(self, shell, 0)
129    }
130}
131
132fn eval_expr_impl(
133    expr: &ast::ArithmeticExpr,
134    shell: &mut Shell<impl extensions::ShellExtensions>,
135    depth: u32,
136) -> Result<i64, EvalError> {
137    let value = match expr {
138        ast::ArithmeticExpr::Literal(l) => *l,
139        ast::ArithmeticExpr::Reference(lvalue) => deref_lvalue(shell, lvalue, depth)?,
140        ast::ArithmeticExpr::UnaryOp(op, operand) => apply_unary_op(shell, *op, operand, depth)?,
141        ast::ArithmeticExpr::BinaryOp(op, left, right) => {
142            apply_binary_op(shell, *op, left, right, depth)?
143        }
144        ast::ArithmeticExpr::Conditional(condition, then_expr, else_expr) => {
145            let conditional_eval = eval_expr_impl(condition, shell, depth)?;
146
147            // Ensure we only evaluate the branch indicated by the condition.
148            if conditional_eval != 0 {
149                eval_expr_impl(then_expr, shell, depth)?
150            } else {
151                eval_expr_impl(else_expr, shell, depth)?
152            }
153        }
154        ast::ArithmeticExpr::Assignment(lvalue, rhs) => {
155            let expr_eval = eval_expr_impl(rhs, shell, depth)?;
156            assign(shell, lvalue, expr_eval, depth)?
157        }
158        ast::ArithmeticExpr::UnaryAssignment(op, lvalue) => {
159            apply_unary_assignment_op(shell, lvalue, *op, depth)?
160        }
161        ast::ArithmeticExpr::BinaryAssignment(op, lvalue, operand) => {
162            let value = apply_binary_op(
163                shell,
164                *op,
165                &ast::ArithmeticExpr::Reference(lvalue.clone()),
166                operand,
167                depth,
168            )?;
169            assign(shell, lvalue, value, depth)?
170        }
171    };
172
173    Ok(value)
174}
175
176fn get_var_value<'a>(
177    shell: &'a Shell<impl extensions::ShellExtensions>,
178    name: &str,
179) -> Result<Cow<'a, str>, EvalError> {
180    let value = shell.env_var(name).map(|var| var.resolve_value(shell));
181
182    if let Some(value) = value
183        && value.is_set()
184    {
185        return Ok(value.to_cow_str(shell).to_string().into());
186    }
187
188    if shell.options().treat_unset_variables_as_error {
189        return Err(EvalError::ExpandingUnsetVariable(name.into()));
190    }
191
192    Ok("".into())
193}
194
195fn deref_lvalue(
196    shell: &mut Shell<impl extensions::ShellExtensions>,
197    lvalue: &ast::ArithmeticTarget,
198    depth: u32,
199) -> Result<i64, EvalError> {
200    let value_str: Cow<'_, str> = match lvalue {
201        ast::ArithmeticTarget::Variable(name) => get_var_value(shell, name.as_str())?,
202        ast::ArithmeticTarget::ArrayElement(name, index_expr) => {
203            let index_str = eval_expr_impl(index_expr, shell, depth)?.to_string();
204
205            shell
206                .env()
207                .get(name)
208                .map_or_else(
209                    || Ok(None),
210                    |(_, v)| v.value().get_at(index_str.as_str(), shell),
211                )
212                .map_err(|_err| EvalError::FailedToAccessArray)?
213                .unwrap_or(Cow::Borrowed(""))
214        }
215    };
216
217    let parsed_value = brush_parser::arithmetic::parse(value_str.as_ref())
218        .map_err(|_err| EvalError::ParseError(value_str.to_string()))?;
219
220    // Literals don't need depth tracking — they can't cause recursion.
221    // Only increment depth when the parsed value requires further evaluation
222    // (i.e., it references other variables), matching bash's behavior.
223    if matches!(parsed_value, ast::ArithmeticExpr::Literal(_)) {
224        return eval_expr_impl(&parsed_value, shell, depth);
225    }
226
227    let new_depth = depth + 1;
228    if new_depth > MAX_VARIABLE_DEREF_DEPTH {
229        return Err(EvalError::RecursionLimitExceeded);
230    }
231
232    eval_expr_impl(&parsed_value, shell, new_depth)
233}
234
235fn apply_unary_op(
236    shell: &mut Shell<impl extensions::ShellExtensions>,
237    op: ast::UnaryOperator,
238    operand: &ast::ArithmeticExpr,
239    depth: u32,
240) -> Result<i64, EvalError> {
241    let operand_eval = eval_expr_impl(operand, shell, depth)?;
242
243    match op {
244        ast::UnaryOperator::UnaryPlus => Ok(operand_eval),
245        ast::UnaryOperator::UnaryMinus => Ok(operand_eval.wrapping_neg()),
246        ast::UnaryOperator::BitwiseNot => Ok(!operand_eval),
247        ast::UnaryOperator::LogicalNot => Ok(bool_to_i64(operand_eval == 0)),
248    }
249}
250
251fn apply_binary_op(
252    shell: &mut Shell<impl extensions::ShellExtensions>,
253    op: ast::BinaryOperator,
254    left: &ast::ArithmeticExpr,
255    right: &ast::ArithmeticExpr,
256    depth: u32,
257) -> Result<i64, EvalError> {
258    // First, special-case short-circuiting operators. For those, we need
259    // to ensure we don't eagerly evaluate both operands. After we
260    // get these out of the way, we can easily just evaluate operands
261    // for the other operators.
262    match op {
263        ast::BinaryOperator::LogicalAnd => {
264            let left = eval_expr_impl(left, shell, depth)?;
265            if left == 0 {
266                return Ok(bool_to_i64(false));
267            }
268
269            let right = eval_expr_impl(right, shell, depth)?;
270            return Ok(bool_to_i64(right != 0));
271        }
272        ast::BinaryOperator::LogicalOr => {
273            let left = eval_expr_impl(left, shell, depth)?;
274            if left != 0 {
275                return Ok(bool_to_i64(true));
276            }
277
278            let right = eval_expr_impl(right, shell, depth)?;
279            return Ok(bool_to_i64(right != 0));
280        }
281        _ => (),
282    }
283
284    // The remaining operators unconditionally operate both operands.
285    let left = eval_expr_impl(left, shell, depth)?;
286    let right = eval_expr_impl(right, shell, depth)?;
287
288    #[expect(clippy::cast_possible_truncation)]
289    #[expect(clippy::cast_sign_loss)]
290    match op {
291        ast::BinaryOperator::Power => {
292            if right >= 0 {
293                Ok(wrapping_pow_u64(left, right as u64))
294            } else {
295                Err(EvalError::NegativeExponent)
296            }
297        }
298        ast::BinaryOperator::Multiply => Ok(left.wrapping_mul(right)),
299        ast::BinaryOperator::Divide => {
300            if right == 0 {
301                Err(EvalError::DivideByZero)
302            } else {
303                Ok(left.wrapping_div(right))
304            }
305        }
306        ast::BinaryOperator::Modulo => {
307            if right == 0 {
308                Err(EvalError::DivideByZero)
309            } else {
310                Ok(left.wrapping_rem(right))
311            }
312        }
313        ast::BinaryOperator::Comma => Ok(right),
314        ast::BinaryOperator::Add => Ok(left.wrapping_add(right)),
315        ast::BinaryOperator::Subtract => Ok(left.wrapping_sub(right)),
316        ast::BinaryOperator::ShiftLeft => Ok(left.wrapping_shl(right as u32)),
317        ast::BinaryOperator::ShiftRight => Ok(left.wrapping_shr(right as u32)),
318        ast::BinaryOperator::LessThan => Ok(bool_to_i64(left < right)),
319        ast::BinaryOperator::LessThanOrEqualTo => Ok(bool_to_i64(left <= right)),
320        ast::BinaryOperator::GreaterThan => Ok(bool_to_i64(left > right)),
321        ast::BinaryOperator::GreaterThanOrEqualTo => Ok(bool_to_i64(left >= right)),
322        ast::BinaryOperator::Equals => Ok(bool_to_i64(left == right)),
323        ast::BinaryOperator::NotEquals => Ok(bool_to_i64(left != right)),
324        ast::BinaryOperator::BitwiseAnd => Ok(left & right),
325        ast::BinaryOperator::BitwiseXor => Ok(left ^ right),
326        ast::BinaryOperator::BitwiseOr => Ok(left | right),
327        ast::BinaryOperator::LogicalAnd => unreachable!("LogicalAnd covered above"),
328        ast::BinaryOperator::LogicalOr => unreachable!("LogicalOr covered above"),
329    }
330}
331
332fn apply_unary_assignment_op(
333    shell: &mut Shell<impl extensions::ShellExtensions>,
334    lvalue: &ast::ArithmeticTarget,
335    op: ast::UnaryAssignmentOperator,
336    depth: u32,
337) -> Result<i64, EvalError> {
338    let value = deref_lvalue(shell, lvalue, depth)?;
339
340    match op {
341        ast::UnaryAssignmentOperator::PrefixIncrement => {
342            let new_value = value.wrapping_add(1);
343            assign(shell, lvalue, new_value, depth)?;
344            Ok(new_value)
345        }
346        ast::UnaryAssignmentOperator::PrefixDecrement => {
347            let new_value = value.wrapping_sub(1);
348            assign(shell, lvalue, new_value, depth)?;
349            Ok(new_value)
350        }
351        ast::UnaryAssignmentOperator::PostfixIncrement => {
352            let new_value = value.wrapping_add(1);
353            assign(shell, lvalue, new_value, depth)?;
354            Ok(value)
355        }
356        ast::UnaryAssignmentOperator::PostfixDecrement => {
357            let new_value = value.wrapping_sub(1);
358            assign(shell, lvalue, new_value, depth)?;
359            Ok(value)
360        }
361    }
362}
363
364fn assign(
365    shell: &mut Shell<impl extensions::ShellExtensions>,
366    lvalue: &ast::ArithmeticTarget,
367    value: i64,
368    depth: u32,
369) -> Result<i64, EvalError> {
370    match lvalue {
371        ast::ArithmeticTarget::Variable(name) => {
372            shell
373                .env_mut()
374                .update_or_add(
375                    name.as_str(),
376                    variables::ShellValueLiteral::Scalar(value.to_string()),
377                    |_| Ok(()),
378                    env::EnvironmentLookup::Anywhere,
379                    env::EnvironmentScope::Global,
380                )
381                .map_err(|_err| EvalError::FailedToUpdateEnvironment)?;
382        }
383        ast::ArithmeticTarget::ArrayElement(name, index_expr) => {
384            let index_str = eval_expr_impl(index_expr, shell, depth)?.to_string();
385
386            shell
387                .env_mut()
388                .update_or_add_array_element(
389                    name.as_str(),
390                    index_str,
391                    value.to_string(),
392                    |_| Ok(()),
393                    env::EnvironmentLookup::Anywhere,
394                    env::EnvironmentScope::Global,
395                )
396                .map_err(|_err| EvalError::FailedToUpdateEnvironment)?;
397        }
398    }
399
400    Ok(value)
401}
402
403const fn bool_to_i64(value: bool) -> i64 {
404    if value { 1 } else { 0 }
405}
406
407// N.B. We implement our own version of wrapping_pow that takes a 64-bit exponent.
408// This seems to be the best way to guarantee that we handle overflow cases
409// with exponents correctly.
410const fn wrapping_pow_u64(mut base: i64, mut exponent: u64) -> i64 {
411    let mut result: i64 = 1;
412
413    while exponent > 0 {
414        if exponent % 2 == 1 {
415            result = result.wrapping_mul(base);
416        }
417
418        base = base.wrapping_mul(base);
419        exponent /= 2;
420    }
421
422    result
423}