numr 0.5.2

High-performance numerical computing with multi-backend GPU acceleration (CPU/CUDA/WebGPU)
Documentation
// Backend parity tests for BinaryOps trait
//
// Dtype-parameterized: each test runs for all supported dtypes (F32, F64, F16, BF16, FP8).
// Tensors are created in f64 then cast to target dtype via tensor_from_f64().
// Comparison reads back in native dtype - no unnecessary f64 conversion.

use numr::dtype::DType;
use numr::ops::BinaryOps;
use numr::runtime::Runtime;
use numr::tensor::Tensor;

use crate::backend_parity::dtype_helpers::tensor_from_f64;
#[cfg(feature = "cuda")]
use crate::backend_parity::helpers::with_cuda_backend;
#[cfg(feature = "wgpu")]
use crate::backend_parity::helpers::with_wgpu_backend;
use crate::common::{
    assert_tensor_allclose, create_cpu_client, is_dtype_supported, supported_dtypes,
};

#[derive(Clone, Copy, Debug)]
enum BinaryOp {
    Add,
    Sub,
    Mul,
    Div,
    Pow,
    Maximum,
    Minimum,
    Atan2,
}

#[derive(Clone)]
struct TestCase {
    a: Vec<f64>,
    a_shape: Vec<usize>,
    b: Vec<f64>,
    b_shape: Vec<usize>,
}

impl TestCase {
    fn new(a: Vec<f64>, a_shape: Vec<usize>, b: Vec<f64>, b_shape: Vec<usize>) -> Self {
        Self {
            a,
            a_shape,
            b,
            b_shape,
        }
    }
}

fn apply_binary_op<R: Runtime>(
    client: &impl BinaryOps<R>,
    op: BinaryOp,
    a: &Tensor<R>,
    b: &Tensor<R>,
) -> numr::error::Result<Tensor<R>> {
    match op {
        BinaryOp::Add => client.add(a, b),
        BinaryOp::Sub => client.sub(a, b),
        BinaryOp::Mul => client.mul(a, b),
        BinaryOp::Div => client.div(a, b),
        BinaryOp::Pow => client.pow(a, b),
        BinaryOp::Maximum => client.maximum(a, b),
        BinaryOp::Minimum => client.minimum(a, b),
        BinaryOp::Atan2 => client.atan2(a, b),
    }
}

fn test_binary_parity(op: BinaryOp, test_cases: &[TestCase], dtype: DType) {
    let (cpu_client, cpu_device) = create_cpu_client();

    // Compute CPU baseline results (kept as tensors for native comparison)
    let cpu_results: Vec<Tensor<numr::runtime::cpu::CpuRuntime>> = test_cases
        .iter()
        .map(|tc| {
            let a = tensor_from_f64(&tc.a, &tc.a_shape, dtype, &cpu_device, &cpu_client)
                .unwrap_or_else(|e| panic!("CPU tensor_from_f64 failed for {dtype:?}: {e}"));
            let b = tensor_from_f64(&tc.b, &tc.b_shape, dtype, &cpu_device, &cpu_client)
                .unwrap_or_else(|e| panic!("CPU tensor_from_f64 failed for {dtype:?}: {e}"));

            apply_binary_op(&cpu_client, op, &a, &b)
                .unwrap_or_else(|e| panic!("CPU {op:?} failed for {dtype:?}: {e}"))
        })
        .collect();

    #[cfg(feature = "cuda")]
    if is_dtype_supported("cuda", dtype) {
        with_cuda_backend(|cuda_client, cuda_device| {
            for (idx, tc) in test_cases.iter().enumerate() {
                let a = tensor_from_f64(&tc.a, &tc.a_shape, dtype, &cuda_device, &cuda_client)
                    .unwrap_or_else(|e| panic!("CUDA tensor_from_f64 failed for {dtype:?}: {e}"));
                let b = tensor_from_f64(&tc.b, &tc.b_shape, dtype, &cuda_device, &cuda_client)
                    .unwrap_or_else(|e| panic!("CUDA tensor_from_f64 failed for {dtype:?}: {e}"));

                let result = apply_binary_op(&cuda_client, op, &a, &b)
                    .unwrap_or_else(|e| panic!("CUDA {op:?} failed for {dtype:?}: {e}"));

                assert_tensor_allclose(
                    &result,
                    &cpu_results[idx],
                    dtype,
                    &format!("{op:?} CUDA vs CPU [{dtype:?}] case {idx}"),
                );
            }
        });
    }

    #[cfg(feature = "wgpu")]
    if is_dtype_supported("wgpu", dtype) {
        with_wgpu_backend(|wgpu_client, wgpu_device| {
            for (idx, tc) in test_cases.iter().enumerate() {
                let a = tensor_from_f64(&tc.a, &tc.a_shape, dtype, &wgpu_device, &wgpu_client)
                    .unwrap_or_else(|e| panic!("WebGPU tensor_from_f64 failed for {dtype:?}: {e}"));
                let b = tensor_from_f64(&tc.b, &tc.b_shape, dtype, &wgpu_device, &wgpu_client)
                    .unwrap_or_else(|e| panic!("WebGPU tensor_from_f64 failed for {dtype:?}: {e}"));

                let result = apply_binary_op(&wgpu_client, op, &a, &b)
                    .unwrap_or_else(|e| panic!("WebGPU {op:?} failed for {dtype:?}: {e}"));

                assert_tensor_allclose(
                    &result,
                    &cpu_results[idx],
                    dtype,
                    &format!("{op:?} WebGPU vs CPU [{dtype:?}] case {idx}"),
                );
            }
        });
    }
}

