use crate::dtype::Element;
use crate::ops::BinaryOp;
#[inline]
pub unsafe fn binary_op_kernel<T: Element>(
op: BinaryOp,
a: *const T,
b: *const T,
out: *mut T,
len: usize,
) {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
{
use super::simd::binary;
use crate::dtype::DType;
match T::DTYPE {
DType::F32 => {
binary::binary_f32(op, a as *const f32, b as *const f32, out as *mut f32, len);
return;
}
DType::F64 => {
binary::binary_f64(op, a as *const f64, b as *const f64, out as *mut f64, len);
return;
}
DType::I32 => {
binary::binary_i32(op, a as *const i32, b as *const i32, out as *mut i32, len);
return;
}
#[cfg(feature = "f16")]
DType::F16 => {
binary::binary_f16(
op,
a as *const half::f16,
b as *const half::f16,
out as *mut half::f16,
len,
);
return;
}
#[cfg(feature = "f16")]
DType::BF16 => {
binary::binary_bf16(
op,
a as *const half::bf16,
b as *const half::bf16,
out as *mut half::bf16,
len,
);
return;
}
_ => {} }
}
binary_op_scalar(op, a, b, out, len);
}
#[inline]
unsafe fn binary_op_scalar<T: Element>(
op: BinaryOp,
a: *const T,
b: *const T,
out: *mut T,
len: usize,
) {
let a_slice = std::slice::from_raw_parts(a, len);
let b_slice = std::slice::from_raw_parts(b, len);
let out_slice = std::slice::from_raw_parts_mut(out, len);
match op {
BinaryOp::Add => {
for i in 0..len {
out_slice[i] = a_slice[i] + b_slice[i];
}
}
BinaryOp::Sub => {
for i in 0..len {
out_slice[i] = a_slice[i] - b_slice[i];
}
}
BinaryOp::Mul => {
for i in 0..len {
out_slice[i] = a_slice[i] * b_slice[i];
}
}
BinaryOp::Div => {
for i in 0..len {
out_slice[i] = a_slice[i] / b_slice[i];
}
}
BinaryOp::Pow => {
for i in 0..len {
let base = a_slice[i].to_f64();
let exp = b_slice[i].to_f64();
out_slice[i] = T::from_f64(base.powf(exp));
}
}
BinaryOp::Max => {
for i in 0..len {
out_slice[i] = if a_slice[i] > b_slice[i] {
a_slice[i]
} else {
b_slice[i]
};
}
}
BinaryOp::Min => {
for i in 0..len {
out_slice[i] = if a_slice[i] < b_slice[i] {
a_slice[i]
} else {
b_slice[i]
};
}
}
BinaryOp::Atan2 => {
for i in 0..len {
let y = a_slice[i].to_f64();
let x = b_slice[i].to_f64();
out_slice[i] = T::from_f64(y.atan2(x));
}
}
}
}
#[inline]
pub unsafe fn binary_scalar_f32(
op: BinaryOp,
a: *const f32,
b: *const f32,
out: *mut f32,
len: usize,
) {
match op {
BinaryOp::Add => {
for i in 0..len {
*out.add(i) = *a.add(i) + *b.add(i);
}
}
BinaryOp::Sub => {
for i in 0..len {
*out.add(i) = *a.add(i) - *b.add(i);
}
}
BinaryOp::Mul => {
for i in 0..len {
*out.add(i) = *a.add(i) * *b.add(i);
}
}
BinaryOp::Div => {
for i in 0..len {
*out.add(i) = *a.add(i) / *b.add(i);
}
}
BinaryOp::Max => {
for i in 0..len {
let av = *a.add(i);
let bv = *b.add(i);
*out.add(i) = if av > bv { av } else { bv };
}
}
BinaryOp::Min => {
for i in 0..len {
let av = *a.add(i);
let bv = *b.add(i);
*out.add(i) = if av < bv { av } else { bv };
}
}
BinaryOp::Pow => {
for i in 0..len {
*out.add(i) = (*a.add(i)).powf(*b.add(i));
}
}
BinaryOp::Atan2 => {
for i in 0..len {
*out.add(i) = (*a.add(i)).atan2(*b.add(i));
}
}
}
}
#[inline]
pub unsafe fn binary_scalar_f64(
op: BinaryOp,
a: *const f64,
b: *const f64,
out: *mut f64,
len: usize,
) {
match op {
BinaryOp::Add => {
for i in 0..len {
*out.add(i) = *a.add(i) + *b.add(i);
}
}
BinaryOp::Sub => {
for i in 0..len {
*out.add(i) = *a.add(i) - *b.add(i);
}
}
BinaryOp::Mul => {
for i in 0..len {
*out.add(i) = *a.add(i) * *b.add(i);
}
}
BinaryOp::Div => {
for i in 0..len {
*out.add(i) = *a.add(i) / *b.add(i);
}
}
BinaryOp::Max => {
for i in 0..len {
let av = *a.add(i);
let bv = *b.add(i);
*out.add(i) = if av > bv { av } else { bv };
}
}
BinaryOp::Min => {
for i in 0..len {
let av = *a.add(i);
let bv = *b.add(i);
*out.add(i) = if av < bv { av } else { bv };
}
}
BinaryOp::Pow => {
for i in 0..len {
*out.add(i) = (*a.add(i)).powf(*b.add(i));
}
}
BinaryOp::Atan2 => {
for i in 0..len {
*out.add(i) = (*a.add(i)).atan2(*b.add(i));
}
}
}
}
#[inline]
pub unsafe fn binary_scalar_i32(
op: BinaryOp,
a: *const i32,
b: *const i32,
out: *mut i32,
len: usize,
) {
match op {
BinaryOp::Add => {
for i in 0..len {
*out.add(i) = (*a.add(i)).wrapping_add(*b.add(i));
}
}
BinaryOp::Sub => {
for i in 0..len {
*out.add(i) = (*a.add(i)).wrapping_sub(*b.add(i));
}
}
BinaryOp::Mul => {
for i in 0..len {
*out.add(i) = (*a.add(i)).wrapping_mul(*b.add(i));
}
}
BinaryOp::Div => {
for i in 0..len {
let bv = *b.add(i);
*out.add(i) = if bv != 0 {
(*a.add(i)).wrapping_div(bv)
} else {
0
};
}
}
BinaryOp::Max => {
for i in 0..len {
let av = *a.add(i);
let bv = *b.add(i);
*out.add(i) = if av > bv { av } else { bv };
}
}
BinaryOp::Min => {
for i in 0..len {
let av = *a.add(i);
let bv = *b.add(i);
*out.add(i) = if av < bv { av } else { bv };
}
}
BinaryOp::Pow => {
for i in 0..len {
let base = *a.add(i) as f64;
let exp = *b.add(i) as f64;
*out.add(i) = base.powf(exp) as i32;
}
}
BinaryOp::Atan2 => {
for i in 0..len {
let y = *a.add(i) as f64;
let x = *b.add(i) as f64;
*out.add(i) = y.atan2(x) as i32;
}
}
}
}
#[inline]
#[allow(clippy::too_many_arguments)]
pub unsafe fn binary_op_strided_kernel<T: Element>(
op: BinaryOp,
a: *const T,
b: *const T,
out: *mut T,
out_shape: &[usize],
a_strides: &[isize],
b_strides: &[isize],
a_offset: usize,
b_offset: usize,
) {
let ndim = out_shape.len();
let total = out_shape.iter().product::<usize>();
if total == 0 {
return;
}
let is_simple = ndim > 0 && {
let mut expected_stride = 1isize;
let mut simple = true;
for i in (0..ndim).rev() {
if a_strides[i] != expected_stride || b_strides[i] != expected_stride {
simple = false;
break;
}
expected_stride *= out_shape[i] as isize;
}
simple && a_offset == 0 && b_offset == 0
};
if is_simple {
binary_op_kernel(op, a, b, out, total);
return;
}
let mut indices = vec![0usize; ndim];
let mut a_idx = a_offset as isize;
let mut b_idx = b_offset as isize;
for out_idx in 0..total {
let a_val = *a.offset(a_idx);
let b_val = *b.offset(b_idx);
let result = match op {
BinaryOp::Add => a_val + b_val,
BinaryOp::Sub => a_val - b_val,
BinaryOp::Mul => a_val * b_val,
BinaryOp::Div => a_val / b_val,
BinaryOp::Pow => T::from_f64(a_val.to_f64().powf(b_val.to_f64())),
BinaryOp::Max => {
if a_val > b_val {
a_val
} else {
b_val
}
}
BinaryOp::Min => {
if a_val < b_val {
a_val
} else {
b_val
}
}
BinaryOp::Atan2 => T::from_f64(a_val.to_f64().atan2(b_val.to_f64())),
};
*out.add(out_idx) = result;
for dim in (0..ndim).rev() {
indices[dim] += 1;
a_idx += a_strides[dim];
b_idx += b_strides[dim];
if indices[dim] < out_shape[dim] {
break;
}
indices[dim] = 0;
a_idx -= (out_shape[dim] as isize) * a_strides[dim];
b_idx -= (out_shape[dim] as isize) * b_strides[dim];
}
}
}