use numr::dtype::DType;
use numr::ops::BinaryOps;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
use crate::backend_parity::dtype_helpers::tensor_from_f64;
#[cfg(feature = "cuda")]
use crate::backend_parity::helpers::with_cuda_backend;
#[cfg(feature = "wgpu")]
use crate::backend_parity::helpers::with_wgpu_backend;
use crate::common::{
assert_tensor_allclose, create_cpu_client, is_dtype_supported, supported_dtypes,
};
#[derive(Clone, Copy, Debug)]
enum BinaryOp {
Add,
Sub,
Mul,
Div,
Pow,
Maximum,
Minimum,
Atan2,
}
#[derive(Clone)]
struct TestCase {
a: Vec<f64>,
a_shape: Vec<usize>,
b: Vec<f64>,
b_shape: Vec<usize>,
}
impl TestCase {
fn new(a: Vec<f64>, a_shape: Vec<usize>, b: Vec<f64>, b_shape: Vec<usize>) -> Self {
Self {
a,
a_shape,
b,
b_shape,
}
}
}
fn apply_binary_op<R: Runtime>(
client: &impl BinaryOps<R>,
op: BinaryOp,
a: &Tensor<R>,
b: &Tensor<R>,
) -> numr::error::Result<Tensor<R>> {
match op {
BinaryOp::Add => client.add(a, b),
BinaryOp::Sub => client.sub(a, b),
BinaryOp::Mul => client.mul(a, b),
BinaryOp::Div => client.div(a, b),
BinaryOp::Pow => client.pow(a, b),
BinaryOp::Maximum => client.maximum(a, b),
BinaryOp::Minimum => client.minimum(a, b),
BinaryOp::Atan2 => client.atan2(a, b),
}
}
fn test_binary_parity(op: BinaryOp, test_cases: &[TestCase], dtype: DType) {
let (cpu_client, cpu_device) = create_cpu_client();
let cpu_results: Vec<Tensor<numr::runtime::cpu::CpuRuntime>> = test_cases
.iter()
.map(|tc| {
let a = tensor_from_f64(&tc.a, &tc.a_shape, dtype, &cpu_device, &cpu_client)
.unwrap_or_else(|e| panic!("CPU tensor_from_f64 failed for {dtype:?}: {e}"));
let b = tensor_from_f64(&tc.b, &tc.b_shape, dtype, &cpu_device, &cpu_client)
.unwrap_or_else(|e| panic!("CPU tensor_from_f64 failed for {dtype:?}: {e}"));
apply_binary_op(&cpu_client, op, &a, &b)
.unwrap_or_else(|e| panic!("CPU {op:?} failed for {dtype:?}: {e}"))
})
.collect();
#[cfg(feature = "cuda")]
if is_dtype_supported("cuda", dtype) {
with_cuda_backend(|cuda_client, cuda_device| {
for (idx, tc) in test_cases.iter().enumerate() {
let a = tensor_from_f64(&tc.a, &tc.a_shape, dtype, &cuda_device, &cuda_client)
.unwrap_or_else(|e| panic!("CUDA tensor_from_f64 failed for {dtype:?}: {e}"));
let b = tensor_from_f64(&tc.b, &tc.b_shape, dtype, &cuda_device, &cuda_client)
.unwrap_or_else(|e| panic!("CUDA tensor_from_f64 failed for {dtype:?}: {e}"));
let result = apply_binary_op(&cuda_client, op, &a, &b)
.unwrap_or_else(|e| panic!("CUDA {op:?} failed for {dtype:?}: {e}"));
assert_tensor_allclose(
&result,
&cpu_results[idx],
dtype,
&format!("{op:?} CUDA vs CPU [{dtype:?}] case {idx}"),
);
}
});
}
#[cfg(feature = "wgpu")]
if is_dtype_supported("wgpu", dtype) {
with_wgpu_backend(|wgpu_client, wgpu_device| {
for (idx, tc) in test_cases.iter().enumerate() {
let a = tensor_from_f64(&tc.a, &tc.a_shape, dtype, &wgpu_device, &wgpu_client)
.unwrap_or_else(|e| panic!("WebGPU tensor_from_f64 failed for {dtype:?}: {e}"));
let b = tensor_from_f64(&tc.b, &tc.b_shape, dtype, &wgpu_device, &wgpu_client)
.unwrap_or_else(|e| panic!("WebGPU tensor_from_f64 failed for {dtype:?}: {e}"));
let result = apply_binary_op(&wgpu_client, op, &a, &b)
.unwrap_or_else(|e| panic!("WebGPU {op:?} failed for {dtype:?}: {e}"));
assert_tensor_allclose(
&result,
&cpu_results[idx],
dtype,
&format!("{op:?} WebGPU vs CPU [{dtype:?}] case {idx}"),
);
}
});
}
}
macro_rules! binary_case {
($name:ident, $op:expr, $cases:expr) => {
#[test]
fn $name() {
for dtype in supported_dtypes("cpu") {
test_binary_parity($op, $cases, dtype);
}
}
};
}
binary_case!(
test_add_parity,
BinaryOp::Add,
&[
TestCase::new(
vec![1.0, 2.0, 3.0, 4.0],
vec![4],
vec![5.0, 6.0, 7.0, 8.0],
vec![4]
),
TestCase::new(
vec![1.0, 2.0, 3.0, 4.0],
vec![2, 2],
vec![0.5, 0.5, 0.5, 0.5],
vec![2, 2]
),
TestCase::new(vec![1.0, 2.0, 3.0, 4.0], vec![4], vec![10.0], vec![1]),
TestCase::new(vec![1.0, 2.0, 3.0, 4.0], vec![4], vec![5.0], vec![]),
]
);
binary_case!(
test_sub_parity,
BinaryOp::Sub,
&[
TestCase::new(
vec![5.0, 6.0, 7.0, 8.0],
vec![4],
vec![1.0, 2.0, 3.0, 4.0],
vec![4]
),
TestCase::new(
vec![10.0, 20.0, 30.0, 40.0],
vec![2, 2],
vec![1.0, 1.0, 1.0, 1.0],
vec![2, 2]
),
]
);
binary_case!(
test_mul_parity,
BinaryOp::Mul,
&[
TestCase::new(
vec![1.0, 2.0, 3.0, 4.0],
vec![4],
vec![2.0, 3.0, 4.0, 5.0],
vec![4]
),
TestCase::new(
vec![0.5, 1.5, 2.5, 3.5],
vec![2, 2],
vec![2.0, 2.0, 2.0, 2.0],
vec![2, 2]
),
TestCase::new(vec![1.0, 2.0, 3.0, 4.0], vec![4], vec![2.0], vec![]),
]
);
binary_case!(
test_div_parity,
BinaryOp::Div,
&[
TestCase::new(
vec![10.0, 20.0, 30.0, 40.0],
vec![4],
vec![2.0, 4.0, 5.0, 8.0],
vec![4]
),
TestCase::new(
vec![100.0, 200.0, 300.0, 400.0],
vec![2, 2],
vec![2.0, 4.0, 5.0, 8.0],
vec![2, 2],
),
]
);
binary_case!(
test_pow_parity,
BinaryOp::Pow,
&[
TestCase::new(
vec![2.0, 3.0, 4.0, 5.0],
vec![4],
vec![2.0, 2.0, 2.0, 2.0],
vec![4]
),
TestCase::new(
vec![2.0, 3.0, 4.0, 5.0],
vec![2, 2],
vec![0.0, 1.0, 2.0, 3.0],
vec![2, 2]
),
]
);
binary_case!(
test_maximum_parity,
BinaryOp::Maximum,
&[
TestCase::new(
vec![1.0, 5.0, 3.0, 2.0],
vec![4],
vec![3.0, 2.0, 5.0, 1.0],
vec![4]
),
TestCase::new(
vec![10.0, 20.0, 30.0, 40.0],
vec![2, 2],
vec![15.0, 15.0, 15.0, 15.0],
vec![2, 2],
),
]
);
binary_case!(
test_minimum_parity,
BinaryOp::Minimum,
&[
TestCase::new(
vec![1.0, 5.0, 3.0, 2.0],
vec![4],
vec![3.0, 2.0, 5.0, 1.0],
vec![4]
),
TestCase::new(
vec![10.0, 20.0, 30.0, 40.0],
vec![2, 2],
vec![15.0, 15.0, 15.0, 15.0],
vec![2, 2],
),
]
);
binary_case!(
test_atan2_parity,
BinaryOp::Atan2,
&[
TestCase::new(
vec![0.0, 1.0, 1.0, 0.0],
vec![4],
vec![1.0, 0.0, 1.0, 1.0],
vec![4]
),
TestCase::new(
vec![1.0, -1.0, -1.0, 1.0],
vec![2, 2],
vec![1.0, 1.0, -1.0, -1.0],
vec![2, 2]
),
]
);