use std::collections::HashMap;
use mathlex::{BinaryOp, Expression, MathFloat, UnaryOp};
use mathlex_eval::{EvalInput, NumericResult, compile, eval};
use proptest::prelude::*;
fn finite_f64() -> impl Strategy<Value = f64> {
prop::num::f64::NORMAL
.prop_filter("finite", |v| v.is_finite())
.prop_map(|v| {
v.clamp(-1e6_f64, 1e6_f64)
})
}
fn positive_f64() -> impl Strategy<Value = f64> {
(0.001_f64..1_000.0_f64).prop_filter("finite", |v| v.is_finite())
}
fn small_i64() -> impl Strategy<Value = i64> {
-100_i64..=100_i64
}
fn leaf_ast() -> impl Strategy<Value = Expression> {
prop_oneof![
small_i64().prop_map(Expression::Integer),
finite_f64().prop_map(|v| Expression::Float(MathFloat::from(v))),
Just(Expression::Variable("x".into())),
]
}
fn safe_binary_op() -> impl Strategy<Value = BinaryOp> {
prop_oneof![
Just(BinaryOp::Add),
Just(BinaryOp::Sub),
Just(BinaryOp::Mul),
]
}
fn ast_strategy(depth: u32) -> impl Strategy<Value = Expression> {
if depth == 0 {
return leaf_ast().boxed();
}
let leaf = leaf_ast().boxed();
let binary = (
safe_binary_op(),
ast_strategy(depth - 1),
ast_strategy(depth - 1),
)
.prop_map(|(op, left, right)| Expression::Binary {
op,
left: Box::new(left),
right: Box::new(right),
})
.boxed();
let unary_neg = ast_strategy(depth - 1)
.prop_map(|operand| Expression::Unary {
op: UnaryOp::Neg,
operand: Box::new(operand),
})
.boxed();
prop_oneof![
3 => leaf,
2 => binary,
1 => unary_neg,
]
.boxed()
}
fn simple_binary_ast() -> impl Strategy<Value = (Expression, BinaryOp, f64)> {
(safe_binary_op(), finite_f64()).prop_map(|(op, lit)| {
let ast = Expression::Binary {
op,
left: Box::new(Expression::Variable("x".into())),
right: Box::new(Expression::Float(MathFloat::from(lit))),
};
(ast, op, lit)
})
}
fn array_pair_sizes() -> impl Strategy<Value = (usize, usize)> {
(1_usize..=8, 1_usize..=8)
}
fn naive_eval(ast: &Expression, x: f64) -> Option<f64> {
match ast {
Expression::Integer(n) => Some(*n as f64),
Expression::Float(f) => {
let v = f64::from(*f);
if v.is_finite() { Some(v) } else { None }
}
Expression::Variable(name) if name == "x" => Some(x),
Expression::Binary { op, left, right } => {
let l = naive_eval(left, x)?;
let r = naive_eval(right, x)?;
match op {
BinaryOp::Add => Some(l + r),
BinaryOp::Sub => Some(l - r),
BinaryOp::Mul => Some(l * r),
BinaryOp::Div => {
if r == 0.0 {
None
} else {
Some(l / r)
}
}
BinaryOp::Pow => {
let v = l.powf(r);
if v.is_finite() { Some(v) } else { None }
}
BinaryOp::Mod => {
if r == 0.0 {
None
} else {
Some(l % r)
}
}
_ => None,
}
}
Expression::Unary { op, operand } => {
let v = naive_eval(operand, x)?;
match op {
UnaryOp::Neg => Some(-v),
UnaryOp::Pos => Some(v),
_ => None,
}
}
_ => None,
}
}
fn no_constants() -> HashMap<&'static str, NumericResult> {
HashMap::new()
}
fn make_scalar_args(x_val: f64) -> HashMap<&'static str, EvalInput> {
let mut args = HashMap::new();
args.insert("x", EvalInput::Scalar(x_val));
args
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(512))]
#[test]
fn prop_valid_ast_compile_no_panic(ast in ast_strategy(3)) {
let _ = compile(&ast, &no_constants());
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(512))]
#[test]
fn prop_scalar_eval_matches_naive(
(ast, op, _lit) in simple_binary_ast(),
x_val in finite_f64(),
) {
let compiled = match compile(&ast, &no_constants()) {
Ok(c) => c,
Err(_) => return Ok(()),
};
let handle = match eval(&compiled, make_scalar_args(x_val)) {
Ok(h) => h,
Err(_) => return Ok(()),
};
let result = match handle.scalar() {
Ok(r) => r,
Err(_) => return Ok(()),
};
let eval_val = match result.to_f64() {
Some(v) => v,
None => return Ok(()),
};
let naive_val = match naive_eval(&ast, x_val) {
Some(v) => v,
None => return Ok(()),
};
if !eval_val.is_finite() || !naive_val.is_finite() {
return Ok(());
}
let tolerance = match op {
BinaryOp::Pow => 1e-6,
_ => 1e-9,
};
let diff = (eval_val - naive_val).abs();
let relative = if naive_val.abs() > 1.0 {
diff / naive_val.abs()
} else {
diff
};
prop_assert!(
relative <= tolerance,
"eval={eval_val}, naive={naive_val}, diff={diff}",
);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(256))]
#[test]
fn prop_eager_and_lazy_identical(
x_vals in prop::collection::vec(finite_f64(), 1..=16),
) {
let ast = Expression::Binary {
op: BinaryOp::Mul,
left: Box::new(Expression::Variable("x".into())),
right: Box::new(Expression::Variable("x".into())),
};
let compiled = compile(&ast, &no_constants()).unwrap();
let make_args = || {
let mut args = HashMap::new();
args.insert("x", EvalInput::from(x_vals.clone()));
args
};
let eager: Vec<NumericResult> = eval(&compiled, make_args())
.unwrap()
.to_array()
.unwrap()
.iter()
.copied()
.collect();
let lazy: Vec<NumericResult> = eval(&compiled, make_args())
.unwrap()
.iter()
.map(|r| r.unwrap())
.collect();
prop_assert_eq!(eager.len(), lazy.len());
for (i, (e, l)) in eager.iter().zip(lazy.iter()).enumerate() {
prop_assert_eq!(
e, l,
"mismatch at index {}: eager={:?}, lazy={:?}", i, e, l,
);
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(128))]
#[test]
fn prop_eager_lazy_random_ast(
ast in ast_strategy(3),
x_vals in prop::collection::vec(positive_f64(), 1..=8),
) {
let compiled = match compile(&ast, &no_constants()) {
Ok(c) => c,
Err(_) => return Ok(()),
};
let names = compiled.argument_names();
if names != ["x"] {
return Ok(());
}
let make_args = || {
let mut args = HashMap::new();
args.insert("x", EvalInput::from(x_vals.clone()));
args
};
let eager_result = eval(&compiled, make_args())
.unwrap()
.to_array();
let lazy_result: Vec<Result<NumericResult, _>> = eval(&compiled, make_args())
.unwrap()
.iter()
.collect();
match eager_result {
Err(_) => {
prop_assert!(lazy_result.iter().any(|r| r.is_err()));
}
Ok(arr) => {
let eager_vec: Vec<NumericResult> = arr.iter().copied().collect();
prop_assert_eq!(eager_vec.len(), lazy_result.len());
for (i, (e, l)) in eager_vec.iter().zip(lazy_result.iter()).enumerate() {
match l {
Ok(lv) => prop_assert_eq!(
e, lv,
"index {}: eager={:?}, lazy={:?}", i, e, lv,
),
Err(_) => {
}
}
}
}
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(256))]
#[test]
fn prop_broadcast_len_is_product_of_array_lengths(
(n, m) in array_pair_sizes(),
x_base in finite_f64(),
y_base in finite_f64(),
) {
let ast = Expression::Binary {
op: BinaryOp::Add,
left: Box::new(Expression::Variable("x".into())),
right: Box::new(Expression::Variable("y".into())),
};
let compiled = compile(&ast, &no_constants()).unwrap();
let x_vals: Vec<f64> = (0..n).map(|i| x_base + i as f64).collect();
let y_vals: Vec<f64> = (0..m).map(|i| y_base + i as f64).collect();
let mut args = HashMap::new();
args.insert("x", EvalInput::from(x_vals));
args.insert("y", EvalInput::from(y_vals));
let handle = eval(&compiled, args).unwrap();
prop_assert_eq!(handle.len(), n * m);
prop_assert_eq!(handle.shape(), &[n, m]);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(128))]
#[test]
fn prop_broadcast_three_arrays_len_is_product(
a in 1_usize..=5,
b in 1_usize..=5,
c in 1_usize..=5,
) {
let ast = Expression::Binary {
op: BinaryOp::Add,
left: Box::new(Expression::Binary {
op: BinaryOp::Add,
left: Box::new(Expression::Variable("x".into())),
right: Box::new(Expression::Variable("y".into())),
}),
right: Box::new(Expression::Variable("z".into())),
};
let compiled = compile(&ast, &no_constants()).unwrap();
let x_vals: Vec<f64> = (0..a).map(|i| i as f64).collect();
let y_vals: Vec<f64> = (0..b).map(|i| i as f64).collect();
let z_vals: Vec<f64> = (0..c).map(|i| i as f64).collect();
let mut args = HashMap::new();
args.insert("x", EvalInput::from(x_vals));
args.insert("y", EvalInput::from(y_vals));
args.insert("z", EvalInput::from(z_vals));
let handle = eval(&compiled, args).unwrap();
prop_assert_eq!(handle.len(), a * b * c);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(256))]
#[test]
fn prop_arg_order_independent_scalar(
x_val in finite_f64(),
y_val in finite_f64(),
) {
let ast = Expression::Binary {
op: BinaryOp::Add,
left: Box::new(Expression::Variable("x".into())),
right: Box::new(Expression::Variable("y".into())),
};
let compiled = compile(&ast, &no_constants()).unwrap();
let mut args_xy: HashMap<&str, EvalInput> = HashMap::new();
args_xy.insert("x", EvalInput::Scalar(x_val));
args_xy.insert("y", EvalInput::Scalar(y_val));
let mut args_yx: HashMap<&str, EvalInput> = HashMap::new();
args_yx.insert("y", EvalInput::Scalar(y_val));
args_yx.insert("x", EvalInput::Scalar(x_val));
let result_xy = eval(&compiled, args_xy)
.unwrap()
.scalar()
.unwrap()
.to_f64();
let result_yx = eval(&compiled, args_yx)
.unwrap()
.scalar()
.unwrap()
.to_f64();
prop_assert_eq!(result_xy, result_yx);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(128))]
#[test]
fn prop_arg_order_independent_array(
x_vals in prop::collection::vec(finite_f64(), 1..=8),
y_vals in prop::collection::vec(finite_f64(), 1..=8),
) {
let ast = Expression::Binary {
op: BinaryOp::Add,
left: Box::new(Expression::Variable("x".into())),
right: Box::new(Expression::Variable("y".into())),
};
let compiled = compile(&ast, &no_constants()).unwrap();
let mut args_xy: HashMap<&str, EvalInput> = HashMap::new();
args_xy.insert("x", EvalInput::from(x_vals.clone()));
args_xy.insert("y", EvalInput::from(y_vals.clone()));
let mut args_yx: HashMap<&str, EvalInput> = HashMap::new();
args_yx.insert("y", EvalInput::from(y_vals.clone()));
args_yx.insert("x", EvalInput::from(x_vals.clone()));
let flat_xy: Vec<NumericResult> = eval(&compiled, args_xy)
.unwrap()
.to_array()
.unwrap()
.iter()
.copied()
.collect();
let flat_yx: Vec<NumericResult> = eval(&compiled, args_yx)
.unwrap()
.to_array()
.unwrap()
.iter()
.copied()
.collect();
prop_assert_eq!(flat_xy.len(), flat_yx.len());
for (i, (a, b)) in flat_xy.iter().zip(flat_yx.iter()).enumerate() {
prop_assert_eq!(a, b, "mismatch at index {}", i);
}
}
}