numr 0.5.1

High-performance numerical computing with multi-backend GPU acceleration (CPU/CUDA/WebGPU)
Documentation
//! Single-dimension reduction with native precision

use crate::dispatch_dtype;
use crate::dtype::Element;
use crate::error::{Error, Result};
use crate::ops::{Kernel, ReduceOp, reduce_output_shape};
use crate::runtime::cpu::{CpuClient, CpuRuntime};
use crate::tensor::Tensor;
#[cfg(feature = "rayon")]
use rayon::prelude::*;

/// Reduce a single dimension of a tensor using native precision.
///
/// Uses chunked iteration for non-last dimensions to handle strided memory access.
pub(super) fn reduce_single_dim(
    client: &CpuClient,
    op: ReduceOp,
    a: &Tensor<CpuRuntime>,
    dim: usize,
    keepdim: bool,
    op_name: &'static str,
) -> Result<Tensor<CpuRuntime>> {
    let dtype = a.dtype();
    let shape = a.shape();
    let ndim = shape.len();

    if dim >= ndim {
        return Err(Error::InvalidDimension {
            dim: dim as isize,
            ndim,
        });
    }

    let reduce_size = shape[dim];
    let outer_size: usize = shape[..dim].iter().product();
    let outer_size = outer_size.max(1);
    let inner_size: usize = shape[dim + 1..].iter().product();
    let inner_size = inner_size.max(1);

    let out_shape = reduce_output_shape(shape, &[dim], keepdim);
    let out = Tensor::<CpuRuntime>::empty(&out_shape, dtype, &client.device);

    if dim == ndim - 1 {
        let a_ptr = a.ptr();
        let out_ptr = out.ptr();

        dispatch_dtype!(dtype, T => {
            unsafe {
                <CpuClient as Kernel<CpuRuntime>>::reduce::<T>(
                    client,
                    op,
                    a_ptr as *const T,
                    out_ptr as *mut T,
                    reduce_size,
                    outer_size,
                );
            }
        }, op_name);
    } else {
        let a_ptr = a.ptr();
        let out_ptr = out.ptr();

        dispatch_dtype!(dtype, T => {
            unsafe {
                reduce_non_last_dim_runtime::<T>(
                    client,
                    op,
                    a_ptr as *const T,
                    out_ptr as *mut T,
                    outer_size,
                    reduce_size,
                    inner_size,
                );
            }
        }, op_name);
    }

    Ok(out)
}

#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn reduce_non_last_dim<T: Element>(
    op: ReduceOp,
    a: *const T,
    out: *mut T,
    outer_size: usize,
    reduce_size: usize,
    inner_size: usize,
) {
    for outer in 0..outer_size {
        reduce_non_last_dim_outer(op, a, out, outer, reduce_size, inner_size);
    }
}

#[allow(unsafe_op_in_unsafe_fn)]
#[inline]
pub(super) unsafe fn reduce_non_last_dim_outer<T: Element>(
    op: ReduceOp,
    a: *const T,
    out: *mut T,
    outer: usize,
    reduce_size: usize,
    inner_size: usize,
) {
    for inner in 0..inner_size {
        let mut acc = match op {
            ReduceOp::Sum | ReduceOp::Mean => T::zero(),
            ReduceOp::Prod => T::one(),
            ReduceOp::Max => {
                let idx = outer * reduce_size * inner_size + inner;
                *a.add(idx)
            }
            ReduceOp::Min => {
                let idx = outer * reduce_size * inner_size + inner;
                *a.add(idx)
            }
            ReduceOp::All => T::one(),
            ReduceOp::Any => T::zero(),
        };

        for r in 0..reduce_size {
            let idx = outer * reduce_size * inner_size + r * inner_size + inner;
            let val = *a.add(idx);

            acc = match op {
                ReduceOp::Sum | ReduceOp::Mean => acc + val,
                ReduceOp::Prod => acc * val,
                ReduceOp::Max => {
                    if val > acc {
                        val
                    } else {
                        acc
                    }
                }
                ReduceOp::Min => {
                    if val < acc {
                        val
                    } else {
                        acc
                    }
                }
                ReduceOp::All => {
                    if val.to_f64() != 0.0 && acc.to_f64() != 0.0 {
                        T::one()
                    } else {
                        T::zero()
                    }
                }
                ReduceOp::Any => {
                    if val.to_f64() != 0.0 || acc.to_f64() != 0.0 {
                        T::one()
                    } else {
                        T::zero()
                    }
                }
            };
        }

        if matches!(op, ReduceOp::Mean) {
            acc = T::from_f64(acc.to_f64() / reduce_size as f64);
        }

        let out_idx = outer * inner_size + inner;
        *out.add(out_idx) = acc;
    }
}

#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn reduce_non_last_dim_runtime<T: Element>(
    client: &CpuClient,
    op: ReduceOp,
    a: *const T,
    out: *mut T,
    outer_size: usize,
    reduce_size: usize,
    inner_size: usize,
) {
    #[cfg(feature = "rayon")]
    {
        if outer_size > 1 {
            return reduce_non_last_dim_parallel(
                client,
                op,
                a,
                out,
                outer_size,
                reduce_size,
                inner_size,
            );
        }
    }

    #[cfg(not(feature = "rayon"))]
    let _ = client;

    reduce_non_last_dim(op, a, out, outer_size, reduce_size, inner_size);
}

#[cfg(feature = "rayon")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn reduce_non_last_dim_parallel<T: Element>(
    client: &CpuClient,
    op: ReduceOp,
    a: *const T,
    out: *mut T,
    outer_size: usize,
    reduce_size: usize,
    inner_size: usize,
) {
    let min_len = client.rayon_min_len();
    let a_addr = a as usize;
    let out_addr = out as usize;
    client.install_parallelism(|| {
        (0..outer_size)
            .into_par_iter()
            .with_min_len(min_len)
            .for_each(|outer| unsafe {
                let a_ptr = a_addr as *const T;
                let out_ptr = out_addr as *mut T;
                reduce_non_last_dim_outer(op, a_ptr, out_ptr, outer, reduce_size, inner_size);
            });
    });
}