use cel::{Context, ExecutionError, ResolveResult, extractors::Arguments, objects::Value};
pub fn register(ctx: &mut Context<'_>) {
ctx.add_function("math.ceil", math_ceil);
ctx.add_function("math.floor", math_floor);
ctx.add_function("math.round", math_round);
ctx.add_function("math.trunc", math_trunc);
ctx.add_function("math.abs", math_abs);
ctx.add_function("math.sign", math_sign);
ctx.add_function("math.isInf", math_is_inf);
ctx.add_function("math.isNaN", math_is_nan);
ctx.add_function("math.isFinite", math_is_finite);
ctx.add_function("math.bitAnd", math_bit_and);
ctx.add_function("math.bitOr", math_bit_or);
ctx.add_function("math.bitXor", math_bit_xor);
ctx.add_function("math.bitNot", math_bit_not);
ctx.add_function("math.bitShiftLeft", math_bit_shift_left);
ctx.add_function("math.bitShiftRight", math_bit_shift_right);
ctx.add_function("math.sqrt", math_sqrt);
ctx.add_function("math.greatest", math_greatest);
ctx.add_function("math.least", math_least);
}
fn math_ceil(v: f64) -> ResolveResult {
Ok(Value::Float(v.ceil()))
}
fn math_floor(v: f64) -> ResolveResult {
Ok(Value::Float(v.floor()))
}
fn math_round(v: f64) -> ResolveResult {
Ok(Value::Float(v.round()))
}
fn math_trunc(v: f64) -> ResolveResult {
Ok(Value::Float(v.trunc()))
}
fn math_abs(Arguments(args): Arguments) -> ResolveResult {
let arg = args
.first()
.ok_or_else(|| ExecutionError::function_error("math.abs", "missing argument"))?;
match arg {
Value::Int(n) => {
Ok(Value::Int(n.checked_abs().ok_or_else(|| {
ExecutionError::function_error("math.abs", "integer overflow")
})?))
}
Value::UInt(n) => Ok(Value::UInt(*n)),
Value::Float(f) => Ok(Value::Float(f.abs())),
_ => Err(ExecutionError::function_error(
"math.abs",
"expected int, uint, or double",
)),
}
}
fn math_sign(Arguments(args): Arguments) -> ResolveResult {
let arg = args
.first()
.ok_or_else(|| ExecutionError::function_error("math.sign", "missing argument"))?;
match arg {
Value::Int(n) => Ok(Value::Int(n.signum())),
Value::UInt(n) => Ok(Value::UInt(if *n == 0 { 0 } else { 1 })),
Value::Float(f) => {
if f.is_nan() {
Ok(Value::Float(f64::NAN))
} else if *f == 0.0 {
Ok(Value::Float(0.0))
} else {
Ok(Value::Float(f.signum()))
}
}
_ => Err(ExecutionError::function_error(
"math.sign",
"expected int, uint, or double",
)),
}
}
fn math_is_inf(v: f64) -> ResolveResult {
Ok(Value::Bool(v.is_infinite()))
}
fn math_is_nan(v: f64) -> ResolveResult {
Ok(Value::Bool(v.is_nan()))
}
fn math_is_finite(v: f64) -> ResolveResult {
Ok(Value::Bool(v.is_finite()))
}
fn math_bit_and(a: i64, b: i64) -> ResolveResult {
Ok(Value::Int(a & b))
}
fn math_bit_or(a: i64, b: i64) -> ResolveResult {
Ok(Value::Int(a | b))
}
fn math_bit_xor(a: i64, b: i64) -> ResolveResult {
Ok(Value::Int(a ^ b))
}
fn math_bit_not(v: i64) -> ResolveResult {
Ok(Value::Int(!v))
}
fn math_bit_shift_left(v: i64, shift: i64) -> ResolveResult {
if !(0..=63).contains(&shift) {
return Err(ExecutionError::function_error(
"math.bitShiftLeft",
"shift amount must be between 0 and 63",
));
}
Ok(Value::Int(v << shift))
}
fn math_bit_shift_right(v: i64, shift: i64) -> ResolveResult {
if !(0..=63).contains(&shift) {
return Err(ExecutionError::function_error(
"math.bitShiftRight",
"shift amount must be between 0 and 63",
));
}
Ok(Value::Int(v >> shift))
}
fn math_sqrt(v: f64) -> ResolveResult {
Ok(Value::Float(v.sqrt()))
}
fn numeric_cmp(a: &Value, b: &Value) -> Result<std::cmp::Ordering, ExecutionError> {
let fa = to_f64(a)?;
let fb = to_f64(b)?;
fa.partial_cmp(&fb)
.ok_or_else(|| ExecutionError::function_error("math.greatest/least", "cannot compare NaN values"))
}
fn to_f64(v: &Value) -> Result<f64, ExecutionError> {
match v {
Value::Int(n) => Ok(*n as f64),
Value::UInt(n) => Ok(*n as f64),
Value::Float(f) => Ok(*f),
_ => Err(ExecutionError::function_error(
"math.greatest/least",
"expected numeric argument",
)),
}
}
fn math_greatest(Arguments(args): Arguments) -> ResolveResult {
math_extremum(&args, "math.greatest", std::cmp::Ordering::Greater)
}
fn math_least(Arguments(args): Arguments) -> ResolveResult {
math_extremum(&args, "math.least", std::cmp::Ordering::Less)
}
fn math_extremum(args: &[Value], name: &str, target_ord: std::cmp::Ordering) -> ResolveResult {
if args.is_empty() {
return Err(ExecutionError::function_error(
name,
"at least one argument required",
));
}
let effective: &[Value] = if args.len() == 1 {
if let Value::List(list) = &args[0] {
if list.is_empty() {
return Err(ExecutionError::function_error(
name,
"at least one argument required",
));
}
list.as_ref()
} else {
args
}
} else {
args
};
let mut result = effective[0].clone();
for item in &effective[1..] {
if numeric_cmp(item, &result)? == target_ord {
result = item.clone();
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use cel::Program;
fn eval(expr: &str) -> Value {
let mut ctx = Context::default();
register(&mut ctx);
Program::compile(expr).unwrap().execute(&ctx).unwrap()
}
fn eval_err(expr: &str) -> ExecutionError {
let mut ctx = Context::default();
register(&mut ctx);
Program::compile(expr).unwrap().execute(&ctx).unwrap_err()
}
#[test]
fn test_ceil() {
assert_eq!(eval("math.ceil(1.2)"), Value::Float(2.0));
assert_eq!(eval("math.ceil(-1.8)"), Value::Float(-1.0));
assert_eq!(eval("math.ceil(2.0)"), Value::Float(2.0));
}
#[test]
fn test_floor() {
assert_eq!(eval("math.floor(1.8)"), Value::Float(1.0));
assert_eq!(eval("math.floor(-1.2)"), Value::Float(-2.0));
assert_eq!(eval("math.floor(2.0)"), Value::Float(2.0));
}
#[test]
fn test_round() {
assert_eq!(eval("math.round(1.5)"), Value::Float(2.0));
assert_eq!(eval("math.round(1.4)"), Value::Float(1.0));
assert_eq!(eval("math.round(-1.5)"), Value::Float(-2.0));
}
#[test]
fn test_trunc() {
assert_eq!(eval("math.trunc(1.9)"), Value::Float(1.0));
assert_eq!(eval("math.trunc(-1.9)"), Value::Float(-1.0));
}
#[test]
fn test_abs_int() {
assert_eq!(eval("math.abs(-5)"), Value::Int(5));
assert_eq!(eval("math.abs(5)"), Value::Int(5));
assert_eq!(eval("math.abs(0)"), Value::Int(0));
}
#[test]
fn test_abs_float() {
assert_eq!(eval("math.abs(-3.14)"), Value::Float(3.14));
assert_eq!(eval("math.abs(3.14)"), Value::Float(3.14));
}
#[test]
fn test_sign_int() {
assert_eq!(eval("math.sign(-3)"), Value::Int(-1));
assert_eq!(eval("math.sign(0)"), Value::Int(0));
assert_eq!(eval("math.sign(5)"), Value::Int(1));
}
#[test]
fn test_sign_float() {
assert_eq!(eval("math.sign(-3.0)"), Value::Float(-1.0));
assert_eq!(eval("math.sign(0.0)"), Value::Float(0.0));
assert_eq!(eval("math.sign(5.0)"), Value::Float(1.0));
}
#[test]
fn test_is_inf() {
assert_eq!(eval("math.isInf(1.0 / 0.0)"), Value::Bool(true));
assert_eq!(eval("math.isInf(1.0)"), Value::Bool(false));
}
#[test]
fn test_is_nan() {
assert_eq!(eval("math.isNaN(0.0 / 0.0)"), Value::Bool(true));
assert_eq!(eval("math.isNaN(1.0)"), Value::Bool(false));
}
#[test]
fn test_is_finite() {
assert_eq!(eval("math.isFinite(1.0)"), Value::Bool(true));
assert_eq!(eval("math.isFinite(1.0 / 0.0)"), Value::Bool(false));
assert_eq!(eval("math.isFinite(0.0 / 0.0)"), Value::Bool(false));
}
#[test]
fn test_bit_and() {
assert_eq!(eval("math.bitAnd(3, 5)"), Value::Int(1));
}
#[test]
fn test_bit_or() {
assert_eq!(eval("math.bitOr(3, 5)"), Value::Int(7));
}
#[test]
fn test_bit_xor() {
assert_eq!(eval("math.bitXor(3, 5)"), Value::Int(6));
}
#[test]
fn test_bit_not() {
assert_eq!(eval("math.bitNot(0)"), Value::Int(-1));
}
#[test]
fn test_bit_shift_left() {
assert_eq!(eval("math.bitShiftLeft(1, 3)"), Value::Int(8));
}
#[test]
fn test_bit_shift_right() {
assert_eq!(eval("math.bitShiftRight(8, 3)"), Value::Int(1));
}
#[test]
fn test_bit_shift_invalid() {
eval_err("math.bitShiftLeft(1, -1)");
eval_err("math.bitShiftLeft(1, 64)");
eval_err("math.bitShiftRight(1, -1)");
}
#[test]
fn test_sqrt() {
assert_eq!(eval("math.sqrt(4.0)"), Value::Float(2.0));
assert_eq!(eval("math.sqrt(0.0)"), Value::Float(0.0));
assert_eq!(eval("math.sqrt(2.0)"), Value::Float(std::f64::consts::SQRT_2));
assert_eq!(eval("math.sqrt(49.0)"), Value::Float(7.0));
}
#[test]
fn test_sqrt_negative_returns_nan() {
let result = eval("math.sqrt(-1.0)");
match result {
Value::Float(f) => assert!(f.is_nan()),
other => panic!("expected Float(NaN), got {other:?}"),
}
}
#[test]
fn test_sqrt_nan_check() {
assert_eq!(eval("math.isNaN(math.sqrt(-15.34))"), Value::Bool(true));
}
#[test]
fn test_abs_overflow() {
eval_err("math.abs(-9223372036854775808)");
}
#[test]
fn test_greatest() {
assert_eq!(eval("math.greatest(1, 3, 2)"), Value::Int(3));
assert_eq!(eval("math.greatest(1.0, 3.0, 2.0)"), Value::Float(3.0));
}
#[test]
fn test_least() {
assert_eq!(eval("math.least(1, 3, 2)"), Value::Int(1));
assert_eq!(eval("math.least(1.0, 3.0, 2.0)"), Value::Float(1.0));
}
#[test]
fn test_greatest_single() {
assert_eq!(eval("math.greatest(42)"), Value::Int(42));
}
#[test]
fn test_least_single() {
assert_eq!(eval("math.least(42)"), Value::Int(42));
}
#[test]
fn test_greatest_single_float() {
assert_eq!(eval("math.greatest(-0.5)"), Value::Float(-0.5));
}
#[test]
fn test_least_single_float() {
assert_eq!(eval("math.least(-0.5)"), Value::Float(-0.5));
}
#[test]
fn test_is_nan_negated() {
assert_eq!(eval("!math.isNaN(1.0)"), Value::Bool(true));
}
#[test]
fn test_is_inf_negated() {
assert_eq!(eval("!math.isInf(1.0)"), Value::Bool(true));
assert_eq!(eval("!math.isFinite(1.0 / 0.0)"), Value::Bool(true));
}
}