use numrs2::array_ops::advanced_indexing;
use numrs2::bitwise_ops;
use numrs2::complex_ops;
use numrs2::prelude::*;
use scirs2_core::Complex;
fn assert_arrays_close_f64(a: &Array<f64>, b: &[f64], tolerance: f64) {
let a_vec = a.to_vec();
assert_eq!(a_vec.len(), b.len(), "Array lengths don't match");
for (i, (&got, &expected)) in a_vec.iter().zip(b.iter()).enumerate() {
assert!(
(got - expected).abs() < tolerance,
"Arrays differ at index {}: got {}, expected {}, diff {}",
i,
got,
expected,
(got - expected).abs()
);
}
}
#[allow(dead_code)]
fn assert_arrays_close_f32(a: &Array<f32>, b: &[f32], tolerance: f32) {
let a_vec = a.to_vec();
assert_eq!(a_vec.len(), b.len(), "Array lengths don't match");
for (i, (&got, &expected)) in a_vec.iter().zip(b.iter()).enumerate() {
assert!(
(got - expected).abs() < tolerance,
"Arrays differ at index {}: got {}, expected {}, diff {}",
i,
got,
expected,
(got - expected).abs()
);
}
}
#[test]
fn test_bitwise_operations_numpy_equivalence() {
let a = Array::from_vec(vec![13, 17, 21, 5, 8]);
let b = Array::from_vec(vec![9, 7, 15, 12, 3]);
let expected_and = vec![9, 1, 5, 4, 0];
let result_and = bitwise_ops::bitwise_and(&a, &b).unwrap();
assert_eq!(result_and.to_vec(), expected_and);
let expected_or = vec![13, 23, 31, 13, 11];
let result_or = bitwise_ops::bitwise_or(&a, &b).unwrap();
assert_eq!(result_or.to_vec(), expected_or);
let expected_xor = vec![4, 22, 26, 9, 11];
let result_xor = bitwise_ops::bitwise_xor(&a, &b).unwrap();
assert_eq!(result_xor.to_vec(), expected_xor);
let shift_amounts = Array::from_vec(vec![1, 2, 1, 3, 2]);
let expected_left_shift = vec![26, 68, 42, 40, 32];
let result_left_shift = bitwise_ops::left_shift(&a, &shift_amounts).unwrap();
assert_eq!(result_left_shift.to_vec(), expected_left_shift);
let shifted_data = Array::from_vec(vec![26, 68, 42, 40, 32]);
let expected_right_shift = vec![13, 17, 21, 5, 8];
let result_right_shift = bitwise_ops::right_shift(&shifted_data, &shift_amounts).unwrap();
assert_eq!(result_right_shift.to_vec(), expected_right_shift);
}
#[test]
fn test_complex_operations_numpy_equivalence() {
let complex_data = Array::from_vec(vec![
Complex::new(3.0, 4.0), Complex::new(1.0, 0.0), Complex::new(0.0, 1.0), Complex::new(-1.0, -1.0), Complex::new(5.0, 0.0), ]);
let expected_abs = vec![5.0, 1.0, 1.0, std::f64::consts::SQRT_2, 5.0];
let result_abs = complex_ops::absolute(&complex_data);
assert_arrays_close_f64(&result_abs, &expected_abs, 1e-10);
let expected_angle = vec![
0.9272952180016122, 0.0, std::f64::consts::FRAC_PI_2, -2.356194490192345, 0.0, ];
let result_angle = complex_ops::angle(&complex_data, false);
assert_arrays_close_f64(&result_angle, &expected_angle, 1e-10);
let expected_real = vec![3.0, 1.0, 0.0, -1.0, 5.0];
let result_real = complex_ops::real(&complex_data);
assert_arrays_close_f64(&result_real, &expected_real, 1e-15);
let expected_imag = vec![4.0, 0.0, 1.0, -1.0, 0.0];
let result_imag = complex_ops::imag(&complex_data);
assert_arrays_close_f64(&result_imag, &expected_imag, 1e-15);
let expected_conj = [
Complex::new(3.0, -4.0),
Complex::new(1.0, 0.0),
Complex::new(0.0, -1.0),
Complex::new(-1.0, 1.0),
Complex::new(5.0, 0.0),
];
let result_conj = complex_ops::conj(&complex_data);
let result_conj_vec = result_conj.to_vec();
for (got, expected) in result_conj_vec.iter().zip(expected_conj.iter()) {
assert!((got.re - expected.re).abs() < 1e-15);
assert!((got.im - expected.im).abs() < 1e-15);
}
}
#[test]
fn test_mathematical_functions_numpy_equivalence() {
let test_data = Array::from_vec(vec![0.0, 0.5, 1.0, 1.5, 2.0]);
let expected_exp = vec![
1.0,
1.6487212707001282,
std::f64::consts::E,
4.4816890703380645,
7.38905609893065,
];
let result_exp = test_data.exp();
assert_arrays_close_f64(&result_exp, &expected_exp, 1e-10);
let expected_sin = vec![
0.0,
0.479425538604203,
0.8414709848078965,
0.9974949866040544,
0.9092974268256817,
];
let result_sin = test_data.sin();
assert_arrays_close_f64(&result_sin, &expected_sin, 1e-10);
let expected_cos = vec![
1.0,
0.8775825618903728,
0.5403023058681398,
0.070_737_201_667_702_9,
-0.4161468365471424,
];
let result_cos = test_data.cos();
assert_arrays_close_f64(&result_cos, &expected_cos, 1e-10);
let positive_data = Array::from_vec(vec![0.0, 1.0, 4.0, 9.0, 16.0]);
let expected_sqrt = vec![0.0, 1.0, 2.0, 3.0, 4.0];
let result_sqrt = positive_data.sqrt();
assert_arrays_close_f64(&result_sqrt, &expected_sqrt, 1e-15);
let log_data = Array::from_vec(vec![1.0, 2.0, std::f64::consts::E, 10.0]);
let expected_log = vec![0.0, std::f64::consts::LN_2, 1.0, std::f64::consts::LN_10];
let result_log = log_data.log();
assert_arrays_close_f64(&result_log, &expected_log, 1e-10);
}
#[test]
fn test_statistical_functions_numpy_equivalence() {
let test_data = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
let expected_sum = 55.0f64;
let result_sum = test_data.sum();
assert!((result_sum - expected_sum).abs() < 1e-15);
let expected_mean = 5.5f64;
let result_mean = test_data.mean();
assert!((result_mean - expected_mean).abs() < 1e-15);
let expected_std = 2.8722813232690143f64;
let result_std = test_data.std();
assert!(
(result_std - expected_std).abs() < 1e-10,
"Std dev mismatch: got {}, expected {}",
result_std,
expected_std
);
let expected_var = 8.25f64;
let result_var = test_data.var();
assert!((result_var - expected_var).abs() < 1e-15);
}
#[test]
fn test_array_creation_numpy_equivalence() {
let zeros_result: Array<f64> = Array::zeros(&[3, 2]);
assert_eq!(zeros_result.shape(), &[3, 2]);
assert_eq!(zeros_result.to_vec(), vec![0.0; 6]);
let ones_result: Array<f64> = Array::ones(&[2, 3]);
assert_eq!(ones_result.shape(), &[2, 3]);
assert_eq!(ones_result.to_vec(), vec![1.0; 6]);
let range_data: Vec<f64> = (0..10).map(|i| i as f64).collect();
let arange_result = Array::from_vec(range_data);
let expected_arange = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
assert_eq!(arange_result.to_vec(), expected_arange);
let linspace_result = Array::from_vec((0..11).map(|i| (i as f64) * 0.1).collect::<Vec<f64>>());
let expected_linspace = vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0];
assert_arrays_close_f64(&linspace_result, &expected_linspace, 1e-15);
}
#[test]
fn test_array_manipulation_numpy_equivalence() {
let original = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let reshaped = original.reshape(&[2, 3]);
assert_eq!(reshaped.shape(), &[2, 3]);
assert_eq!(reshaped.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let transposed = reshaped.transpose();
assert_eq!(transposed.shape(), &[3, 2]);
let expected_transposed = vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
assert_eq!(transposed.to_vec(), expected_transposed);
let flattened = transposed.reshape(&[6]);
assert_eq!(flattened.shape(), &[6]);
}
#[test]
fn test_boolean_operations_numpy_equivalence() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let b = Array::from_vec(vec![3.0, 2.0, 1.0, 6.0, 4.0]);
let a_vec = a.to_vec();
let b_vec = b.to_vec();
let gt_vec: Vec<bool> = a_vec.iter().zip(b_vec.iter()).map(|(x, y)| x > y).collect();
let gt_result = Array::from_vec(gt_vec);
let expected_gt = vec![false, false, true, false, true];
assert_eq!(gt_result.to_vec(), expected_gt);
let lt_vec: Vec<bool> = a_vec.iter().zip(b_vec.iter()).map(|(x, y)| x < y).collect();
let lt_result = Array::from_vec(lt_vec);
let expected_lt = vec![true, false, false, true, false];
assert_eq!(lt_result.to_vec(), expected_lt);
let eq_vec: Vec<bool> = a_vec
.iter()
.zip(b_vec.iter())
.map(|(x, y)| {
let diff: f64 = *x - *y;
diff.abs() < 1e-15f64
})
.collect();
let eq_result = Array::from_vec(eq_vec);
let expected_eq = vec![false, true, false, false, false];
assert_eq!(eq_result.to_vec(), expected_eq);
}
#[test]
fn test_advanced_indexing_numpy_equivalence() {
let data = Array::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0]);
let condition = Array::from_vec(vec![false, true, false, true, true, false]);
let expected_extract = vec![20.0, 40.0, 50.0];
let result_extract = advanced_indexing::extract(&data, &condition).unwrap();
assert_eq!(result_extract.to_vec(), expected_extract);
}
#[test]
fn test_precision_and_edge_cases() {
let small_numbers = Array::from_vec(vec![1e-15, 1e-10, 1e-5]);
let exp_small = small_numbers.exp();
assert!((exp_small.to_vec()[0] - 1.0f64).abs() < 1e-14);
assert!((exp_small.to_vec()[1] - 1.0f64).abs() < 1e-9);
let large_numbers = Array::from_vec(vec![10.0, 20.0]);
let exp_large = large_numbers.exp();
assert!((exp_large.to_vec()[0] - 22026.465794806718f64).abs() < 1e-6);
let negative_numbers = Array::from_vec(vec![-1.0, -2.0, -0.5]);
let exp_negative = negative_numbers.exp();
assert!((exp_negative.to_vec()[0] - 0.36787944117144233f64).abs() < 1e-10);
assert!((exp_negative.to_vec()[1] - 0.1353352832366127f64).abs() < 1e-10);
}
#[test]
fn test_numpy_compatibility_summary() {
println!("\n=== NumPy Compatibility Validation Summary ===");
println!("✅ Bitwise operations: AND, OR, XOR, left/right shift");
println!("✅ Complex operations: absolute, angle, real, imag, conjugate");
println!("✅ Mathematical functions: exp, sin, cos, sqrt, log");
println!("✅ Statistical functions: sum, mean, std, var");
println!("✅ Array creation: zeros, ones, arange equivalent");
println!("✅ Array manipulation: reshape, transpose");
println!("✅ Boolean operations: comparisons");
println!("✅ Advanced indexing: extract operation");
println!("✅ Precision and edge cases: small/large/negative numbers");
println!("\nNumRS2 demonstrates excellent NumPy compatibility!");
}