macro_rules! binary_case {
    ($name:ident, $op:expr, $cases:expr) => {
        #[test]
        fn $name() {
            for dtype in supported_dtypes("cpu") {
                test_binary_parity($op, $cases, dtype);
            }
        }
    };
}

binary_case!(
    test_add_parity,
    BinaryOp::Add,
    &[
        TestCase::new(
            vec![1.0, 2.0, 3.0, 4.0],
            vec![4],
            vec![5.0, 6.0, 7.0, 8.0],
            vec![4]
        ),
        TestCase::new(
            vec![1.0, 2.0, 3.0, 4.0],
            vec![2, 2],
            vec![0.5, 0.5, 0.5, 0.5],
            vec![2, 2]
        ),
        TestCase::new(vec![1.0, 2.0, 3.0, 4.0], vec![4], vec![10.0], vec![1]),
        TestCase::new(vec![1.0, 2.0, 3.0, 4.0], vec![4], vec![5.0], vec![]),
    ]
);

binary_case!(
    test_sub_parity,
    BinaryOp::Sub,
    &[
        TestCase::new(
            vec![5.0, 6.0, 7.0, 8.0],
            vec![4],
            vec![1.0, 2.0, 3.0, 4.0],
            vec![4]
        ),
        TestCase::new(
            vec![10.0, 20.0, 30.0, 40.0],
            vec![2, 2],
            vec![1.0, 1.0, 1.0, 1.0],
            vec![2, 2]
        ),
    ]
);

binary_case!(
    test_mul_parity,
    BinaryOp::Mul,
    &[
        TestCase::new(
            vec![1.0, 2.0, 3.0, 4.0],
            vec![4],
            vec![2.0, 3.0, 4.0, 5.0],
            vec![4]
        ),
        TestCase::new(
            vec![0.5, 1.5, 2.5, 3.5],
            vec![2, 2],
            vec![2.0, 2.0, 2.0, 2.0],
            vec![2, 2]
        ),
        TestCase::new(vec![1.0, 2.0, 3.0, 4.0], vec![4], vec![2.0], vec![]),
    ]
);

binary_case!(
    test_div_parity,
    BinaryOp::Div,
    &[
        TestCase::new(
            vec![10.0, 20.0, 30.0, 40.0],
            vec![4],
            vec![2.0, 4.0, 5.0, 8.0],
            vec![4]
        ),
        TestCase::new(
            vec![100.0, 200.0, 300.0, 400.0],
            vec![2, 2],
            vec![2.0, 4.0, 5.0, 8.0],
            vec![2, 2],
        ),
    ]
);

binary_case!(
    test_pow_parity,
    BinaryOp::Pow,
    &[
        TestCase::new(
            vec![2.0, 3.0, 4.0, 5.0],
            vec![4],
            vec![2.0, 2.0, 2.0, 2.0],
            vec![4]
        ),
        TestCase::new(
            vec![2.0, 3.0, 4.0, 5.0],
            vec![2, 2],
            vec![0.0, 1.0, 2.0, 3.0],
            vec![2, 2]
        ),
    ]
);

binary_case!(
    test_maximum_parity,
    BinaryOp::Maximum,
    &[
        TestCase::new(
            vec![1.0, 5.0, 3.0, 2.0],
            vec![4],
            vec![3.0, 2.0, 5.0, 1.0],
            vec![4]
        ),
        TestCase::new(
            vec![10.0, 20.0, 30.0, 40.0],
            vec![2, 2],
            vec![15.0, 15.0, 15.0, 15.0],
            vec![2, 2],
        ),
    ]
);

binary_case!(
    test_minimum_parity,
    BinaryOp::Minimum,
    &[
        TestCase::new(
            vec![1.0, 5.0, 3.0, 2.0],
            vec![4],
            vec![3.0, 2.0, 5.0, 1.0],
            vec![4]
        ),
        TestCase::new(
            vec![10.0, 20.0, 30.0, 40.0],
            vec![2, 2],
            vec![15.0, 15.0, 15.0, 15.0],
            vec![2, 2],
        ),
    ]
);

binary_case!(
    test_atan2_parity,
    BinaryOp::Atan2,
    &[
        TestCase::new(
            vec![0.0, 1.0, 1.0, 0.0],
            vec![4],
            vec![1.0, 0.0, 1.0, 1.0],
            vec![4]
        ),
        TestCase::new(
            vec![1.0, -1.0, -1.0, 1.0],
            vec![2, 2],
            vec![1.0, 1.0, -1.0, -1.0],
            vec![2, 2]
        ),
    ]
);