use super::super::kernels;
use super::super::{CpuClient, CpuRuntime};
use crate::dispatch_dtype;
use crate::error::Result;
use crate::ops::{BinaryOp, Kernel};
use crate::runtime::{compute_broadcast_shape, validate_binary_dtypes};
use crate::tensor::Tensor;
pub fn binary_op_impl(
client: &CpuClient,
op: BinaryOp,
a: &Tensor<CpuRuntime>,
b: &Tensor<CpuRuntime>,
op_name: &'static str,
) -> Result<Tensor<CpuRuntime>> {
let dtype = validate_binary_dtypes(a, b)?;
let out_shape = compute_broadcast_shape(a, b)?;
let out = Tensor::<CpuRuntime>::empty(&out_shape, dtype, &client.device);
let out_ptr = out.ptr();
let same_shapes = a.shape() == b.shape() && a.shape() == out_shape.as_slice();
let both_contiguous = a.is_contiguous() && b.is_contiguous();
if same_shapes && both_contiguous {
let len = a.numel();
let a_ptr = a.ptr();
let b_ptr = b.ptr();
dispatch_dtype!(dtype, T => {
unsafe {
<CpuClient as Kernel<CpuRuntime>>::binary_op::<T>(
client, op,
a_ptr as *const T,
b_ptr as *const T,
out_ptr as *mut T,
len,
);
}
}, op_name);
} else {
let a_broadcast = a.broadcast_to(&out_shape)?;
let b_broadcast = b.broadcast_to(&out_shape)?;
let a_ptr = a_broadcast.ptr();
let b_ptr = b_broadcast.ptr();
let a_strides: Vec<isize> = a_broadcast.layout().strides().to_vec();
let b_strides: Vec<isize> = b_broadcast.layout().strides().to_vec();
dispatch_dtype!(dtype, T => {
unsafe {
kernels::binary_op_strided_kernel::<T>(
op,
a_ptr as *const T,
b_ptr as *const T,
out_ptr as *mut T,
&out_shape,
&a_strides,
&b_strides,
0,
0,
);
}
}, op_name);
}
Ok(out)
}