#[cfg(target_arch = "x86_64")]
mod avx2;
#[cfg(target_arch = "x86_64")]
mod avx512;
#[cfg(target_arch = "aarch64")]
mod aarch64;
use super::{SimdLevel, detect_simd};
use crate::ops::ReduceOp;
const SIMD_THRESHOLD: usize = 64;
#[inline]
const fn is_simd_supported(op: ReduceOp) -> bool {
matches!(
op,
ReduceOp::Sum | ReduceOp::Max | ReduceOp::Min | ReduceOp::Prod
)
}
#[inline]
pub unsafe fn reduce_f32(
op: ReduceOp,
a: *const f32,
out: *mut f32,
reduce_size: usize,
outer_size: usize,
) {
let level = detect_simd();
if reduce_size < SIMD_THRESHOLD || level == SimdLevel::Scalar || !is_simd_supported(op) {
reduce_scalar_f32(op, a, out, reduce_size, outer_size);
return;
}
#[cfg(target_arch = "x86_64")]
match level {
SimdLevel::Avx512 => avx512::reduce_f32(op, a, out, reduce_size, outer_size),
SimdLevel::Avx2Fma => avx2::reduce_f32(op, a, out, reduce_size, outer_size),
_ => reduce_scalar_f32(op, a, out, reduce_size, outer_size),
}
#[cfg(target_arch = "aarch64")]
match level {
SimdLevel::Neon | SimdLevel::NeonFp16 => {
aarch64::neon::reduce_f32(op, a, out, reduce_size, outer_size)
}
_ => reduce_scalar_f32(op, a, out, reduce_size, outer_size),
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
reduce_scalar_f32(op, a, out, reduce_size, outer_size);
}
#[inline]
pub unsafe fn reduce_f64(
op: ReduceOp,
a: *const f64,
out: *mut f64,
reduce_size: usize,
outer_size: usize,
) {
let level = detect_simd();
if reduce_size < SIMD_THRESHOLD || level == SimdLevel::Scalar || !is_simd_supported(op) {
reduce_scalar_f64(op, a, out, reduce_size, outer_size);
return;
}
#[cfg(target_arch = "x86_64")]
match level {
SimdLevel::Avx512 => avx512::reduce_f64(op, a, out, reduce_size, outer_size),
SimdLevel::Avx2Fma => avx2::reduce_f64(op, a, out, reduce_size, outer_size),
_ => reduce_scalar_f64(op, a, out, reduce_size, outer_size),
}
#[cfg(target_arch = "aarch64")]
match level {
SimdLevel::Neon | SimdLevel::NeonFp16 => {
aarch64::neon::reduce_f64(op, a, out, reduce_size, outer_size)
}
_ => reduce_scalar_f64(op, a, out, reduce_size, outer_size),
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
reduce_scalar_f64(op, a, out, reduce_size, outer_size);
}
#[inline]
pub unsafe fn reduce_scalar_f32(
op: ReduceOp,
a: *const f32,
out: *mut f32,
reduce_size: usize,
outer_size: usize,
) {
match op {
ReduceOp::Sum => {
for o in 0..outer_size {
let mut sum = 0.0f32;
for r in 0..reduce_size {
sum += *a.add(o * reduce_size + r);
}
*out.add(o) = sum;
}
}
ReduceOp::Max => {
for o in 0..outer_size {
let mut max_val = *a.add(o * reduce_size);
for r in 1..reduce_size {
let val = *a.add(o * reduce_size + r);
if val > max_val {
max_val = val;
}
}
*out.add(o) = max_val;
}
}
ReduceOp::Min => {
for o in 0..outer_size {
let mut min_val = *a.add(o * reduce_size);
for r in 1..reduce_size {
let val = *a.add(o * reduce_size + r);
if val < min_val {
min_val = val;
}
}
*out.add(o) = min_val;
}
}
ReduceOp::Prod => {
for o in 0..outer_size {
let mut prod = 1.0f32;
for r in 0..reduce_size {
prod *= *a.add(o * reduce_size + r);
}
*out.add(o) = prod;
}
}
ReduceOp::Mean => {
let scale = 1.0 / reduce_size as f32;
for o in 0..outer_size {
let mut sum = 0.0f32;
for r in 0..reduce_size {
sum += *a.add(o * reduce_size + r);
}
*out.add(o) = sum * scale;
}
}
ReduceOp::All | ReduceOp::Any => {
let is_any = matches!(op, ReduceOp::Any);
for o in 0..outer_size {
let mut result = !is_any;
for r in 0..reduce_size {
let val = *a.add(o * reduce_size + r) != 0.0;
if is_any {
result = result || val;
} else {
result = result && val;
}
}
*out.add(o) = if result { 1.0 } else { 0.0 };
}
}
}
}
#[inline]
pub unsafe fn reduce_scalar_f64(
op: ReduceOp,
a: *const f64,
out: *mut f64,
reduce_size: usize,
outer_size: usize,
) {
match op {
ReduceOp::Sum => {
for o in 0..outer_size {
let mut sum = 0.0f64;
for r in 0..reduce_size {
sum += *a.add(o * reduce_size + r);
}
*out.add(o) = sum;
}
}
ReduceOp::Max => {
for o in 0..outer_size {
let mut max_val = *a.add(o * reduce_size);
for r in 1..reduce_size {
let val = *a.add(o * reduce_size + r);
if val > max_val {
max_val = val;
}
}
*out.add(o) = max_val;
}
}
ReduceOp::Min => {
for o in 0..outer_size {
let mut min_val = *a.add(o * reduce_size);
for r in 1..reduce_size {
let val = *a.add(o * reduce_size + r);
if val < min_val {
min_val = val;
}
}
*out.add(o) = min_val;
}
}
ReduceOp::Prod => {
for o in 0..outer_size {
let mut prod = 1.0f64;
for r in 0..reduce_size {
prod *= *a.add(o * reduce_size + r);
}
*out.add(o) = prod;
}
}
ReduceOp::Mean => {
let scale = 1.0 / reduce_size as f64;
for o in 0..outer_size {
let mut sum = 0.0f64;
for r in 0..reduce_size {
sum += *a.add(o * reduce_size + r);
}
*out.add(o) = sum * scale;
}
}
ReduceOp::All | ReduceOp::Any => {
let is_any = matches!(op, ReduceOp::Any);
for o in 0..outer_size {
let mut result = !is_any;
for r in 0..reduce_size {
let val = *a.add(o * reduce_size + r) != 0.0;
if is_any {
result = result || val;
} else {
result = result && val;
}
}
*out.add(o) = if result { 1.0 } else { 0.0 };
}
}
}
}
#[cfg(feature = "f16")]
pub unsafe fn reduce_f16(
op: ReduceOp,
a: *const half::f16,
out: *mut half::f16,
reduce_size: usize,
outer_size: usize,
) {
use super::half_convert_utils::*;
let input_len = outer_size * reduce_size;
let mut a_f32 = vec![0.0f32; input_len];
let mut out_f32 = vec![0.0f32; outer_size];
convert_f16_to_f32(a as *const u16, a_f32.as_mut_ptr(), input_len);
reduce_f32(
op,
a_f32.as_ptr(),
out_f32.as_mut_ptr(),
reduce_size,
outer_size,
);
convert_f32_to_f16(out_f32.as_ptr(), out as *mut u16, outer_size);
}
#[cfg(feature = "f16")]
pub unsafe fn reduce_bf16(
op: ReduceOp,
a: *const half::bf16,
out: *mut half::bf16,
reduce_size: usize,
outer_size: usize,
) {
use super::half_convert_utils::*;
let input_len = outer_size * reduce_size;
let mut a_f32 = vec![0.0f32; input_len];
let mut out_f32 = vec![0.0f32; outer_size];
convert_bf16_to_f32(a as *const u16, a_f32.as_mut_ptr(), input_len);
reduce_f32(
op,
a_f32.as_ptr(),
out_f32.as_mut_ptr(),
reduce_size,
outer_size,
);
convert_f32_to_bf16(out_f32.as_ptr(), out as *mut u16, outer_size);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reduce_sum_f32() {
let a: Vec<f32> = (0..1000).map(|x| x as f32).collect();
let mut out = vec![0.0f32; 1];
unsafe { reduce_f32(ReduceOp::Sum, a.as_ptr(), out.as_mut_ptr(), 1000, 1) }
assert!((out[0] - 499500.0).abs() < 1e-3, "got {}", out[0]);
}
#[test]
fn test_reduce_max_f32() {
let a: Vec<f32> = (0..1000).map(|x| (x as f32) - 500.0).collect();
let mut out = vec![0.0f32; 1];
unsafe { reduce_f32(ReduceOp::Max, a.as_ptr(), out.as_mut_ptr(), 1000, 1) }
assert_eq!(out[0], 499.0);
}
#[test]
fn test_reduce_min_f32() {
let a: Vec<f32> = (0..1000).map(|x| (x as f32) - 500.0).collect();
let mut out = vec![0.0f32; 1];
unsafe { reduce_f32(ReduceOp::Min, a.as_ptr(), out.as_mut_ptr(), 1000, 1) }
assert_eq!(out[0], -500.0);
}
#[test]
fn test_reduce_prod_f32() {
let a: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let mut out = vec![0.0f32; 1];
unsafe { reduce_scalar_f32(ReduceOp::Prod, a.as_ptr(), out.as_mut_ptr(), 5, 1) }
assert_eq!(out[0], 120.0); }
#[test]
fn test_reduce_multiple_outer() {
let a: Vec<f32> = (0..1000).map(|x| x as f32).collect();
let mut out = vec![0.0f32; 2];
unsafe { reduce_f32(ReduceOp::Sum, a.as_ptr(), out.as_mut_ptr(), 500, 2) }
assert!((out[0] - 124750.0).abs() < 1e-3, "first got {}", out[0]);
assert!((out[1] - 374750.0).abs() < 1e-3, "second got {}", out[1]);
}
}