use crate::dispatch_dtype;
use crate::dtype::Element;
use crate::error::{Error, Result};
use crate::ops::{Kernel, ReduceOp, reduce_output_shape};
use crate::runtime::cpu::{CpuClient, CpuRuntime};
use crate::tensor::Tensor;
#[cfg(feature = "rayon")]
use rayon::prelude::*;
pub(super) fn reduce_single_dim(
client: &CpuClient,
op: ReduceOp,
a: &Tensor<CpuRuntime>,
dim: usize,
keepdim: bool,
op_name: &'static str,
) -> Result<Tensor<CpuRuntime>> {
let dtype = a.dtype();
let shape = a.shape();
let ndim = shape.len();
if dim >= ndim {
return Err(Error::InvalidDimension {
dim: dim as isize,
ndim,
});
}
let reduce_size = shape[dim];
let outer_size: usize = shape[..dim].iter().product();
let outer_size = outer_size.max(1);
let inner_size: usize = shape[dim + 1..].iter().product();
let inner_size = inner_size.max(1);
let out_shape = reduce_output_shape(shape, &[dim], keepdim);
let out = Tensor::<CpuRuntime>::empty(&out_shape, dtype, &client.device);
if dim == ndim - 1 {
let a_ptr = a.ptr();
let out_ptr = out.ptr();
dispatch_dtype!(dtype, T => {
unsafe {
<CpuClient as Kernel<CpuRuntime>>::reduce::<T>(
client,
op,
a_ptr as *const T,
out_ptr as *mut T,
reduce_size,
outer_size,
);
}
}, op_name);
} else {
let a_ptr = a.ptr();
let out_ptr = out.ptr();
dispatch_dtype!(dtype, T => {
unsafe {
reduce_non_last_dim_runtime::<T>(
client,
op,
a_ptr as *const T,
out_ptr as *mut T,
outer_size,
reduce_size,
inner_size,
);
}
}, op_name);
}
Ok(out)
}
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn reduce_non_last_dim<T: Element>(
op: ReduceOp,
a: *const T,
out: *mut T,
outer_size: usize,
reduce_size: usize,
inner_size: usize,
) {
for outer in 0..outer_size {
reduce_non_last_dim_outer(op, a, out, outer, reduce_size, inner_size);
}
}
#[allow(unsafe_op_in_unsafe_fn)]
#[inline]
pub(super) unsafe fn reduce_non_last_dim_outer<T: Element>(
op: ReduceOp,
a: *const T,
out: *mut T,
outer: usize,
reduce_size: usize,
inner_size: usize,
) {
for inner in 0..inner_size {
let mut acc = match op {
ReduceOp::Sum | ReduceOp::Mean => T::zero(),
ReduceOp::Prod => T::one(),
ReduceOp::Max => {
let idx = outer * reduce_size * inner_size + inner;
*a.add(idx)
}
ReduceOp::Min => {
let idx = outer * reduce_size * inner_size + inner;
*a.add(idx)
}
ReduceOp::All => T::one(),
ReduceOp::Any => T::zero(),
};
for r in 0..reduce_size {
let idx = outer * reduce_size * inner_size + r * inner_size + inner;
let val = *a.add(idx);
acc = match op {
ReduceOp::Sum | ReduceOp::Mean => acc + val,
ReduceOp::Prod => acc * val,
ReduceOp::Max => {
if val > acc {
val
} else {
acc
}
}
ReduceOp::Min => {
if val < acc {
val
} else {
acc
}
}
ReduceOp::All => {
if val.to_f64() != 0.0 && acc.to_f64() != 0.0 {
T::one()
} else {
T::zero()
}
}
ReduceOp::Any => {
if val.to_f64() != 0.0 || acc.to_f64() != 0.0 {
T::one()
} else {
T::zero()
}
}
};
}
if matches!(op, ReduceOp::Mean) {
acc = T::from_f64(acc.to_f64() / reduce_size as f64);
}
let out_idx = outer * inner_size + inner;
*out.add(out_idx) = acc;
}
}
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn reduce_non_last_dim_runtime<T: Element>(
client: &CpuClient,
op: ReduceOp,
a: *const T,
out: *mut T,
outer_size: usize,
reduce_size: usize,
inner_size: usize,
) {
#[cfg(feature = "rayon")]
{
if outer_size > 1 {
return reduce_non_last_dim_parallel(
client,
op,
a,
out,
outer_size,
reduce_size,
inner_size,
);
}
}
#[cfg(not(feature = "rayon"))]
let _ = client;
reduce_non_last_dim(op, a, out, outer_size, reduce_size, inner_size);
}
#[cfg(feature = "rayon")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn reduce_non_last_dim_parallel<T: Element>(
client: &CpuClient,
op: ReduceOp,
a: *const T,
out: *mut T,
outer_size: usize,
reduce_size: usize,
inner_size: usize,
) {
let min_len = client.rayon_min_len();
let a_addr = a as usize;
let out_addr = out as usize;
client.install_parallelism(|| {
(0..outer_size)
.into_par_iter()
.with_min_len(min_len)
.for_each(|outer| unsafe {
let a_ptr = a_addr as *const T;
let out_ptr = out_addr as *mut T;
reduce_non_last_dim_outer(op, a_ptr, out_ptr, outer, reduce_size, inner_size);
});
});
}