use crate::dtype::{DTypeMapping, get_dtype_info};
use crate::view::{ArcTensorView, RcTensorView, TensorViewOps};
use anyhow::{Result, bail};
macro_rules! impl_cpu_elementwise {
($view_type:ident) => {
impl $view_type {
pub fn elementwise<T, F>(output: &mut Self, inputs: &[&Self], mut f: F) -> Result<()>
where
T: bytemuck::Pod + DTypeMapping,
F: FnMut(&[T]) -> T,
{
let shape = output.shape();
let total = output.size();
let dtype = T::DTYPE;
if output.dtype() != dtype {
bail!("Output dtype mismatch");
}
for inp in inputs {
if inp.dtype() != dtype {
bail!("Input dtype mismatch");
}
if inp.shape() != shape {
bail!("Input shape mismatch");
}
}
let elem_size = std::mem::size_of::<T>();
let out_ptr = unsafe { output.raw_data_ptr() } as *mut T;
let out_strides = output.strides();
let mut input_data = Vec::with_capacity(inputs.len());
for inp in inputs {
let ptr = unsafe { inp.raw_data_ptr() } as *const T;
let strides = inp.strides();
input_data.push((ptr, strides));
}
for linear_idx in 0..total {
let out_offset =
Self::linear_to_offset(linear_idx, shape, out_strides) / elem_size;
let mut values = Vec::with_capacity(inputs.len());
for (ptr, strides) in &input_data {
let inp_offset =
Self::linear_to_offset(linear_idx, shape, strides) / elem_size;
let val = unsafe { *ptr.add(inp_offset) };
values.push(val);
}
let result = f(&values);
unsafe { *out_ptr.add(out_offset) = result };
}
Ok(())
}
fn linear_to_offset(linear_idx: usize, shape: &[usize], strides: &[usize]) -> usize {
let mut offset = 0;
let mut rem = linear_idx;
for d in (0..shape.len()).rev() {
let idx = rem % shape[d];
rem /= shape[d];
offset += idx * strides[d];
}
offset
}
}
};
}
impl_cpu_elementwise!(RcTensorView);
impl_cpu_elementwise!(ArcTensorView);