use super::super::kernels;
use super::super::{CpuClient, CpuRuntime};
use crate::dispatch_dtype;
use crate::error::{Error, Result};
use crate::ops::{BinaryOp, Kernel};
use crate::runtime::{compute_broadcast_shape, validate_binary_dtypes};
use crate::tensor::Tensor;
pub fn binary_op_impl(
client: &CpuClient,
op: BinaryOp,
a: &Tensor<CpuRuntime>,
b: &Tensor<CpuRuntime>,
op_name: &'static str,
) -> Result<Tensor<CpuRuntime>> {
let dtype = validate_binary_dtypes(a, b)?;
let out_shape = compute_broadcast_shape(a, b)?;
let out = Tensor::<CpuRuntime>::empty(&out_shape, dtype, &client.device);
write_binary_into(client, op, a, b, &out, &out_shape, op_name)?;
Ok(out)
}
pub fn binary_op_into_impl(
client: &CpuClient,
op: BinaryOp,
out: &Tensor<CpuRuntime>,
a: &Tensor<CpuRuntime>,
b: &Tensor<CpuRuntime>,
op_name: &'static str,
) -> Result<()> {
let dtype = validate_binary_dtypes(a, b)?;
let out_shape = compute_broadcast_shape(a, b)?;
if out.dtype() != dtype {
return Err(Error::DTypeMismatch {
lhs: dtype,
rhs: out.dtype(),
});
}
if out.shape() != out_shape.as_slice() {
return Err(Error::ShapeMismatch {
expected: out_shape,
got: out.shape().to_vec(),
});
}
if !out.is_contiguous() {
return Err(Error::Backend(
"binary_op_into: destination tensor must be contiguous".into(),
));
}
write_binary_into(client, op, a, b, out, &out_shape, op_name)
}
fn write_binary_into(
client: &CpuClient,
op: BinaryOp,
a: &Tensor<CpuRuntime>,
b: &Tensor<CpuRuntime>,
out: &Tensor<CpuRuntime>,
out_shape: &[usize],
op_name: &'static str,
) -> Result<()> {
let dtype = out.dtype();
let out_ptr = out.ptr();
let same_shapes = a.shape() == b.shape() && a.shape() == out_shape;
let both_contiguous = a.is_contiguous() && b.is_contiguous();
if same_shapes && both_contiguous {
let len = a.numel();
let a_ptr = a.ptr();
let b_ptr = b.ptr();
dispatch_dtype!(dtype, T => {
unsafe {
<CpuClient as Kernel<CpuRuntime>>::binary_op::<T>(
client, op,
a_ptr as *const T,
b_ptr as *const T,
out_ptr as *mut T,
len,
);
}
}, op_name);
} else {
let a_broadcast = a.broadcast_to(out_shape)?;
let b_broadcast = b.broadcast_to(out_shape)?;
let a_ptr = a_broadcast.ptr();
let b_ptr = b_broadcast.ptr();
let a_strides: Vec<isize> = a_broadcast.layout().strides().to_vec();
let b_strides: Vec<isize> = b_broadcast.layout().strides().to_vec();
dispatch_dtype!(dtype, T => {
unsafe {
kernels::binary_op_strided_kernel::<T>(
op,
a_ptr as *const T,
b_ptr as *const T,
out_ptr as *mut T,
out_shape,
&a_strides,
&b_strides,
0,
0,
);
}
}, op_name);
}
Ok(())
}