use approx::assert_relative_eq;
use numrs2::array::Array;
use numrs2::math::{arange, linspace, ElementWiseMath};
use numrs2::prelude::*;
use numrs2::ufuncs::{cos, sin, tan};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Deserialize, Serialize)]
struct SerializedArray {
data: Vec<f64>,
shape: Vec<usize>,
dtype: String,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
#[allow(dead_code)]
enum TestResult {
Scalar(f64),
Array(SerializedArray),
Boolean(bool),
BooleanArray(Vec<bool>),
IntegerArray(Vec<i64>),
}
fn load_reference_data() -> HashMap<String, HashMap<String, serde_json::Value>> {
let data_path = concat!(
env!("CARGO_MANIFEST_DIR"),
"/tests/py/array_operations_reference_data.json"
);
let data_str = std::fs::read_to_string(data_path)
.expect("Failed to read reference data file. Run the Python script first.");
serde_json::from_str(&data_str).expect("Failed to parse reference data JSON")
}
fn deserialize_array(serialized: &SerializedArray) -> Array<f64> {
Array::from_vec(serialized.data.clone()).reshape(&serialized.shape)
}
fn assert_arrays_close(actual: &Array<f64>, expected: &Array<f64>, tolerance: f64) {
assert_eq!(actual.shape(), expected.shape(), "Array shapes don't match");
let actual_vec = actual.to_vec();
let expected_vec = expected.to_vec();
assert_eq!(
actual_vec.len(),
expected_vec.len(),
"Array sizes don't match"
);
for (i, (&a, &e)) in actual_vec.iter().zip(expected_vec.iter()).enumerate() {
assert_relative_eq!(a, e, epsilon = tolerance, max_relative = tolerance);
if (a - e).abs() > tolerance {
panic!("Arrays differ at index {}: got {}, expected {}", i, a, e);
}
}
}
fn assert_scalar_close(actual: f64, expected: f64, tolerance: f64) {
assert_relative_eq!(actual, expected, epsilon = tolerance);
}
#[cfg(test)]
mod array_creation_tests {
use super::*;
#[test]
fn test_zeros_creation() {
let reference_data = load_reference_data();
let creation_tests = &reference_data["array_creation"];
for (test_name, test_data) in creation_tests {
if test_name.starts_with("zeros_shape_") {
let test_obj = test_data.as_object().unwrap();
let shape: Vec<usize> = test_obj["shape"]
.as_array()
.unwrap()
.iter()
.map(|v| v.as_u64().unwrap() as usize)
.collect();
let expected: SerializedArray =
serde_json::from_value(test_obj["result"].clone()).unwrap();
if shape.iter().product::<usize>() > 0 {
let actual = Array::<f64>::zeros(&shape);
let expected_array = deserialize_array(&expected);
assert_arrays_close(&actual, &expected_array, 1e-15);
}
}
}
}
#[test]
fn test_ones_creation() {
let reference_data = load_reference_data();
let creation_tests = &reference_data["array_creation"];
for (test_name, test_data) in creation_tests {
if test_name.starts_with("ones_shape_") {
let test_obj = test_data.as_object().unwrap();
let shape: Vec<usize> = test_obj["shape"]
.as_array()
.unwrap()
.iter()
.map(|v| v.as_u64().unwrap() as usize)
.collect();
let expected: SerializedArray =
serde_json::from_value(test_obj["result"].clone()).unwrap();
if shape.iter().product::<usize>() > 0 {
let actual = Array::<f64>::ones(&shape);
let expected_array = deserialize_array(&expected);
assert_arrays_close(&actual, &expected_array, 1e-15);
}
}
}
}
#[test]
fn test_full_creation() {
let reference_data = load_reference_data();
let creation_tests = &reference_data["array_creation"];
for (test_name, test_data) in creation_tests {
if test_name.starts_with("full_shape_") {
let test_obj = test_data.as_object().unwrap();
let shape: Vec<usize> = test_obj["shape"]
.as_array()
.unwrap()
.iter()
.map(|v| v.as_u64().unwrap() as usize)
.collect();
let fill_value = test_obj["fill_value"].as_f64().unwrap();
let expected: SerializedArray =
serde_json::from_value(test_obj["result"].clone()).unwrap();
let actual = Array::<f64>::full(&shape, fill_value);
let expected_array = deserialize_array(&expected);
assert_arrays_close(&actual, &expected_array, 1e-15);
}
}
}
#[test]
fn test_arange_creation() {
let reference_data = load_reference_data();
let creation_tests = &reference_data["array_creation"];
for (test_name, test_data) in creation_tests {
if test_name.starts_with("arange_") {
let test_obj = test_data.as_object().unwrap();
let params = test_obj["params"].as_object().unwrap();
let start = params["start"].as_f64().unwrap();
let stop = params["stop"].as_f64().unwrap();
let step = params["step"].as_f64().unwrap();
let expected: SerializedArray =
serde_json::from_value(test_obj["result"].clone()).unwrap();
let actual = arange(start, stop, step);
let expected_array = deserialize_array(&expected);
assert_arrays_close(&actual, &expected_array, 1e-12);
}
}
}
#[test]
fn test_linspace_creation() {
let reference_data = load_reference_data();
let creation_tests = &reference_data["array_creation"];
for (test_name, test_data) in creation_tests {
if test_name.starts_with("linspace_") {
let test_obj = test_data.as_object().unwrap();
let params = test_obj["params"].as_object().unwrap();
let start = params["start"].as_f64().unwrap();
let stop = params["stop"].as_f64().unwrap();
let num = params["num"].as_u64().unwrap() as usize;
let expected: SerializedArray =
serde_json::from_value(test_obj["result"].clone()).unwrap();
let actual = linspace(start, stop, num);
let expected_array = deserialize_array(&expected);
assert_arrays_close(&actual, &expected_array, 1e-12);
}
}
}
}
#[cfg(test)]
mod array_manipulation_tests {
use super::*;
#[test]
fn test_reshape_operations() {
let reference_data = load_reference_data();
let manipulation_tests = &reference_data["array_manipulation"];
for (test_name, test_data) in manipulation_tests {
if test_name.starts_with("reshape_") {
let test_obj = test_data.as_object().unwrap();
let input: SerializedArray =
serde_json::from_value(test_obj["input"].clone()).unwrap();
let new_shape: Vec<usize> = test_obj["new_shape"]
.as_array()
.unwrap()
.iter()
.map(|v| v.as_u64().unwrap() as usize)
.collect();
let expected: SerializedArray =
serde_json::from_value(test_obj["result"].clone()).unwrap();
let input_array = deserialize_array(&input);
let actual = input_array.reshape(&new_shape);
let expected_array = deserialize_array(&expected);
assert_arrays_close(&actual, &expected_array, 1e-15);
}
}
}
#[test]
fn test_transpose_operations() {
let reference_data = load_reference_data();
let manipulation_tests = &reference_data["array_manipulation"];
for (test_name, test_data) in manipulation_tests {
if test_name.starts_with("transpose_") {
let test_obj = test_data.as_object().unwrap();
let input: SerializedArray =
serde_json::from_value(test_obj["input"].clone()).unwrap();
let expected: SerializedArray =
serde_json::from_value(test_obj["result"].clone()).unwrap();
let input_array = deserialize_array(&input);
let actual = input_array.transpose();
let expected_array = deserialize_array(&expected);
assert_arrays_close(&actual, &expected_array, 1e-15);
}
}
}
#[test]
fn test_flatten_operations() {
let reference_data = load_reference_data();
let manipulation_tests = &reference_data["array_manipulation"];
for (test_name, test_data) in manipulation_tests {
if test_name.starts_with("flatten_") {
let test_obj = test_data.as_object().unwrap();
let input: SerializedArray =
serde_json::from_value(test_obj["input"].clone()).unwrap();
let expected: SerializedArray =
serde_json::from_value(test_obj["result"].clone()).unwrap();
let input_array = deserialize_array(&input);
let actual = input_array.flatten(None);
let expected_array = deserialize_array(&expected);
assert_arrays_close(&actual, &expected_array, 1e-15);
}
}
}
#[test]
fn test_squeeze_operations() {
let reference_data = load_reference_data();
let manipulation_tests = &reference_data["array_manipulation"];
for (test_name, test_data) in manipulation_tests {
if test_name.starts_with("squeeze_") {
let test_obj = test_data.as_object().unwrap();
let input: SerializedArray =
serde_json::from_value(test_obj["input"].clone()).unwrap();
let expected: SerializedArray =
serde_json::from_value(test_obj["result"].clone()).unwrap();
let input_array = deserialize_array(&input);
let actual = squeeze(&input_array, None).unwrap();
let expected_array = deserialize_array(&expected);
assert_arrays_close(&actual, &expected_array, 1e-15);
}
}
}
}
#[cfg(test)]
mod arithmetic_operations_tests {
use super::*;
#[test]
fn test_element_wise_arithmetic() {
let reference_data = load_reference_data();
let arithmetic_tests = &reference_data["arithmetic_operations"];
for test_data in arithmetic_tests.values() {
let test_obj = test_data.as_object().unwrap();
let operation = test_obj["operation"].as_str().unwrap();
if ["add", "subtract", "multiply", "divide"].contains(&operation) {
let input1: SerializedArray =
serde_json::from_value(test_obj["input1"].clone()).unwrap();
let input2: SerializedArray =
serde_json::from_value(test_obj["input2"].clone()).unwrap();
let expected: SerializedArray =
serde_json::from_value(test_obj["result"].clone()).unwrap();
let arr1 = deserialize_array(&input1);
let arr2 = deserialize_array(&input2);
let expected_array = deserialize_array(&expected);
let actual = match operation {
"add" => arr1.add(&arr2),
"subtract" => arr1.subtract(&arr2),
"multiply" => arr1.multiply(&arr2),
"divide" => arr1.divide(&arr2),
_ => unreachable!(),
};
assert_arrays_close(&actual, &expected_array, 1e-12);
}
}
}
#[test]
fn test_scalar_arithmetic() {
let reference_data = load_reference_data();
let arithmetic_tests = &reference_data["arithmetic_operations"];
for (test_name, test_data) in arithmetic_tests {
if test_name.starts_with("scalar_") {
let test_obj = test_data.as_object().unwrap();
let operation = test_obj["operation"].as_str().unwrap();
let input: SerializedArray =
serde_json::from_value(test_obj["input"].clone()).unwrap();
let scalar = test_obj["scalar"].as_f64().unwrap();
let expected: SerializedArray =
serde_json::from_value(test_obj["result"].clone()).unwrap();
let input_array = deserialize_array(&input);
let expected_array = deserialize_array(&expected);
let actual = match operation {
"scalar_add" => input_array.add_scalar(scalar),
"scalar_subtract" => input_array.subtract_scalar(scalar),
"scalar_multiply" => input_array.multiply_scalar(scalar),
"scalar_divide" => input_array.divide_scalar(scalar),
_ => unreachable!(),
};
assert_arrays_close(&actual, &expected_array, 1e-12);
}
}
}
}
#[cfg(test)]
mod mathematical_functions_tests {
use super::*;
#[test]
fn test_basic_math_functions() {
let reference_data = load_reference_data();
let math_tests = &reference_data["mathematical_functions"];
for test_data in math_tests.values() {
let test_obj = test_data.as_object().unwrap();
let operation = test_obj["operation"].as_str().unwrap();
if ["sqrt", "exp", "log", "abs", "sin", "cos", "tan"].contains(&operation) {
let input: SerializedArray =
serde_json::from_value(test_obj["input"].clone()).unwrap();
let expected: SerializedArray =
serde_json::from_value(test_obj["result"].clone()).unwrap();
let input_array = deserialize_array(&input);
let expected_array = deserialize_array(&expected);
let actual = match operation {
"sqrt" => sqrt(&input_array),
"exp" => exp(&input_array),
"log" => log(&input_array),
"abs" => input_array.abs(),
"sin" => sin(&input_array),
"cos" => cos(&input_array),
"tan" => tan(&input_array),
_ => continue,
};
let actual_vec = actual.to_vec();
let expected_vec = expected_array.to_vec();
assert_eq!(
actual_vec.len(),
expected_vec.len(),
"Array sizes don't match"
);
for (&a, &e) in actual_vec.iter().zip(expected_vec.iter()) {
if a.is_nan() || e.is_nan() || a.is_infinite() || e.is_infinite() {
continue;
}
assert_relative_eq!(a, e, epsilon = 1e-9, max_relative = 1e-9);
}
}
}
}
#[test]
fn test_power_operations() {
let reference_data = load_reference_data();
let math_tests = &reference_data["mathematical_functions"];
for (test_name, test_data) in math_tests {
if test_name.starts_with("power_") {
let test_obj = test_data.as_object().unwrap();
let base: SerializedArray =
serde_json::from_value(test_obj["base"].clone()).unwrap();
let exponent = test_obj["exponent"].as_f64().unwrap();
let expected: SerializedArray =
serde_json::from_value(test_obj["result"].clone()).unwrap();
let base_array = deserialize_array(&base);
let expected_array = deserialize_array(&expected);
let actual = base_array.pow(exponent);
assert_arrays_close(&actual, &expected_array, 1e-12);
}
}
}
#[test]
fn test_rounding_functions() {
let reference_data = load_reference_data();
let math_tests = &reference_data["mathematical_functions"];
for test_data in math_tests.values() {
let test_obj = test_data.as_object().unwrap();
let operation = test_obj["operation"].as_str().unwrap();
if ["floor", "ceil", "round"].contains(&operation) {
let input: SerializedArray =
serde_json::from_value(test_obj["input"].clone()).unwrap();
let expected: SerializedArray =
serde_json::from_value(test_obj["result"].clone()).unwrap();
let input_array = deserialize_array(&input);
let expected_array = deserialize_array(&expected);
let actual = match operation {
"floor" => floor(&input_array),
"ceil" => ceil(&input_array),
"round" => round(&input_array),
_ => continue,
};
assert_arrays_close(&actual, &expected_array, 1e-15);
}
}
}
}
#[cfg(test)]
mod statistical_operations_tests {
use super::*;
#[test]
fn test_basic_statistics() {
let reference_data = load_reference_data();
let stats_tests = &reference_data["statistical_operations"];
for (test_name, test_data) in stats_tests {
if test_name.contains("_overall") {
let test_obj = test_data.as_object().unwrap();
let operation = test_obj["operation"].as_str().unwrap();
let input: SerializedArray =
serde_json::from_value(test_obj["input"].clone()).unwrap();
let expected = test_obj["result"].as_f64().unwrap();
let input_array = deserialize_array(&input);
let actual = match operation {
"mean" => input_array.mean(),
"sum" => input_array.sum(),
"min" => input_array.min(),
"max" => input_array.max(),
"std" => input_array.std(),
"var" => input_array.var(),
_ => continue,
};
assert_scalar_close(actual, expected, 1e-12);
}
}
}
#[test]
fn test_axis_statistics() {
let reference_data = load_reference_data();
let stats_tests = &reference_data["statistical_operations"];
for (test_name, test_data) in stats_tests {
if test_name.contains("_axis") && !test_name.contains("_overall") {
let test_obj = test_data.as_object().unwrap();
let operation = test_obj["operation"].as_str().unwrap();
let input: SerializedArray =
serde_json::from_value(test_obj["input"].clone()).unwrap();
let axis = test_obj["axis"].as_u64().unwrap() as usize;
let input_array = deserialize_array(&input);
match operation {
"sum" => {
if let Ok(actual) = input_array.sum_axis(axis) {
let expected: SerializedArray =
serde_json::from_value(test_obj["result"].clone()).unwrap();
let expected_array = deserialize_array(&expected);
assert_arrays_close(&actual, &expected_array, 1e-12);
}
}
"mean" => {
if let Ok(actual) = input_array.mean_axis(Some(axis)) {
let expected: SerializedArray =
serde_json::from_value(test_obj["result"].clone()).unwrap();
let expected_array = deserialize_array(&expected);
assert_arrays_close(&actual, &expected_array, 1e-12);
}
}
_ => continue,
}
}
}
}
#[test]
fn test_percentiles() {
let reference_data = load_reference_data();
let stats_tests = &reference_data["statistical_operations"];
for (test_name, test_data) in stats_tests {
if test_name.starts_with("percentile_") {
let test_obj = test_data.as_object().unwrap();
let input: SerializedArray =
serde_json::from_value(test_obj["input"].clone()).unwrap();
let percentile = test_obj["percentile"].as_f64().unwrap() / 100.0; let expected = test_obj["result"].as_f64().unwrap();
let input_array = deserialize_array(&input);
let actual = input_array.percentile(percentile);
assert_scalar_close(actual, expected, 1e-12);
}
}
}
}
#[cfg(test)]
mod comparison_operations_tests {
use super::*;
#[test]
fn test_element_wise_comparisons() {
let reference_data = load_reference_data();
let comparison_tests = &reference_data["comparison_operations"];
for (test_name, test_data) in comparison_tests {
if test_name.ends_with("_arrays") {
let test_obj = test_data.as_object().unwrap();
let operation = test_obj["operation"].as_str().unwrap();
let input1: SerializedArray =
serde_json::from_value(test_obj["input1"].clone()).unwrap();
let input2: SerializedArray =
serde_json::from_value(test_obj["input2"].clone()).unwrap();
let expected: Vec<bool> = test_obj["result"]
.as_array()
.unwrap()
.iter()
.map(|v| v.as_bool().unwrap())
.collect();
let arr1 = deserialize_array(&input1);
let arr2 = deserialize_array(&input2);
let actual_result = match operation {
"greater" => greater(&arr1, &arr2).unwrap(),
"greater_equal" => greater_equal(&arr1, &arr2).unwrap(),
"less" => less(&arr1, &arr2).unwrap(),
"less_equal" => less_equal(&arr1, &arr2).unwrap(),
"equal" => equal(&arr1, &arr2).unwrap(),
"not_equal" => not_equal(&arr1, &arr2).unwrap(),
_ => continue,
};
let actual = actual_result.to_vec();
assert_eq!(
actual, expected,
"Comparison operation {} failed",
operation
);
}
}
}
#[test]
fn test_array_equality() {
let reference_data = load_reference_data();
let comparison_tests = &reference_data["comparison_operations"];
for (test_name, test_data) in comparison_tests {
if test_name.starts_with("array_equal_") {
let test_obj = test_data.as_object().unwrap();
let input1: SerializedArray =
serde_json::from_value(test_obj["input1"].clone()).unwrap();
let input2: SerializedArray =
serde_json::from_value(test_obj["input2"].clone()).unwrap();
let expected = test_obj["result"].as_bool().unwrap();
let arr1 = deserialize_array(&input1);
let arr2 = deserialize_array(&input2);
let actual = array_equal(&arr1, &arr2, None);
assert_eq!(actual, expected, "Array equality test {} failed", test_name);
}
}
}
#[test]
fn test_allclose() {
let reference_data = load_reference_data();
let comparison_tests = &reference_data["comparison_operations"];
for (test_name, test_data) in comparison_tests {
if test_name.starts_with("allclose_") {
let test_obj = test_data.as_object().unwrap();
let input1: SerializedArray =
serde_json::from_value(test_obj["input1"].clone()).unwrap();
let input2: SerializedArray =
serde_json::from_value(test_obj["input2"].clone()).unwrap();
let expected = test_obj["result"].as_bool().unwrap();
let arr1 = deserialize_array(&input1);
let arr2 = deserialize_array(&input2);
let actual = allclose(&arr1, &arr2);
assert_eq!(actual, expected, "Allclose test {} failed", test_name);
}
}
}
}
#[cfg(test)]
mod indexing_operations_tests {
use super::*;
#[test]
fn test_basic_indexing() {
let reference_data = load_reference_data();
let indexing_tests = &reference_data["indexing_operations"];
for (test_name, test_data) in indexing_tests {
if test_name.starts_with("get_") {
let test_obj = test_data.as_object().unwrap();
let input: SerializedArray =
serde_json::from_value(test_obj["input"].clone()).unwrap();
let indices: Vec<usize> = test_obj["indices"]
.as_array()
.unwrap()
.iter()
.map(|v| v.as_u64().unwrap() as usize)
.collect();
let expected = test_obj["result"].as_f64().unwrap();
let input_array = deserialize_array(&input);
let actual = input_array.get(&indices).unwrap();
assert_scalar_close(actual, expected, 1e-15);
}
}
}
#[test]
fn test_slicing() {
let reference_data = load_reference_data();
let indexing_tests = &reference_data["indexing_operations"];
for (test_name, test_data) in indexing_tests {
if test_name.starts_with("slice_") {
let test_obj = test_data.as_object().unwrap();
let input: SerializedArray =
serde_json::from_value(test_obj["input"].clone()).unwrap();
let expected: SerializedArray =
serde_json::from_value(test_obj["result"].clone()).unwrap();
let input_array = deserialize_array(&input);
let expected_array = deserialize_array(&expected);
if test_name.contains("row") {
let row = test_obj["row"].as_u64().unwrap() as usize;
let actual = input_array.slice(0, row).unwrap();
assert_arrays_close(&actual, &expected_array, 1e-15);
} else if test_name.contains("col") {
let col = test_obj["col"].as_u64().unwrap() as usize;
let actual = input_array.slice(1, col).unwrap();
assert_arrays_close(&actual, &expected_array, 1e-15);
}
}
}
}
}