Skip to main content

openjd_expr/functions/
arithmetic.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5//! Arithmetic operator implementations.
6
7use crate::error::ExpressionError;
8use crate::function_library::EvalContext;
9use crate::types::ExprType;
10use crate::value::{ExprValue, Float64};
11
12type R = Result<ExprValue, ExpressionError>;
13type Ctx<'a> = &'a mut dyn EvalContext;
14
15// ── Integer arithmetic ──
16
17pub fn add_int(_: Ctx, a: &[ExprValue]) -> R {
18    match (&a[0], &a[1]) {
19        (ExprValue::Int(l), ExprValue::Int(r)) => Ok(ExprValue::Int(
20            l.checked_add(*r)
21                .ok_or_else(ExpressionError::integer_overflow)?,
22        )),
23        _ => Err(ExpressionError::type_error("type error")),
24    }
25}
26
27pub fn sub_int(_: Ctx, a: &[ExprValue]) -> R {
28    match (&a[0], &a[1]) {
29        (ExprValue::Int(l), ExprValue::Int(r)) => Ok(ExprValue::Int(
30            l.checked_sub(*r)
31                .ok_or_else(ExpressionError::integer_overflow)?,
32        )),
33        _ => Err(ExpressionError::type_error("type error")),
34    }
35}
36
37pub fn mul_int(_: Ctx, a: &[ExprValue]) -> R {
38    match (&a[0], &a[1]) {
39        (ExprValue::Int(l), ExprValue::Int(r)) => Ok(ExprValue::Int(
40            l.checked_mul(*r)
41                .ok_or_else(ExpressionError::integer_overflow)?,
42        )),
43        _ => Err(ExpressionError::type_error("type error")),
44    }
45}
46
47pub fn truediv_int(_: Ctx, a: &[ExprValue]) -> R {
48    match (&a[0], &a[1]) {
49        (ExprValue::Int(l), ExprValue::Int(r)) => {
50            if *r == 0 {
51                return Err(ExpressionError::division_by_zero("Division"));
52            }
53            Ok(ExprValue::Float(Float64::new(*l as f64 / *r as f64)?))
54        }
55        _ => Err(ExpressionError::type_error("type error")),
56    }
57}
58
59pub fn floordiv_int(_: Ctx, a: &[ExprValue]) -> R {
60    match (&a[0], &a[1]) {
61        (ExprValue::Int(l), ExprValue::Int(r)) => {
62            if *r == 0 {
63                return Err(ExpressionError::division_by_zero("Division"));
64            }
65            let d = l
66                .checked_div(*r)
67                .ok_or_else(ExpressionError::integer_overflow)?;
68            // Python uses floored division (toward -∞), not truncated (toward 0).
69            // Adjust when the remainder is nonzero and operands have different signs.
70            let result = if (l ^ r) < 0 && d * r != *l { d - 1 } else { d };
71            Ok(ExprValue::Int(result))
72        }
73        _ => Err(ExpressionError::type_error("type error")),
74    }
75}
76
77pub fn mod_int(_: Ctx, a: &[ExprValue]) -> R {
78    match (&a[0], &a[1]) {
79        (ExprValue::Int(l), ExprValue::Int(r)) => {
80            if *r == 0 {
81                return Err(ExpressionError::division_by_zero("Modulo"));
82            }
83            if *r == 1 || *r == -1 {
84                return Ok(ExprValue::Int(0));
85            }
86            let rem = l
87                .checked_rem(*r)
88                .ok_or_else(ExpressionError::integer_overflow)?;
89            // Python uses floored modulo: result sign matches divisor sign.
90            let result = if rem != 0 && (rem ^ r) < 0 {
91                rem + r
92            } else {
93                rem
94            };
95            Ok(ExprValue::Int(result))
96        }
97        _ => Err(ExpressionError::type_error("type error")),
98    }
99}
100
101pub fn pow_int(_: Ctx, a: &[ExprValue]) -> R {
102    match (&a[0], &a[1]) {
103        (ExprValue::Int(base), ExprValue::Int(exp)) => {
104            if *exp < 0 {
105                if *base == 0 {
106                    return Err(ExpressionError::float_error(
107                        "Cannot raise zero to a negative power",
108                    ));
109                }
110                let exp32 = i32::try_from(*exp).unwrap_or(i32::MIN);
111                return Ok(ExprValue::Float(Float64::new((*base as f64).powi(exp32))?));
112            }
113            // Guard: exponent > 63 with |base| > 1 always overflows i64
114            if *exp > 63 && !matches!(*base, -1..=1) {
115                return Err(ExpressionError::integer_overflow());
116            }
117            // Special-case base ∈ {-1, 0, 1} to avoid u32 truncation of large exponents
118            if *exp > u32::MAX as i64 {
119                return Ok(ExprValue::Int(match *base {
120                    0 => 0,
121                    1 => 1,
122                    -1 => {
123                        if *exp % 2 == 0 {
124                            1
125                        } else {
126                            -1
127                        }
128                    }
129                    _ => unreachable!(),
130                }));
131            }
132            Ok(ExprValue::Int(
133                base.checked_pow(*exp as u32)
134                    .ok_or_else(ExpressionError::integer_overflow)?,
135            ))
136        }
137        _ => Err(ExpressionError::type_error("type error")),
138    }
139}
140
141pub fn neg_int(_: Ctx, a: &[ExprValue]) -> R {
142    match &a[0] {
143        ExprValue::Int(n) => Ok(ExprValue::Int(
144            n.checked_neg()
145                .ok_or_else(ExpressionError::integer_overflow)?,
146        )),
147        _ => Err(ExpressionError::type_error("type error")),
148    }
149}
150
151pub fn pos_int(_: Ctx, a: &[ExprValue]) -> R {
152    match &a[0] {
153        ExprValue::Int(n) => Ok(ExprValue::Int(*n)),
154        _ => Err(ExpressionError::type_error("type error")),
155    }
156}
157
158// ── Float arithmetic ──
159
160pub fn add_float(_: Ctx, a: &[ExprValue]) -> R {
161    let (l, r) = get_two_floats(a)?;
162    Ok(ExprValue::Float(Float64::new(l + r)?))
163}
164
165pub fn sub_float(_: Ctx, a: &[ExprValue]) -> R {
166    let (l, r) = get_two_floats(a)?;
167    Ok(ExprValue::Float(Float64::new(l - r)?))
168}
169
170pub fn mul_float(_: Ctx, a: &[ExprValue]) -> R {
171    let (l, r) = get_two_floats(a)?;
172    Ok(ExprValue::Float(Float64::new(l * r)?))
173}
174
175pub fn truediv_float(_: Ctx, a: &[ExprValue]) -> R {
176    let (l, r) = get_two_floats(a)?;
177    if r == 0.0 {
178        return Err(ExpressionError::division_by_zero("Division"));
179    }
180    Ok(ExprValue::Float(Float64::new(l / r)?))
181}
182
183pub fn floordiv_float(_: Ctx, a: &[ExprValue]) -> R {
184    let (l, r) = get_two_floats(a)?;
185    if r == 0.0 {
186        return Err(ExpressionError::division_by_zero("Division"));
187    }
188    let v = (l / r).floor();
189    if v.abs() > i64::MAX as f64 {
190        return Err(ExpressionError::integer_overflow());
191    }
192    Ok(ExprValue::Int(v as i64))
193}
194
195pub fn mod_float(_: Ctx, a: &[ExprValue]) -> R {
196    let (l, r) = get_two_floats(a)?;
197    if r == 0.0 {
198        return Err(ExpressionError::division_by_zero("Modulo"));
199    }
200    // Python uses floored modulo: l - r * floor(l / r)
201    Ok(ExprValue::Float(Float64::new(l - r * (l / r).floor())?))
202}
203
204pub fn pow_float(_: Ctx, a: &[ExprValue]) -> R {
205    let (l, r) = get_two_floats(a)?;
206    if l == 0.0 && r < 0.0 {
207        return Err(ExpressionError::float_error(
208            "Cannot raise zero to a negative power",
209        ));
210    }
211    if l < 0.0 && r.fract() != 0.0 {
212        return Err(ExpressionError::float_error(format!(
213            "Cannot compute {} ** {} (would produce complex number)",
214            l, r
215        )));
216    }
217    let result = l.powf(r);
218    if result.is_infinite() {
219        return Err(ExpressionError::float_error(format!(
220            "Overflow computing {} ** {} (result too large for float)",
221            l, r
222        )));
223    }
224    Ok(ExprValue::Float(Float64::new(result)?))
225}
226
227pub fn neg_float(_: Ctx, a: &[ExprValue]) -> R {
228    match &a[0] {
229        ExprValue::Float(n) => Ok(ExprValue::Float(Float64::new(-n.value())?)),
230        _ => Err(ExpressionError::type_error("type error")),
231    }
232}
233
234pub fn pos_float(_: Ctx, a: &[ExprValue]) -> R {
235    match &a[0] {
236        ExprValue::Float(n) => Ok(ExprValue::Float(Float64::new(n.value())?)),
237        _ => Err(ExpressionError::type_error("type error")),
238    }
239}
240
241// ── String operators ──
242
243pub fn add_string(ctx: Ctx, a: &[ExprValue]) -> R {
244    match (&a[0], &a[1]) {
245        (ExprValue::String(l), ExprValue::String(r)) => {
246            ctx.count_string_ops(l.len() + r.len())?;
247            Ok(ExprValue::String(format!("{l}{r}")))
248        }
249        _ => Err(ExpressionError::type_error("type error")),
250    }
251}
252
253pub fn add_string_range(_: Ctx, a: &[ExprValue]) -> R {
254    match (&a[0], &a[1]) {
255        (ExprValue::String(l), ExprValue::RangeExpr(r)) => Ok(ExprValue::String(format!("{l}{r}"))),
256        _ => Err(ExpressionError::type_error("type error")),
257    }
258}
259
260pub fn add_range_string(_: Ctx, a: &[ExprValue]) -> R {
261    match (&a[0], &a[1]) {
262        (ExprValue::RangeExpr(l), ExprValue::String(r)) => Ok(ExprValue::String(format!("{l}{r}"))),
263        _ => Err(ExpressionError::type_error("type error")),
264    }
265}
266
267pub fn mul_string(ctx: Ctx, a: &[ExprValue]) -> R {
268    match (&a[0], &a[1]) {
269        (ExprValue::String(s), ExprValue::Int(n)) => {
270            if *n < 0 {
271                return Ok(ExprValue::String(String::new()));
272            }
273            let result_len = s.len() * (*n as usize);
274            ctx.count_string_ops(result_len)?;
275            ctx.check_memory(result_len)?;
276            Ok(ExprValue::String(s.repeat(*n as usize)))
277        }
278        _ => Err(ExpressionError::type_error("type error")),
279    }
280}
281
282// ── Path operators ──
283
284pub fn path_div(ctx: Ctx, a: &[ExprValue]) -> R {
285    let (l, format) = match &a[0] {
286        ExprValue::Path { value, format } => (value.as_str(), *format),
287        ExprValue::String(s) => (s.as_str(), ctx.path_format()),
288        _ => return Err(ExpressionError::type_error("type error")),
289    };
290    let r = match &a[1] {
291        ExprValue::Path { value, .. } | ExprValue::String(value) => value.as_str(),
292        _ => return Err(ExpressionError::type_error("type error")),
293    };
294    ctx.count_string_ops(l.len() + r.len())?;
295    Ok(ExprValue::new_path(super::path::join(l, r, format), format))
296}
297
298pub fn add_path_string(ctx: Ctx, a: &[ExprValue]) -> R {
299    match (&a[0], &a[1]) {
300        (ExprValue::Path { value: l, format }, ExprValue::String(r)) => {
301            ctx.count_string_ops(l.len() + r.len())?;
302            Ok(ExprValue::new_path(format!("{l}{r}"), *format))
303        }
304        _ => Err(ExpressionError::type_error("type error")),
305    }
306}
307
308// ── List operators ──
309
310pub fn add_range_range(ctx: Ctx, a: &[ExprValue]) -> R {
311    match (&a[0], &a[1]) {
312        (ExprValue::RangeExpr(l), ExprValue::RangeExpr(r)) => {
313            ctx.count_ops(l.len() + r.len())?;
314            let elements: Vec<ExprValue> = l.iter().chain(r.iter()).map(ExprValue::Int).collect();
315            Ok(ExprValue::make_list_checked(ctx, elements, ExprType::INT)?)
316        }
317        _ => Err(ExpressionError::type_error("type error")),
318    }
319}
320
321pub fn mul_list(ctx: Ctx, a: &[ExprValue]) -> R {
322    let (elements, elem_type) = a[0]
323        .clone()
324        .into_list()
325        .ok_or_else(|| ExpressionError::type_error("type error"))?;
326    let n = match &a[1] {
327        ExprValue::Int(n) => *n,
328        _ => return Err(ExpressionError::type_error("type error")),
329    };
330    if n <= 0 {
331        return ExprValue::make_list_checked(ctx, Vec::new(), elem_type);
332    }
333    let result_len = elements.len() * n as usize;
334    for _ in 0..result_len {
335        ctx.count_op()?;
336    }
337    let mut result = Vec::new();
338    for _ in 0..n {
339        result.extend(elements.iter().cloned());
340    }
341    ExprValue::make_list_checked(ctx, result, elem_type)
342}
343
344pub fn add_list_list(ctx: Ctx, a: &[ExprValue]) -> R {
345    let (l, lt) = a[0]
346        .clone()
347        .into_list()
348        .ok_or_else(|| ExpressionError::type_error("type error"))?;
349    let (r, rt) = a[1]
350        .clone()
351        .into_list()
352        .ok_or_else(|| ExpressionError::type_error("type error"))?;
353    ctx.count_ops(l.len() + r.len())?;
354    if lt != rt
355        && lt != ExprType::NULLTYPE
356        && rt != ExprType::NULLTYPE
357        && !((lt == ExprType::INT && rt == ExprType::FLOAT)
358            || (lt == ExprType::FLOAT && rt == ExprType::INT))
359        && !((lt == ExprType::PATH && rt == ExprType::STRING)
360            || (lt == ExprType::STRING && rt == ExprType::PATH))
361    {
362        return Err(ExpressionError::type_error(format!(
363            "Cannot concatenate list[{lt}] and list[{rt}]"
364        )));
365    }
366    let mut combined = l;
367    combined.extend(r);
368    let result_type = if lt == ExprType::NULLTYPE { rt } else { lt };
369    ExprValue::make_list_checked(ctx, combined, result_type)
370}
371
372pub fn add_list_range(ctx: Ctx, a: &[ExprValue]) -> R {
373    let (mut l, et) = a[0]
374        .clone()
375        .into_list()
376        .ok_or_else(|| ExpressionError::type_error("type error"))?;
377    let r = match &a[1] {
378        ExprValue::RangeExpr(r) => r,
379        _ => return Err(ExpressionError::type_error("type error")),
380    };
381    ctx.count_ops(l.len() + r.len())?;
382    l.extend(r.iter().map(ExprValue::Int));
383    ExprValue::make_list_checked(ctx, l, et)
384}
385
386pub fn add_range_list(ctx: Ctx, a: &[ExprValue]) -> R {
387    let r = match &a[0] {
388        ExprValue::RangeExpr(r) => r,
389        _ => return Err(ExpressionError::type_error("type error")),
390    };
391    let (l, et) = a[1]
392        .clone()
393        .into_list()
394        .ok_or_else(|| ExpressionError::type_error("type error"))?;
395    ctx.count_ops(r.len() + l.len())?;
396    let mut combined: Vec<ExprValue> = r.iter().map(ExprValue::Int).collect();
397    combined.extend(l);
398    ExprValue::make_list_checked(ctx, combined, et)
399}
400
401// ── Comparison operators ──
402
403pub fn not_bool(_: Ctx, a: &[ExprValue]) -> R {
404    match &a[0] {
405        ExprValue::Bool(b) => Ok(ExprValue::Bool(!*b)),
406        _ => Err(ExpressionError::type_error("type error")),
407    }
408}
409
410// ── Helpers ──
411
412fn get_two_floats(a: &[ExprValue]) -> Result<(f64, f64), ExpressionError> {
413    let l = match &a[0] {
414        ExprValue::Float(f) => f.value(),
415        ExprValue::Int(i) => *i as f64,
416        _ => return Err(ExpressionError::type_error("type error")),
417    };
418    let r = match &a[1] {
419        ExprValue::Float(f) => f.value(),
420        ExprValue::Int(i) => *i as f64,
421        _ => return Err(ExpressionError::type_error("type error")),
422    };
423    Ok((l, r))
424}