use std::collections::HashMap;
use approx::assert_abs_diff_eq;
use mathlex::{BinaryOp, Expression};
use mathlex_eval::{EvalError, EvalInput, NumericResult, compile, eval};
fn var(name: &str) -> Expression {
Expression::Variable(name.into())
}
fn int(v: i64) -> Expression {
Expression::Integer(v)
}
fn ast_x_sq_plus_y() -> Expression {
let x_sq = Expression::Binary {
op: BinaryOp::Pow,
left: Box::new(var("x")),
right: Box::new(int(2)),
};
Expression::Binary {
op: BinaryOp::Add,
left: Box::new(x_sq),
right: Box::new(var("y")),
}
}
fn ast_x_plus_y() -> Expression {
Expression::Binary {
op: BinaryOp::Add,
left: Box::new(var("x")),
right: Box::new(var("y")),
}
}
fn ast_x_times_two() -> Expression {
Expression::Binary {
op: BinaryOp::Mul,
left: Box::new(var("x")),
right: Box::new(int(2)),
}
}
fn no_constants() -> HashMap<&'static str, NumericResult> {
HashMap::new()
}
#[test]
fn all_scalars_produce_zero_d_output() {
let ast = ast_x_sq_plus_y();
let compiled = compile(&ast, &no_constants()).unwrap();
let mut args = HashMap::new();
args.insert("x", EvalInput::Scalar(3.0));
args.insert("y", EvalInput::Scalar(1.0));
let handle = eval(&compiled, args).unwrap();
assert!(
handle.shape().is_empty(),
"expected empty shape for all-scalar args, got {:?}",
handle.shape()
);
assert_eq!(handle.len(), 1);
}
#[test]
fn all_scalars_scalar_value_correct() {
let ast = ast_x_sq_plus_y();
let compiled = compile(&ast, &no_constants()).unwrap();
let mut args = HashMap::new();
args.insert("x", EvalInput::Scalar(3.0));
args.insert("y", EvalInput::Scalar(1.0));
let result = eval(&compiled, args).unwrap().scalar().unwrap();
assert_abs_diff_eq!(result.to_f64().unwrap(), 10.0, epsilon = 1e-10);
}
#[test]
fn one_array_arg_produces_1d_output() {
let ast = ast_x_times_two();
let compiled = compile(&ast, &no_constants()).unwrap();
let mut args = HashMap::new();
args.insert("x", EvalInput::from(vec![1.0, 2.0, 3.0]));
let handle = eval(&compiled, args).unwrap();
assert_eq!(handle.shape(), &[3]);
assert_eq!(handle.len(), 3);
}
#[test]
fn one_array_arg_values_correct() {
let ast = ast_x_times_two();
let compiled = compile(&ast, &no_constants()).unwrap();
let mut args = HashMap::new();
args.insert("x", EvalInput::from(vec![1.0, 2.0, 3.0]));
let arr = eval(&compiled, args).unwrap().to_array().unwrap();
assert_eq!(arr.shape(), &[3]);
assert_eq!(*arr.get([0]).unwrap(), NumericResult::Real(2.0));
assert_eq!(*arr.get([1]).unwrap(), NumericResult::Real(4.0));
assert_eq!(*arr.get([2]).unwrap(), NumericResult::Real(6.0));
}
#[test]
fn two_array_args_produce_2d_output() {
let ast = ast_x_plus_y();
let compiled = compile(&ast, &no_constants()).unwrap();
let mut args = HashMap::new();
args.insert("x", EvalInput::from(vec![1.0, 2.0]));
args.insert("y", EvalInput::from(vec![10.0, 20.0, 30.0]));
let handle = eval(&compiled, args).unwrap();
assert_eq!(handle.shape(), &[2, 3]);
assert_eq!(handle.len(), 6);
}
#[test]
fn two_array_args_grid_values_correct() {
let ast = ast_x_sq_plus_y();
let compiled = compile(&ast, &no_constants()).unwrap();
let mut args = HashMap::new();
args.insert("x", EvalInput::from(vec![1.0, 2.0, 3.0]));
args.insert("y", EvalInput::from(vec![10.0, 20.0]));
let arr = eval(&compiled, args).unwrap().to_array().unwrap();
assert_eq!(arr.shape(), &[3, 2]);
assert_eq!(*arr.get([0, 0]).unwrap(), NumericResult::Real(11.0));
assert_eq!(*arr.get([0, 1]).unwrap(), NumericResult::Real(21.0));
assert_eq!(*arr.get([1, 0]).unwrap(), NumericResult::Real(14.0));
assert_eq!(*arr.get([1, 1]).unwrap(), NumericResult::Real(24.0));
assert_eq!(*arr.get([2, 0]).unwrap(), NumericResult::Real(19.0));
assert_eq!(*arr.get([2, 1]).unwrap(), NumericResult::Real(29.0));
}
#[test]
fn scalar_broadcasts_over_array() {
let ast = ast_x_sq_plus_y();
let compiled = compile(&ast, &no_constants()).unwrap();
let mut args = HashMap::new();
args.insert("x", EvalInput::from(vec![1.0, 2.0, 3.0]));
args.insert("y", EvalInput::Scalar(0.0));
let handle = eval(&compiled, args).unwrap();
assert_eq!(handle.shape(), &[3]);
let arr = handle.to_array().unwrap();
assert_eq!(*arr.get([0]).unwrap(), NumericResult::Real(1.0));
assert_eq!(*arr.get([1]).unwrap(), NumericResult::Real(4.0));
assert_eq!(*arr.get([2]).unwrap(), NumericResult::Real(9.0));
}
#[test]
fn scalar_broadcasts_over_2d_result() {
let ast = Expression::Binary {
op: BinaryOp::Add,
left: Box::new(ast_x_plus_y()),
right: Box::new(var("c")),
};
let mut constants = HashMap::new();
constants.insert("c", NumericResult::Real(100.0));
let compiled = compile(&ast, &constants).unwrap();
let mut args = HashMap::new();
args.insert("x", EvalInput::from(vec![1.0, 2.0]));
args.insert("y", EvalInput::from(vec![10.0, 20.0]));
let arr = eval(&compiled, args).unwrap().to_array().unwrap();
assert_eq!(arr.shape(), &[2, 2]);
assert_eq!(*arr.get([0, 0]).unwrap(), NumericResult::Real(111.0));
assert_eq!(*arr.get([0, 1]).unwrap(), NumericResult::Real(121.0));
assert_eq!(*arr.get([1, 0]).unwrap(), NumericResult::Real(112.0));
assert_eq!(*arr.get([1, 1]).unwrap(), NumericResult::Real(122.0));
}
#[test]
fn shape_inspectable_before_consumption() {
let ast = ast_x_sq_plus_y();
let compiled = compile(&ast, &no_constants()).unwrap();
let mut args = HashMap::new();
args.insert("x", EvalInput::from(vec![1.0, 2.0, 3.0]));
args.insert("y", EvalInput::from(vec![10.0, 20.0]));
let handle = eval(&compiled, args).unwrap();
let shape = handle.shape().to_vec();
let len = handle.len();
let is_empty = handle.is_empty();
assert_eq!(shape, vec![3, 2]);
assert_eq!(len, 6);
assert!(!is_empty);
let arr = handle.to_array().unwrap();
assert_eq!(arr.shape(), &[3, 2]);
}
#[test]
fn eager_and_lazy_produce_identical_results_1d() {
let ast = ast_x_sq_plus_y();
let compiled = compile(&ast, &no_constants()).unwrap();
let make_args = || {
let mut args = HashMap::new();
args.insert("x", EvalInput::from(vec![1.0, 2.0, 3.0]));
args.insert("y", EvalInput::Scalar(10.0));
args
};
let eager: Vec<NumericResult> = eval(&compiled, make_args())
.unwrap()
.to_array()
.unwrap()
.iter()
.copied()
.collect();
let lazy: Vec<NumericResult> = eval(&compiled, make_args())
.unwrap()
.iter()
.map(|r| r.unwrap())
.collect();
assert_eq!(eager, lazy);
}
#[test]
fn eager_and_lazy_produce_identical_results_2d() {
let ast = ast_x_sq_plus_y();
let compiled = compile(&ast, &no_constants()).unwrap();
let make_args = || {
let mut args = HashMap::new();
args.insert("x", EvalInput::from(vec![1.0, 2.0, 3.0]));
args.insert("y", EvalInput::from(vec![10.0, 20.0]));
args
};
let eager: Vec<NumericResult> = eval(&compiled, make_args())
.unwrap()
.to_array()
.unwrap()
.iter()
.copied()
.collect();
let lazy: Vec<NumericResult> = eval(&compiled, make_args())
.unwrap()
.iter()
.map(|r| r.unwrap())
.collect();
assert_eq!(eager, lazy);
assert_eq!(eager.len(), 6);
}
#[test]
fn empty_array_produces_empty_output() {
let ast = ast_x_times_two();
let compiled = compile(&ast, &no_constants()).unwrap();
let mut args = HashMap::new();
args.insert("x", EvalInput::from(vec![] as Vec<f64>));
let handle = eval(&compiled, args).unwrap();
assert!(handle.is_empty());
assert_eq!(handle.len(), 0);
assert_eq!(handle.shape(), &[0]);
}
#[test]
fn empty_array_to_array_returns_empty_array() {
let ast = ast_x_times_two();
let compiled = compile(&ast, &no_constants()).unwrap();
let mut args = HashMap::new();
args.insert("x", EvalInput::from(vec![] as Vec<f64>));
let arr = eval(&compiled, args).unwrap().to_array().unwrap();
assert_eq!(arr.len(), 0);
}
#[test]
fn empty_array_iter_yields_no_elements() {
let ast = ast_x_times_two();
let compiled = compile(&ast, &no_constants()).unwrap();
let mut args = HashMap::new();
args.insert("x", EvalInput::from(vec![] as Vec<f64>));
let items: Vec<_> = eval(&compiled, args).unwrap().iter().collect();
assert!(items.is_empty());
}
#[test]
fn one_empty_array_with_nonempty_array_produces_empty_output() {
let ast = ast_x_plus_y();
let compiled = compile(&ast, &no_constants()).unwrap();
let mut args = HashMap::new();
args.insert("x", EvalInput::from(vec![] as Vec<f64>));
args.insert("y", EvalInput::from(vec![1.0, 2.0, 3.0]));
let handle = eval(&compiled, args).unwrap();
assert!(handle.is_empty());
}
#[test]
fn iter_input_materialized_to_correct_values() {
let ast = ast_x_times_two();
let compiled = compile(&ast, &no_constants()).unwrap();
let mut args = HashMap::new();
args.insert(
"x",
EvalInput::Iter(Box::new(vec![5.0, 6.0, 7.0].into_iter())),
);
let arr = eval(&compiled, args).unwrap().to_array().unwrap();
assert_eq!(arr.shape(), &[3]);
assert_eq!(*arr.get([0]).unwrap(), NumericResult::Real(10.0));
assert_eq!(*arr.get([1]).unwrap(), NumericResult::Real(12.0));
assert_eq!(*arr.get([2]).unwrap(), NumericResult::Real(14.0));
}
#[test]
fn iter_input_shape_matches_array_input() {
let ast = ast_x_times_two();
let compiled = compile(&ast, &no_constants()).unwrap();
let values = vec![1.0, 2.0, 3.0, 4.0];
let mut args_iter = HashMap::new();
args_iter.insert("x", EvalInput::Iter(Box::new(values.clone().into_iter())));
let mut args_arr = HashMap::new();
args_arr.insert("x", EvalInput::from(values));
let shape_from_iter = eval(&compiled, args_iter).unwrap().shape().to_vec();
let shape_from_arr = eval(&compiled, args_arr).unwrap().shape().to_vec();
assert_eq!(shape_from_iter, shape_from_arr);
}
#[test]
fn iter_input_combined_with_array_produces_2d() {
let ast = ast_x_plus_y();
let compiled = compile(&ast, &no_constants()).unwrap();
let mut args = HashMap::new();
args.insert("x", EvalInput::Iter(Box::new(vec![1.0, 2.0].into_iter())));
args.insert("y", EvalInput::from(vec![10.0, 20.0]));
let handle = eval(&compiled, args).unwrap();
assert_eq!(handle.shape(), &[2, 2]);
assert_eq!(handle.len(), 4);
}
#[test]
fn scalar_on_1d_output_returns_shape_mismatch() {
let ast = ast_x_times_two();
let compiled = compile(&ast, &no_constants()).unwrap();
let mut args = HashMap::new();
args.insert("x", EvalInput::from(vec![1.0, 2.0, 3.0]));
let handle = eval(&compiled, args).unwrap();
let err = handle.scalar().unwrap_err();
assert!(
matches!(err, EvalError::ShapeMismatch { .. }),
"expected ShapeMismatch, got {:?}",
err
);
}
#[test]
fn scalar_on_2d_output_returns_shape_mismatch() {
let ast = ast_x_plus_y();
let compiled = compile(&ast, &no_constants()).unwrap();
let mut args = HashMap::new();
args.insert("x", EvalInput::from(vec![1.0, 2.0]));
args.insert("y", EvalInput::from(vec![10.0, 20.0]));
let handle = eval(&compiled, args).unwrap();
let err = handle.scalar().unwrap_err();
assert!(matches!(err, EvalError::ShapeMismatch { .. }));
}