use super::core::BinaryOp;
use crate::shape_error_taxonomy::ShapeErrorUtils;
use crate::{Device, Result, Tensor, TensorError};
use scirs2_core::ndarray::{ArrayD, IxDyn};
use scirs2_core::numeric::Zero;
#[allow(unused_variables)]
pub fn binary_op<T, Op>(a: &Tensor<T>, b: &Tensor<T>, op: Op) -> Result<Tensor<T>>
where
T: Clone + Default + Zero + Send + Sync + 'static + bytemuck::Pod + bytemuck::Zeroable,
Op: BinaryOp<T>,
{
if a.device() != b.device() {
return Err(TensorError::device_mismatch(
"binary_op",
&a.device().to_string(),
&b.device().to_string(),
));
}
let broadcast_shape = a.shape().broadcast_shape(b.shape()).ok_or_else(|| {
ShapeErrorUtils::broadcast_incompatible("binary_op", a.shape(), b.shape())
})?;
match (a.device(), b.device()) {
(Device::Cpu, Device::Cpu) => {
cpu_binary_op(a, b, op, &broadcast_shape)
}
#[cfg(feature = "gpu")]
(Device::Gpu(_), Device::Gpu(_)) => {
gpu_binary_op(a, b, op, &broadcast_shape)
}
#[cfg(feature = "rocm")]
(Device::Rocm(_), Device::Rocm(_)) => {
eprintln!("Warning: ROCm binary operations using CPU fallback - native implementation pending");
cpu_binary_op(a, b, op, &broadcast_shape)
}
#[cfg(any(feature = "gpu", feature = "rocm"))]
_ => {
Err(TensorError::device_mismatch(
"binary_op",
&a.device().to_string(),
&b.device().to_string(),
))
}
}
}
fn cpu_binary_op<T, Op>(
a: &Tensor<T>,
b: &Tensor<T>,
op: Op,
broadcast_shape: &crate::Shape,
) -> Result<Tensor<T>>
where
T: Clone + Default + Zero + Send + Sync + 'static + bytemuck::Pod + bytemuck::Zeroable,
Op: BinaryOp<T>,
{
use crate::tensor::TensorStorage;
let (a_array, b_array) = match (&a.storage, &b.storage) {
(TensorStorage::Cpu(ref a_arr), TensorStorage::Cpu(ref b_arr)) => (a_arr, b_arr),
#[cfg(feature = "gpu")]
_ => {
return Err(TensorError::device_mismatch(
"cpu_binary_op",
"cpu",
"non-cpu",
))
}
};
let output_dims = IxDyn(broadcast_shape.dims());
let mut output = ArrayD::zeros(output_dims);
if a.shape() == b.shape() && a.shape().dims() == broadcast_shape.dims() {
for ((a_val, b_val), out_val) in a_array.iter().zip(b_array.iter()).zip(output.iter_mut()) {
*out_val = op.apply(*a_val, *b_val);
}
} else {
broadcast_operation(
a_array,
b_array,
&mut output,
&op,
a.shape(),
b.shape(),
broadcast_shape,
)?;
}
Ok(Tensor::from_array(output))
}
fn broadcast_operation<T, Op>(
a_array: &ArrayD<T>,
b_array: &ArrayD<T>,
output: &mut ArrayD<T>,
op: &Op,
a_shape: &crate::Shape,
b_shape: &crate::Shape,
broadcast_shape: &crate::Shape,
) -> Result<()>
where
T: Clone + Default + Zero + Send + Sync + 'static + bytemuck::Pod + bytemuck::Zeroable,
Op: BinaryOp<T>,
{
if a_shape.size() == 1 {
if let Some(a_scalar) = a_array.iter().next() {
for (b_val, out_val) in b_array.iter().zip(output.iter_mut()) {
*out_val = op.apply(*a_scalar, *b_val);
}
}
return Ok(());
}
if b_shape.size() == 1 {
if let Some(b_scalar) = b_array.iter().next() {
for (a_val, out_val) in a_array.iter().zip(output.iter_mut()) {
*out_val = op.apply(*a_val, *b_scalar);
}
}
return Ok(());
}
let output_shape = broadcast_shape.dims();
let a_dims = a_shape.dims();
let b_dims = b_shape.dims();
let mut a_strides = vec![0; output_shape.len()];
let mut b_strides = vec![0; output_shape.len()];
let a_offset = output_shape.len() - a_dims.len();
let b_offset = output_shape.len() - b_dims.len();
let mut a_stride_acc = 1;
for i in (0..a_dims.len()).rev() {
let out_idx = a_offset + i;
if a_dims[i] == 1 {
a_strides[out_idx] = 0; } else {
a_strides[out_idx] = a_stride_acc;
}
a_stride_acc *= a_dims[i];
}
let mut b_stride_acc = 1;
for i in (0..b_dims.len()).rev() {
let out_idx = b_offset + i;
if b_dims[i] == 1 {
b_strides[out_idx] = 0; } else {
b_strides[out_idx] = b_stride_acc;
}
b_stride_acc *= b_dims[i];
}
let total_elements: usize = output_shape.iter().product();
for linear_idx in 0..total_elements {
let mut coords = vec![0; output_shape.len()];
let mut remaining = linear_idx;
for i in (0..output_shape.len()).rev() {
coords[i] = remaining % output_shape[i];
remaining /= output_shape[i];
}
let mut a_idx = 0;
let mut b_idx = 0;
for i in 0..output_shape.len() {
a_idx += coords[i] * a_strides[i];
b_idx += coords[i] * b_strides[i];
}
let a_val = a_array
.as_slice()
.unwrap_or_else(|| panic!("Failed to get slice from a_array"))[a_idx];
let b_val = b_array
.as_slice()
.unwrap_or_else(|| panic!("Failed to get slice from b_array"))[b_idx];
let result_val = op.apply(a_val, b_val);
output
.as_slice_mut()
.unwrap_or_else(|| panic!("Failed to get mutable slice from output"))[linear_idx] =
result_val;
}
Ok(())
}
#[cfg(feature = "gpu")]
fn gpu_binary_op<T, Op>(
a: &Tensor<T>,
b: &Tensor<T>,
op: Op,
broadcast_shape: &crate::Shape,
) -> Result<Tensor<T>>
where
T: Clone + Default + Zero + Send + Sync + 'static + bytemuck::Pod + bytemuck::Zeroable,
Op: BinaryOp<T>,
{
use crate::gpu::binary_ops;
use crate::tensor::TensorStorage;
let (a_buffer, b_buffer) = match (&a.storage, &b.storage) {
(TensorStorage::Gpu(ref a_buf), TensorStorage::Gpu(ref b_buf)) => (a_buf, b_buf),
_ => {
return Err(TensorError::device_mismatch(
"gpu_binary_op",
"gpu",
"non-gpu",
))
}
};
let gpu_op = match op.name() {
"Add" => binary_ops::BinaryOp::Add,
"Sub" => binary_ops::BinaryOp::Sub,
"Mul" => binary_ops::BinaryOp::Mul,
"Div" => binary_ops::BinaryOp::Div,
"Pow" => binary_ops::BinaryOp::Pow,
"Min" => binary_ops::BinaryOp::Min,
"Max" => binary_ops::BinaryOp::Max,
_ => {
return Err(TensorError::invalid_argument(format!(
"Unsupported GPU binary operation: {}",
op.name()
)))
}
};
let output_len = broadcast_shape.size();
let result_buffer = if a.shape() == b.shape() && a.shape().dims() == broadcast_shape.dims() {
binary_ops::execute_binary_op(a_buffer, b_buffer, gpu_op, output_len)?
} else {
binary_ops::execute_binary_op_with_broadcasting(
a_buffer,
b_buffer,
gpu_op,
a.shape().dims(),
b.shape().dims(),
broadcast_shape.dims(),
output_len,
)?
};
Ok(Tensor::from_gpu_buffer(
result_buffer,
broadcast_shape.clone(),
))
}