use super::common::{advance_coord, contiguous_strides, out_index_from_coord};
use crate::dispatch_dtype;
use crate::dtype::Element;
use crate::error::Result;
use crate::ops::{AccumulationPrecision, ReduceOp, reduce_output_shape};
use crate::runtime::cpu::kernels::Accumulator;
use crate::runtime::cpu::{CpuClient, CpuRuntime};
use crate::tensor::Tensor;
pub(super) fn reduce_multi_dim_fused(
client: &CpuClient,
op: ReduceOp,
a: &Tensor<CpuRuntime>,
dims: &[usize],
keepdim: bool,
precision: AccumulationPrecision,
op_name: &'static str,
) -> Result<Tensor<CpuRuntime>> {
let shape = a.shape();
let out_shape = reduce_output_shape(shape, dims, keepdim);
let out = Tensor::<CpuRuntime>::empty(&out_shape, a.dtype(), &client.device);
let mut reduce_mask = vec![false; shape.len()];
for &d in dims {
reduce_mask[d] = true;
}
let kept_axes: Vec<usize> = if keepdim {
Vec::new()
} else {
(0..shape.len())
.filter(|&axis| !reduce_mask[axis])
.collect()
};
let out_strides = contiguous_strides(&out_shape);
let reduce_count = dims.iter().fold(1usize, |acc, &d| acc * shape[d]);
let numel = a.numel();
let out_numel = out.numel();
let in_ptr = a.ptr();
let out_ptr = out.ptr();
dispatch_dtype!(a.dtype(), T => {
unsafe {
match precision {
AccumulationPrecision::Native => reduce_multi_dim_fused_native::<T>(
op,
in_ptr as *const T,
out_ptr as *mut T,
numel,
out_numel,
shape,
&reduce_mask,
keepdim,
&kept_axes,
&out_strides,
reduce_count,
),
AccumulationPrecision::FP32 | AccumulationPrecision::BF16 => {
reduce_multi_dim_fused_acc::<T, f32>(
op,
in_ptr as *const T,
out_ptr as *mut T,
numel,
out_numel,
shape,
&reduce_mask,
keepdim,
&kept_axes,
&out_strides,
reduce_count,
)
}
AccumulationPrecision::FP64 => reduce_multi_dim_fused_acc::<T, f64>(
op,
in_ptr as *const T,
out_ptr as *mut T,
numel,
out_numel,
shape,
&reduce_mask,
keepdim,
&kept_axes,
&out_strides,
reduce_count,
),
}
}
}, op_name);
Ok(out)
}
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn reduce_multi_dim_fused_native<T: Element>(
op: ReduceOp,
input: *const T,
output: *mut T,
numel: usize,
out_numel: usize,
shape: &[usize],
reduce_mask: &[bool],
keepdim: bool,
kept_axes: &[usize],
out_strides: &[usize],
reduce_count: usize,
) {
match op {
ReduceOp::Sum | ReduceOp::Mean | ReduceOp::Any => {
for i in 0..out_numel {
*output.add(i) = T::zero();
}
}
ReduceOp::Prod | ReduceOp::All => {
for i in 0..out_numel {
*output.add(i) = T::one();
}
}
ReduceOp::Max | ReduceOp::Min => {}
}
let mut initialized = if matches!(op, ReduceOp::Max | ReduceOp::Min) {
vec![false; out_numel]
} else {
Vec::new()
};
let mut coord = vec![0usize; shape.len()];
for linear in 0..numel {
let out_idx = out_index_from_coord(&coord, reduce_mask, keepdim, kept_axes, out_strides);
let val = *input.add(linear);
match op {
ReduceOp::Sum | ReduceOp::Mean => {
let acc = *output.add(out_idx);
*output.add(out_idx) = acc + val;
}
ReduceOp::Prod => {
let acc = *output.add(out_idx);
*output.add(out_idx) = acc * val;
}
ReduceOp::Max => {
if !initialized[out_idx] {
*output.add(out_idx) = val;
initialized[out_idx] = true;
} else {
let acc = *output.add(out_idx);
*output.add(out_idx) = if val > acc { val } else { acc };
}
}
ReduceOp::Min => {
if !initialized[out_idx] {
*output.add(out_idx) = val;
initialized[out_idx] = true;
} else {
let acc = *output.add(out_idx);
*output.add(out_idx) = if val < acc { val } else { acc };
}
}
ReduceOp::All => {
let acc = *output.add(out_idx);
*output.add(out_idx) = if val.to_f64() != 0.0 && acc.to_f64() != 0.0 {
T::one()
} else {
T::zero()
};
}
ReduceOp::Any => {
let acc = *output.add(out_idx);
*output.add(out_idx) = if val.to_f64() != 0.0 || acc.to_f64() != 0.0 {
T::one()
} else {
T::zero()
};
}
}
if linear + 1 < numel {
advance_coord(&mut coord, shape);
}
}
if matches!(op, ReduceOp::Mean) {
for i in 0..out_numel {
let scaled = (*output.add(i)).to_f64() / reduce_count as f64;
*output.add(i) = T::from_f64(scaled);
}
}
}
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn reduce_multi_dim_fused_acc<T: Element, A: Accumulator>(
op: ReduceOp,
input: *const T,
output: *mut T,
numel: usize,
out_numel: usize,
shape: &[usize],
reduce_mask: &[bool],
keepdim: bool,
kept_axes: &[usize],
out_strides: &[usize],
reduce_count: usize,
) {
let mut acc = match op {
ReduceOp::Sum | ReduceOp::Mean | ReduceOp::Any | ReduceOp::Max | ReduceOp::Min => {
vec![A::ZERO; out_numel]
}
ReduceOp::Prod | ReduceOp::All => vec![A::ONE; out_numel],
};
let mut initialized = if matches!(op, ReduceOp::Max | ReduceOp::Min) {
vec![false; out_numel]
} else {
Vec::new()
};
let mut coord = vec![0usize; shape.len()];
for linear in 0..numel {
let out_idx = out_index_from_coord(&coord, reduce_mask, keepdim, kept_axes, out_strides);
let val = A::acc_in((*input.add(linear)).to_f64());
match op {
ReduceOp::Sum | ReduceOp::Mean => {
acc[out_idx] = acc[out_idx].acc_add(val);
}
ReduceOp::Prod => {
acc[out_idx] = acc[out_idx].acc_mul(val);
}
ReduceOp::Max => {
if !initialized[out_idx] {
acc[out_idx] = val;
initialized[out_idx] = true;
} else if val > acc[out_idx] {
acc[out_idx] = val;
}
}
ReduceOp::Min => {
if !initialized[out_idx] {
acc[out_idx] = val;
initialized[out_idx] = true;
} else if val < acc[out_idx] {
acc[out_idx] = val;
}
}
ReduceOp::All => {
acc[out_idx] = if val != A::ZERO && acc[out_idx] != A::ZERO {
A::ONE
} else {
A::ZERO
};
}
ReduceOp::Any => {
acc[out_idx] = if val != A::ZERO || acc[out_idx] != A::ZERO {
A::ONE
} else {
A::ZERO
};
}
}
if linear + 1 < numel {
advance_coord(&mut coord, shape);
}
}
for i in 0..out_numel {
let mut out_val = acc[i];
if matches!(op, ReduceOp::Mean) {
out_val = out_val.acc_div(reduce_count);
}
*output.add(i) = T::from_f64(out_val.into());
}
}