Skip to main content

openjd_expr/functions/
math.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5//! Math function implementations (min, max, floor, ceil, round, sum).
6
7use crate::error::ExpressionError;
8use crate::function_library::EvalContext;
9use crate::value::{ExprValue, Float64};
10
11type R = Result<ExprValue, ExpressionError>;
12type Ctx<'a> = &'a mut dyn EvalContext;
13
14fn min_max_items(a: &[ExprValue], name: &str) -> Result<Vec<ExprValue>, ExpressionError> {
15    if a.is_empty() {
16        return Err(ExpressionError::new(format!(
17            "{name}() requires at least 1 argument"
18        )));
19    }
20    if a.len() == 1 {
21        match &a[0] {
22            val if val.is_list() => {
23                let elements: Vec<ExprValue> =
24                    val.list_iter().expect("guard ensures list").collect();
25                if elements.is_empty() {
26                    return Err(ExpressionError::new(format!(
27                        "{name}() requires a non-empty list"
28                    )));
29                }
30                Ok(elements)
31            }
32            ExprValue::RangeExpr(r) => {
33                if r.is_empty() {
34                    return Err(ExpressionError::new(format!(
35                        "{name}() requires a non-empty list"
36                    )));
37                }
38                if name == "min" {
39                    return Ok(vec![ExprValue::Int(r.iter().next().unwrap())]);
40                } else {
41                    Ok(vec![ExprValue::Int(r.get(r.len() as i64 - 1).unwrap())])
42                }
43            }
44            _ => Ok(a.to_vec()),
45        }
46    } else {
47        Ok(a.to_vec())
48    }
49}
50
51pub fn min_fn(ctx: Ctx, a: &[ExprValue]) -> R {
52    let items = min_max_items(a, "min")?;
53    ctx.count_ops(items.len())?;
54    let mut result = items[0].clone();
55    for item in &items[1..] {
56        if result.compare(item)?.is_gt() {
57            result = item.clone();
58        }
59    }
60    if items.iter().any(|i| matches!(i, ExprValue::Float(_))) {
61        if let ExprValue::Int(i) = &result {
62            return Ok(ExprValue::Float(Float64::new(*i as f64)?));
63        }
64    }
65    Ok(result)
66}
67
68pub fn max_fn(ctx: Ctx, a: &[ExprValue]) -> R {
69    let items = min_max_items(a, "max")?;
70    ctx.count_ops(items.len())?;
71    let mut result = items[0].clone();
72    for item in &items[1..] {
73        if result.compare(item)?.is_lt() {
74            result = item.clone();
75        }
76    }
77    if items.iter().any(|i| matches!(i, ExprValue::Float(_))) {
78        if let ExprValue::Int(i) = &result {
79            return Ok(ExprValue::Float(Float64::new(*i as f64)?));
80        }
81    }
82    Ok(result)
83}
84
85fn round_half_even(x: f64) -> f64 {
86    let rounded = x.round();
87    if (x - rounded).abs() == 0.5 {
88        if rounded as i64 % 2 != 0 {
89            rounded - x.signum()
90        } else {
91            rounded
92        }
93    } else {
94        rounded
95    }
96}
97
98pub fn floor_float(_: Ctx, a: &[ExprValue]) -> R {
99    match &a[0] {
100        ExprValue::Float(f) => {
101            let v = f.floor();
102            if v.abs() > i64::MAX as f64 {
103                return Err(ExpressionError::integer_overflow());
104            }
105            Ok(ExprValue::Int(v as i64))
106        }
107        _ => Err(ExpressionError::type_error("type error")),
108    }
109}
110
111pub fn floor_int(_: Ctx, a: &[ExprValue]) -> R {
112    match &a[0] {
113        ExprValue::Int(i) => Ok(ExprValue::Int(*i)),
114        _ => Err(ExpressionError::type_error("type error")),
115    }
116}
117
118pub fn ceil_float(_: Ctx, a: &[ExprValue]) -> R {
119    match &a[0] {
120        ExprValue::Float(f) => {
121            let v = f.ceil();
122            if v.abs() > i64::MAX as f64 {
123                return Err(ExpressionError::integer_overflow());
124            }
125            Ok(ExprValue::Int(v as i64))
126        }
127        _ => Err(ExpressionError::type_error("type error")),
128    }
129}
130
131pub fn ceil_int(_: Ctx, a: &[ExprValue]) -> R {
132    match &a[0] {
133        ExprValue::Int(i) => Ok(ExprValue::Int(*i)),
134        _ => Err(ExpressionError::type_error("type error")),
135    }
136}
137
138pub fn round_fn(_: Ctx, a: &[ExprValue]) -> R {
139    match &a[0] {
140        ExprValue::Float(f) => {
141            let has_ndigits = a.len() > 1;
142            let ndigits = a
143                .get(1)
144                .and_then(|v| match v {
145                    ExprValue::Int(n) => Some(*n),
146                    _ => None,
147                })
148                .unwrap_or(0);
149            if !has_ndigits {
150                let v = round_half_even(f.value());
151                if v.abs() > i64::MAX as f64 {
152                    return Err(ExpressionError::integer_overflow());
153                }
154                Ok(ExprValue::Int(v as i64))
155            } else if ndigits >= 0 {
156                let factor = 10f64.powi(ndigits as i32);
157                let rounded = round_half_even(f.value() * factor) / factor;
158                if ndigits == 0 {
159                    Ok(ExprValue::Float(Float64::with_str(
160                        rounded,
161                        format!("{}.0", rounded as i64),
162                    )?))
163                } else {
164                    Ok(ExprValue::Float(Float64::with_str(
165                        rounded,
166                        format!("{:.prec$}", rounded, prec = ndigits as usize),
167                    )?))
168                }
169            } else {
170                let factor = 10f64.powi((-ndigits) as i32);
171                Ok(ExprValue::Float(Float64::new(
172                    round_half_even(f.value() / factor) * factor,
173                )?))
174            }
175        }
176        ExprValue::Int(i) => {
177            let ndigits = a
178                .get(1)
179                .and_then(|v| match v {
180                    ExprValue::Int(n) => Some(*n),
181                    _ => None,
182                })
183                .unwrap_or(0);
184            if ndigits >= 0 {
185                Ok(ExprValue::Int(*i))
186            } else {
187                let factor = 10f64.powi((-ndigits) as i32);
188                let v = round_half_even(*i as f64 / factor) * factor;
189                if v.abs() > i64::MAX as f64 {
190                    return Err(ExpressionError::integer_overflow());
191                }
192                Ok(ExprValue::Int(v as i64))
193            }
194        }
195        _ => Err(ExpressionError::new("round() requires numeric argument")),
196    }
197}
198
199pub fn sum_list(ctx: Ctx, a: &[ExprValue]) -> R {
200    if let Some(iter) = a[0].list_iter() {
201        let mut int_sum: i64 = 0;
202        let mut is_float = false;
203        let mut float_sum: f64 = 0.0;
204        for e in iter {
205            ctx.count_op()?;
206            match e {
207                ExprValue::Int(i) => {
208                    int_sum = int_sum
209                        .checked_add(i)
210                        .ok_or_else(ExpressionError::integer_overflow)?;
211                    float_sum += i as f64;
212                }
213                ExprValue::Float(f) => {
214                    is_float = true;
215                    float_sum += f.value();
216                }
217                _ => return Err(ExpressionError::new("sum() elements must be numeric")),
218            }
219        }
220        if is_float {
221            Ok(ExprValue::Float(Float64::new(float_sum)?))
222        } else {
223            Ok(ExprValue::Int(int_sum))
224        }
225    } else if let ExprValue::RangeExpr(r) = &a[0] {
226        for _ in r.iter() {
227            ctx.count_op()?;
228        }
229        Ok(ExprValue::Int(r.iter().sum()))
230    } else {
231        Err(ExpressionError::new("sum() requires list or range_expr"))
232    }
233}