use rhai::{Array, Dynamic, Engine, EvalAltResult};
use super::arrays::{determine_array_type, ArrayType};
pub fn register_functions(engine: &mut Engine) {
engine.register_fn("mod", |a: i64, b: i64| -> i64 {
if b == 0 {
0 } else {
a % b
}
});
engine.register_fn("%", |a: i64, b: i64| -> i64 {
if b == 0 {
0 } else {
a % b
}
});
engine.register_fn("clamp", clamp_i64);
engine.register_fn("clamp", clamp_f64);
engine.register_fn("sum", sum_array);
engine.register_fn("mean", mean_array);
engine.register_fn("variance", variance_array);
engine.register_fn("stddev", stddev_array);
}
fn extract_numeric_values(arr: &Array) -> Vec<f64> {
arr.iter()
.map(|val| {
if val.is_int() {
val.as_int().unwrap() as f64
} else if val.is_float() {
val.as_float().unwrap()
} else {
0.0
}
})
.collect()
}
fn sum_array(arr: Array) -> Dynamic {
if arr.is_empty() {
return Dynamic::UNIT;
}
let array_type = determine_array_type(&arr);
match array_type {
ArrayType::Empty => Dynamic::UNIT,
ArrayType::Mixed => Dynamic::UNIT, ArrayType::String => Dynamic::UNIT, ArrayType::Numeric => {
let sum: f64 = extract_numeric_values(&arr).iter().sum();
Dynamic::from(sum)
}
}
}
fn mean_array(arr: Array) -> Result<f64, Box<EvalAltResult>> {
if arr.is_empty() {
return Err("Cannot calculate mean of empty array".into());
}
let array_type = determine_array_type(&arr);
match array_type {
ArrayType::Empty => Err("Cannot calculate mean of empty array".into()),
ArrayType::Mixed => Err("Cannot calculate mean of array with mixed types (numbers and non-numbers). Use pluck_as_nums() for type coercion.".into()),
ArrayType::String => Err("Cannot calculate mean of array with non-numeric values".into()),
ArrayType::Numeric => {
let values = extract_numeric_values(&arr);
let sum: f64 = values.iter().sum();
Ok(sum / values.len() as f64)
}
}
}
fn variance_array(arr: Array) -> Result<f64, Box<EvalAltResult>> {
if arr.is_empty() {
return Err("Cannot calculate variance of empty array".into());
}
let array_type = determine_array_type(&arr);
match array_type {
ArrayType::Empty => Err("Cannot calculate variance of empty array".into()),
ArrayType::Mixed => Err("Cannot calculate variance of array with mixed types (numbers and non-numbers). Use pluck_as_nums() for type coercion.".into()),
ArrayType::String => Err("Cannot calculate variance of array with non-numeric values".into()),
ArrayType::Numeric => {
let values = extract_numeric_values(&arr);
let mean: f64 = values.iter().sum::<f64>() / values.len() as f64;
let variance = values
.iter()
.map(|&val| {
let diff = val - mean;
diff * diff
})
.sum::<f64>()
/ values.len() as f64;
Ok(variance)
}
}
}
fn stddev_array(arr: Array) -> Result<f64, Box<EvalAltResult>> {
let variance = variance_array(arr)?;
Ok(variance.sqrt())
}
fn clamp_i64(value: i64, min: i64, max: i64) -> i64 {
value.clamp(min, max)
}
fn clamp_f64(value: f64, min: f64, max: f64) -> f64 {
value.clamp(min, max)
}
#[cfg(test)]
mod tests {
use super::*;
use rhai::Dynamic;
use std::panic::{catch_unwind, AssertUnwindSafe};
fn panic_message(err: Box<dyn std::any::Any + Send>) -> String {
if let Some(msg) = err.downcast_ref::<String>() {
msg.clone()
} else if let Some(msg) = err.downcast_ref::<&str>() {
(*msg).to_string()
} else {
String::new()
}
}
#[test]
fn test_clamp_i64_within_range() {
assert_eq!(clamp_i64(5, 0, 10), 5);
assert_eq!(clamp_i64(50, 0, 100), 50);
assert_eq!(clamp_i64(0, 0, 10), 0);
assert_eq!(clamp_i64(10, 0, 10), 10);
}
#[test]
fn test_clamp_i64_below_min() {
assert_eq!(clamp_i64(-5, 0, 10), 0);
assert_eq!(clamp_i64(-100, 0, 100), 0);
assert_eq!(clamp_i64(-1, 0, 10), 0);
}
#[test]
fn test_clamp_i64_above_max() {
assert_eq!(clamp_i64(15, 0, 10), 10);
assert_eq!(clamp_i64(200, 0, 100), 100);
assert_eq!(clamp_i64(11, 0, 10), 10);
}
#[test]
fn test_clamp_i64_negative_range() {
assert_eq!(clamp_i64(-5, -10, -1), -5);
assert_eq!(clamp_i64(-15, -10, -1), -10);
assert_eq!(clamp_i64(0, -10, -1), -1);
}
#[test]
fn test_clamp_i64_inverted_range() {
let err = catch_unwind(AssertUnwindSafe(|| clamp_i64(5, 10, 0)))
.expect_err("clamp_i64 should panic when min > max");
let msg = panic_message(err);
assert!(
msg.contains("assertion failed: min <= max") || msg.contains("min > max"),
"unexpected panic message: {msg}"
);
}
#[test]
fn test_clamp_f64_within_range() {
assert_eq!(clamp_f64(3.5, 0.0, 5.0), 3.5);
assert_eq!(clamp_f64(2.5, 0.0, 10.0), 2.5);
assert_eq!(clamp_f64(0.0, 0.0, 5.0), 0.0);
assert_eq!(clamp_f64(5.0, 0.0, 5.0), 5.0);
}
#[test]
fn test_clamp_f64_below_min() {
assert_eq!(clamp_f64(-1.5, 0.0, 5.0), 0.0);
assert_eq!(clamp_f64(-100.0, 0.0, 100.0), 0.0);
assert_eq!(clamp_f64(-0.1, 0.0, 10.0), 0.0);
}
#[test]
fn test_clamp_f64_above_max() {
assert_eq!(clamp_f64(7.8, 0.0, 5.0), 5.0);
assert_eq!(clamp_f64(200.5, 0.0, 100.0), 100.0);
assert_eq!(clamp_f64(10.1, 0.0, 10.0), 10.0);
}
#[test]
fn test_clamp_f64_negative_range() {
assert_eq!(clamp_f64(-5.5, -10.0, -1.0), -5.5);
assert_eq!(clamp_f64(-15.0, -10.0, -1.0), -10.0);
assert_eq!(clamp_f64(0.0, -10.0, -1.0), -1.0);
}
#[test]
fn test_clamp_f64_inverted_range() {
let err = catch_unwind(AssertUnwindSafe(|| clamp_f64(5.0, 10.0, 0.0)))
.expect_err("clamp_f64 should panic when min > max");
let msg = panic_message(err);
assert!(
msg.contains("assertion failed: min <= max") || msg.contains("min > max"),
"unexpected panic message: {msg}"
);
}
#[test]
fn test_clamp_f64_fractional() {
assert_eq!(clamp_f64(0.5, 0.0, 1.0), 0.5);
assert_eq!(clamp_f64(1.5, 0.0, 1.0), 1.0);
assert_eq!(clamp_f64(-0.5, 0.0, 1.0), 0.0);
}
#[test]
fn test_sum_integers() {
let arr: Array = vec![
Dynamic::from(1i64),
Dynamic::from(2i64),
Dynamic::from(3i64),
Dynamic::from(4i64),
Dynamic::from(5i64),
];
let result = sum_array(arr);
assert_eq!(result.as_float().unwrap(), 15.0);
}
#[test]
fn test_sum_floats() {
let arr: Array = vec![
Dynamic::from(1.5f64),
Dynamic::from(2.5f64),
Dynamic::from(3.0f64),
];
let result = sum_array(arr);
assert_eq!(result.as_float().unwrap(), 7.0);
}
#[test]
fn test_sum_mixed_numeric() {
let arr: Array = vec![
Dynamic::from(10i64),
Dynamic::from(20.5f64),
Dynamic::from(30i64),
];
let result = sum_array(arr);
assert_eq!(result.as_float().unwrap(), 60.5);
}
#[test]
fn test_sum_mixed_types_rejected() {
let arr: Array = vec![
Dynamic::from(10i64),
Dynamic::from(true),
Dynamic::from(20i64),
Dynamic::from(false),
];
assert!(sum_array(arr).is_unit());
}
#[test]
fn test_sum_strings_rejected() {
let arr: Array = vec![
Dynamic::from(10i64),
Dynamic::from("20".to_string()),
Dynamic::from(30i64),
];
assert!(sum_array(arr).is_unit());
}
#[test]
fn test_sum_non_numeric_rejected() {
let arr: Array = vec![
Dynamic::from(10i64),
Dynamic::from("not a number".to_string()),
Dynamic::from(20i64),
];
assert!(sum_array(arr).is_unit());
}
#[test]
fn test_sum_empty_array() {
let arr: Array = vec![];
assert!(sum_array(arr).is_unit());
}
#[test]
fn test_sum_no_numeric_values() {
let arr: Array = vec![
Dynamic::from("abc".to_string()),
Dynamic::from("def".to_string()),
];
assert!(sum_array(arr).is_unit());
}
#[test]
fn test_mean_basic() {
let arr: Array = vec![
Dynamic::from(1i64),
Dynamic::from(2i64),
Dynamic::from(3i64),
Dynamic::from(4i64),
Dynamic::from(5i64),
];
assert_eq!(mean_array(arr).unwrap(), 3.0);
}
#[test]
fn test_mean_floats() {
let arr: Array = vec![
Dynamic::from(10.0f64),
Dynamic::from(20.0f64),
Dynamic::from(30.0f64),
];
assert_eq!(mean_array(arr).unwrap(), 20.0);
}
#[test]
fn test_mean_mixed_numeric() {
let arr: Array = vec![
Dynamic::from(10i64),
Dynamic::from(20.0f64),
Dynamic::from(30i64),
];
assert_eq!(mean_array(arr).unwrap(), 20.0);
}
#[test]
fn test_mean_mixed_types_rejected() {
let arr: Array = vec![
Dynamic::from(10i64),
Dynamic::from(true),
Dynamic::from(30i64),
Dynamic::from(false),
];
assert!(mean_array(arr).is_err());
}
#[test]
fn test_mean_mixed_numbers_strings_rejected() {
let arr: Array = vec![
Dynamic::from(10i64),
Dynamic::from("not a number".to_string()),
Dynamic::from(20i64),
Dynamic::from(30i64),
];
assert!(mean_array(arr).is_err());
}
#[test]
fn test_mean_empty_array_error() {
let arr: Array = vec![];
assert!(mean_array(arr).is_err());
}
#[test]
fn test_mean_no_numeric_values_error() {
let arr: Array = vec![
Dynamic::from("abc".to_string()),
Dynamic::from("def".to_string()),
];
assert!(mean_array(arr).is_err());
}
#[test]
fn test_variance_basic() {
let arr: Array = vec![
Dynamic::from(1i64),
Dynamic::from(2i64),
Dynamic::from(3i64),
Dynamic::from(4i64),
Dynamic::from(5i64),
];
assert_eq!(variance_array(arr).unwrap(), 2.0);
}
#[test]
fn test_variance_no_variation() {
let arr: Array = vec![
Dynamic::from(5i64),
Dynamic::from(5i64),
Dynamic::from(5i64),
Dynamic::from(5i64),
];
assert_eq!(variance_array(arr).unwrap(), 0.0);
}
#[test]
fn test_variance_floats() {
let arr: Array = vec![
Dynamic::from(2.0f64),
Dynamic::from(4.0f64),
Dynamic::from(6.0f64),
];
let result = variance_array(arr).unwrap();
assert!((result - 8.0 / 3.0).abs() < 1e-10);
}
#[test]
fn test_variance_empty_array_error() {
let arr: Array = vec![];
assert!(variance_array(arr).is_err());
}
#[test]
fn test_variance_no_numeric_values_error() {
let arr: Array = vec![Dynamic::from("abc".to_string())];
assert!(variance_array(arr).is_err());
}
#[test]
fn test_stddev_basic() {
let arr: Array = vec![
Dynamic::from(1i64),
Dynamic::from(2i64),
Dynamic::from(3i64),
Dynamic::from(4i64),
Dynamic::from(5i64),
];
let result = stddev_array(arr).unwrap();
assert!((result - 2.0f64.sqrt()).abs() < 1e-10);
}
#[test]
fn test_stddev_no_variation() {
let arr: Array = vec![
Dynamic::from(5i64),
Dynamic::from(5i64),
Dynamic::from(5i64),
];
assert_eq!(stddev_array(arr).unwrap(), 0.0);
}
#[test]
fn test_stddev_floats() {
let arr: Array = vec![
Dynamic::from(10.0f64),
Dynamic::from(20.0f64),
Dynamic::from(30.0f64),
];
let result = stddev_array(arr).unwrap();
assert!((result - (200.0f64 / 3.0).sqrt()).abs() < 1e-10);
}
#[test]
fn test_stddev_empty_array_error() {
let arr: Array = vec![];
assert!(stddev_array(arr).is_err());
}
#[test]
fn test_stddev_no_numeric_values_error() {
let arr: Array = vec![Dynamic::from("abc".to_string())];
assert!(stddev_array(arr).is_err());
}
#[test]
fn test_sum_numeric_strings_rejected() {
let arr: Array = vec![
Dynamic::from("10".to_string()),
Dynamic::from("20".to_string()),
Dynamic::from("30".to_string()),
];
assert!(sum_array(arr).is_unit());
}
#[test]
fn test_mean_numeric_strings_rejected() {
let arr: Array = vec![
Dynamic::from("10".to_string()),
Dynamic::from("20".to_string()),
Dynamic::from("30".to_string()),
];
assert!(mean_array(arr).is_err());
}
}