brush_core/
arithmetic.rs

1use std::borrow::Cow;
2
3use crate::{ExecutionParameters, Shell, env, expansion, variables};
4use brush_parser::ast;
5
6/// Represents an error that occurs during evaluation of an arithmetic expression.
7#[derive(Debug, thiserror::Error)]
8pub enum EvalError {
9    /// Division by zero.
10    #[error("division by zero")]
11    DivideByZero,
12
13    /// Negative exponent.
14    #[error("exponent less than 0")]
15    NegativeExponent,
16
17    /// Failed to tokenize an arithmetic expression.
18    #[error("failed to tokenize expression")]
19    FailedToTokenizeExpression,
20
21    /// Failed to expand an arithmetic expression.
22    #[error("failed to expand expression: '{0}'")]
23    FailedToExpandExpression(String),
24
25    /// Failed to access an element of an array.
26    #[error("failed to access array")]
27    FailedToAccessArray,
28
29    /// Failed to update the shell environment in an assignment operator.
30    #[error("failed to update environment")]
31    FailedToUpdateEnvironment,
32
33    /// Failed to parse an arithmetic expression.
34    #[error("failed to parse expression: '{0}'")]
35    ParseError(String),
36
37    /// Failed to trace an arithmetic expression.
38    #[error("failed tracing expression")]
39    TraceError,
40}
41
42/// Trait implemented by arithmetic expressions that can be evaluated.
43pub trait ExpandAndEvaluate {
44    /// Evaluate the given expression, returning the resulting numeric value.
45    ///
46    /// # Arguments
47    ///
48    /// * `shell` - The shell to use for evaluation.
49    /// * `trace_if_needed` - Whether to trace the evaluation.
50    async fn eval(
51        &self,
52        shell: &mut Shell,
53        params: &ExecutionParameters,
54        trace_if_needed: bool,
55    ) -> Result<i64, EvalError>;
56}
57
58impl ExpandAndEvaluate for ast::UnexpandedArithmeticExpr {
59    async fn eval(
60        &self,
61        shell: &mut Shell,
62        params: &ExecutionParameters,
63        trace_if_needed: bool,
64    ) -> Result<i64, EvalError> {
65        expand_and_eval(shell, params, self.value.as_str(), trace_if_needed).await
66    }
67}
68
69/// Evaluate the given arithmetic expression, returning the resulting numeric value.
70///
71/// # Arguments
72///
73/// * `shell` - The shell to use for evaluation.
74/// * `expr` - The unexpanded arithmetic expression to evaluate.
75/// * `trace_if_needed` - Whether to trace the evaluation.
76pub(crate) async fn expand_and_eval(
77    shell: &mut Shell,
78    params: &ExecutionParameters,
79    expr: &str,
80    trace_if_needed: bool,
81) -> Result<i64, EvalError> {
82    // Per documentation, first shell-expand it.
83    let expanded_self = expansion::basic_expand_str_without_tilde(shell, params, expr)
84        .await
85        .map_err(|_e| EvalError::FailedToExpandExpression(expr.to_owned()))?;
86
87    // Now parse.
88    let expr = brush_parser::arithmetic::parse(&expanded_self)
89        .map_err(|_e| EvalError::ParseError(expanded_self))?;
90
91    // Trace if applicable.
92    if trace_if_needed && shell.options.print_commands_and_arguments {
93        shell
94            .trace_command(std::format!("(( {expr} ))"))
95            .await
96            .map_err(|_err| EvalError::TraceError)?;
97    }
98
99    // Now evaluate.
100    expr.eval(shell)
101}
102
103/// Trait implemented by evaluatable arithmetic expressions.
104pub trait Evaluatable {
105    /// Evaluate the given arithmetic expression, returning the resulting numeric value.
106    ///
107    /// # Arguments
108    ///
109    /// * `shell` - The shell to use for evaluation.
110    fn eval(&self, shell: &mut Shell) -> Result<i64, EvalError>;
111}
112
113impl Evaluatable for ast::ArithmeticExpr {
114    fn eval(&self, shell: &mut Shell) -> Result<i64, EvalError> {
115        let value = match self {
116            Self::Literal(l) => *l,
117            Self::Reference(lvalue) => deref_lvalue(shell, lvalue)?,
118            Self::UnaryOp(op, operand) => apply_unary_op(shell, *op, operand)?,
119            Self::BinaryOp(op, left, right) => apply_binary_op(shell, *op, left, right)?,
120            Self::Conditional(condition, then_expr, else_expr) => {
121                let conditional_eval = condition.eval(shell)?;
122
123                // Ensure we only evaluate the branch indicated by the condition.
124                if conditional_eval != 0 {
125                    then_expr.eval(shell)?
126                } else {
127                    else_expr.eval(shell)?
128                }
129            }
130            Self::Assignment(lvalue, expr) => {
131                let expr_eval = expr.eval(shell)?;
132                assign(shell, lvalue, expr_eval)?
133            }
134            Self::UnaryAssignment(op, lvalue) => apply_unary_assignment_op(shell, lvalue, *op)?,
135            Self::BinaryAssignment(op, lvalue, operand) => {
136                let value = apply_binary_op(shell, *op, &Self::Reference(lvalue.clone()), operand)?;
137                assign(shell, lvalue, value)?
138            }
139        };
140
141        Ok(value)
142    }
143}
144
145fn deref_lvalue(shell: &mut Shell, lvalue: &ast::ArithmeticTarget) -> Result<i64, EvalError> {
146    let value_str: Cow<'_, str> = match lvalue {
147        ast::ArithmeticTarget::Variable(name) => {
148            shell.get_env_str(name).unwrap_or(Cow::Borrowed(""))
149        }
150        ast::ArithmeticTarget::ArrayElement(name, index_expr) => {
151            let index_str = index_expr.eval(shell)?.to_string();
152
153            shell
154                .env
155                .get(name)
156                .map_or_else(
157                    || Ok(None),
158                    |(_, v)| v.value().get_at(index_str.as_str(), shell),
159                )
160                .map_err(|_err| EvalError::FailedToAccessArray)?
161                .unwrap_or(Cow::Borrowed(""))
162        }
163    };
164
165    let parsed_value = brush_parser::arithmetic::parse(value_str.as_ref())
166        .map_err(|_err| EvalError::ParseError(value_str.to_string()))?;
167
168    parsed_value.eval(shell)
169}
170
171fn apply_unary_op(
172    shell: &mut Shell,
173    op: ast::UnaryOperator,
174    operand: &ast::ArithmeticExpr,
175) -> Result<i64, EvalError> {
176    let operand_eval = operand.eval(shell)?;
177
178    match op {
179        ast::UnaryOperator::UnaryPlus => Ok(operand_eval),
180        ast::UnaryOperator::UnaryMinus => Ok(-operand_eval),
181        ast::UnaryOperator::BitwiseNot => Ok(!operand_eval),
182        ast::UnaryOperator::LogicalNot => Ok(bool_to_i64(operand_eval == 0)),
183    }
184}
185
186fn apply_binary_op(
187    shell: &mut Shell,
188    op: ast::BinaryOperator,
189    left: &ast::ArithmeticExpr,
190    right: &ast::ArithmeticExpr,
191) -> Result<i64, EvalError> {
192    // First, special-case short-circuiting operators. For those, we need
193    // to ensure we don't eagerly evaluate both operands. After we
194    // get these out of the way, we can easily just evaluate operands
195    // for the other operators.
196    match op {
197        ast::BinaryOperator::LogicalAnd => {
198            let left = left.eval(shell)?;
199            if left == 0 {
200                return Ok(bool_to_i64(false));
201            }
202
203            let right = right.eval(shell)?;
204            return Ok(bool_to_i64(right != 0));
205        }
206        ast::BinaryOperator::LogicalOr => {
207            let left = left.eval(shell)?;
208            if left != 0 {
209                return Ok(bool_to_i64(true));
210            }
211
212            let right = right.eval(shell)?;
213            return Ok(bool_to_i64(right != 0));
214        }
215        _ => (),
216    }
217
218    // The remaining operators unconditionally operate both operands.
219    let left = left.eval(shell)?;
220    let right = right.eval(shell)?;
221
222    #[allow(clippy::cast_possible_truncation)]
223    #[allow(clippy::cast_sign_loss)]
224    match op {
225        ast::BinaryOperator::Power => {
226            if right >= 0 {
227                Ok(wrapping_pow_u64(left, right as u64))
228            } else {
229                Err(EvalError::NegativeExponent)
230            }
231        }
232        ast::BinaryOperator::Multiply => Ok(left.wrapping_mul(right)),
233        ast::BinaryOperator::Divide => {
234            if right == 0 {
235                Err(EvalError::DivideByZero)
236            } else {
237                Ok(left.wrapping_div(right))
238            }
239        }
240        ast::BinaryOperator::Modulo => {
241            if right == 0 {
242                Err(EvalError::DivideByZero)
243            } else {
244                Ok(left % right)
245            }
246        }
247        ast::BinaryOperator::Comma => Ok(right),
248        ast::BinaryOperator::Add => Ok(left.wrapping_add(right)),
249        ast::BinaryOperator::Subtract => Ok(left.wrapping_sub(right)),
250        ast::BinaryOperator::ShiftLeft => Ok(left.wrapping_shl(right as u32)),
251        ast::BinaryOperator::ShiftRight => Ok(left.wrapping_shr(right as u32)),
252        ast::BinaryOperator::LessThan => Ok(bool_to_i64(left < right)),
253        ast::BinaryOperator::LessThanOrEqualTo => Ok(bool_to_i64(left <= right)),
254        ast::BinaryOperator::GreaterThan => Ok(bool_to_i64(left > right)),
255        ast::BinaryOperator::GreaterThanOrEqualTo => Ok(bool_to_i64(left >= right)),
256        ast::BinaryOperator::Equals => Ok(bool_to_i64(left == right)),
257        ast::BinaryOperator::NotEquals => Ok(bool_to_i64(left != right)),
258        ast::BinaryOperator::BitwiseAnd => Ok(left & right),
259        ast::BinaryOperator::BitwiseXor => Ok(left ^ right),
260        ast::BinaryOperator::BitwiseOr => Ok(left | right),
261        ast::BinaryOperator::LogicalAnd => unreachable!("LogicalAnd covered above"),
262        ast::BinaryOperator::LogicalOr => unreachable!("LogicalOr covered above"),
263    }
264}
265
266fn apply_unary_assignment_op(
267    shell: &mut Shell,
268    lvalue: &ast::ArithmeticTarget,
269    op: ast::UnaryAssignmentOperator,
270) -> Result<i64, EvalError> {
271    let value = deref_lvalue(shell, lvalue)?;
272
273    match op {
274        ast::UnaryAssignmentOperator::PrefixIncrement => {
275            let new_value = value + 1;
276            assign(shell, lvalue, new_value)?;
277            Ok(new_value)
278        }
279        ast::UnaryAssignmentOperator::PrefixDecrement => {
280            let new_value = value - 1;
281            assign(shell, lvalue, new_value)?;
282            Ok(new_value)
283        }
284        ast::UnaryAssignmentOperator::PostfixIncrement => {
285            let new_value = value + 1;
286            assign(shell, lvalue, new_value)?;
287            Ok(value)
288        }
289        ast::UnaryAssignmentOperator::PostfixDecrement => {
290            let new_value = value - 1;
291            assign(shell, lvalue, new_value)?;
292            Ok(value)
293        }
294    }
295}
296
297fn assign(shell: &mut Shell, lvalue: &ast::ArithmeticTarget, value: i64) -> Result<i64, EvalError> {
298    match lvalue {
299        ast::ArithmeticTarget::Variable(name) => {
300            shell
301                .env
302                .update_or_add(
303                    name.as_str(),
304                    variables::ShellValueLiteral::Scalar(value.to_string()),
305                    |_| Ok(()),
306                    env::EnvironmentLookup::Anywhere,
307                    env::EnvironmentScope::Global,
308                )
309                .map_err(|_err| EvalError::FailedToUpdateEnvironment)?;
310        }
311        ast::ArithmeticTarget::ArrayElement(name, index_expr) => {
312            let index_str = index_expr.eval(shell)?.to_string();
313
314            shell
315                .env
316                .update_or_add_array_element(
317                    name.as_str(),
318                    index_str,
319                    value.to_string(),
320                    |_| Ok(()),
321                    env::EnvironmentLookup::Anywhere,
322                    env::EnvironmentScope::Global,
323                )
324                .map_err(|_err| EvalError::FailedToUpdateEnvironment)?;
325        }
326    }
327
328    Ok(value)
329}
330
331const fn bool_to_i64(value: bool) -> i64 {
332    if value { 1 } else { 0 }
333}
334
335// N.B. We implement our own version of wrapping_pow that takes a 64-bit exponent.
336// This seems to be the best way to guarantee that we handle overflow cases
337// with exponents correctly.
338const fn wrapping_pow_u64(mut base: i64, mut exponent: u64) -> i64 {
339    let mut result: i64 = 1;
340
341    while exponent > 0 {
342        if exponent % 2 == 1 {
343            result = result.wrapping_mul(base);
344        }
345
346        base = base.wrapping_mul(base);
347        exponent /= 2;
348    }
349
350    result
351}