mod special;
pub use special::{
argmax_kernel, argmin_kernel, softmax_bwd_kernel, softmax_kernel, variance_kernel,
};
use crate::dtype::Element;
use crate::ops::{AccumulationPrecision, ReduceOp};
#[inline]
pub unsafe fn reduce_kernel<T: Element>(
op: ReduceOp,
a: *const T,
out: *mut T,
reduce_size: usize,
outer_size: usize,
) {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
{
use super::simd::reduce;
use crate::dtype::DType;
match T::DTYPE {
DType::F32 => {
reduce::reduce_f32(
op,
a as *const f32,
out as *mut f32,
reduce_size,
outer_size,
);
return;
}
DType::F64 => {
reduce::reduce_f64(
op,
a as *const f64,
out as *mut f64,
reduce_size,
outer_size,
);
return;
}
#[cfg(feature = "f16")]
DType::F16 => {
reduce::reduce_f16(
op,
a as *const half::f16,
out as *mut half::f16,
reduce_size,
outer_size,
);
return;
}
#[cfg(feature = "f16")]
DType::BF16 => {
reduce::reduce_bf16(
op,
a as *const half::bf16,
out as *mut half::bf16,
reduce_size,
outer_size,
);
return;
}
_ => {} }
}
reduce_kernel_scalar(op, a, out, reduce_size, outer_size);
}
#[inline]
unsafe fn reduce_kernel_scalar<T: Element>(
op: ReduceOp,
a: *const T,
out: *mut T,
reduce_size: usize,
outer_size: usize,
) {
match op {
ReduceOp::Sum => {
for o in 0..outer_size {
let mut sum = T::zero();
for r in 0..reduce_size {
sum = sum + *a.add(o * reduce_size + r);
}
*out.add(o) = sum;
}
}
ReduceOp::Mean => {
let scale = 1.0 / reduce_size as f64;
for o in 0..outer_size {
let mut sum = T::zero();
for r in 0..reduce_size {
sum = sum + *a.add(o * reduce_size + r);
}
*out.add(o) = T::from_f64(sum.to_f64() * scale);
}
}
ReduceOp::Max => {
for o in 0..outer_size {
let mut max_val = *a.add(o * reduce_size);
for r in 1..reduce_size {
let val = *a.add(o * reduce_size + r);
if val > max_val {
max_val = val;
}
}
*out.add(o) = max_val;
}
}
ReduceOp::Min => {
for o in 0..outer_size {
let mut min_val = *a.add(o * reduce_size);
for r in 1..reduce_size {
let val = *a.add(o * reduce_size + r);
if val < min_val {
min_val = val;
}
}
*out.add(o) = min_val;
}
}
ReduceOp::Prod => {
for o in 0..outer_size {
let mut prod = T::one();
for r in 0..reduce_size {
prod = prod * *a.add(o * reduce_size + r);
}
*out.add(o) = prod;
}
}
ReduceOp::All | ReduceOp::Any => {
let is_any = matches!(op, ReduceOp::Any);
for o in 0..outer_size {
let mut result = !is_any; for r in 0..reduce_size {
let val = (*a.add(o * reduce_size + r)).to_f64() != 0.0;
if is_any {
result = result || val;
} else {
result = result && val;
}
}
*out.add(o) = T::from_f64(if result { 1.0 } else { 0.0 });
}
}
}
}
#[inline]
pub unsafe fn reduce_kernel_with_precision<T: Element>(
op: ReduceOp,
a: *const T,
out: *mut T,
reduce_size: usize,
outer_size: usize,
precision: AccumulationPrecision,
) {
match precision {
AccumulationPrecision::Native => {
reduce_kernel(op, a, out, reduce_size, outer_size);
}
AccumulationPrecision::FP32 | AccumulationPrecision::BF16 => {
reduce_kernel_acc::<T, f32>(op, a, out, reduce_size, outer_size);
}
AccumulationPrecision::FP64 => {
reduce_kernel_acc::<T, f64>(op, a, out, reduce_size, outer_size);
}
}
}
pub trait Accumulator: Copy + PartialOrd + PartialEq + Into<f64> {
const ZERO: Self;
const ONE: Self;
fn acc_in(v: f64) -> Self;
fn acc_add(self, other: Self) -> Self;
fn acc_mul(self, other: Self) -> Self;
fn acc_div(self, n: usize) -> Self;
}
impl Accumulator for f32 {
const ZERO: Self = 0.0;
const ONE: Self = 1.0;
#[inline]
fn acc_in(v: f64) -> Self {
v as f32
}
#[inline]
fn acc_add(self, other: Self) -> Self {
self + other
}
#[inline]
fn acc_mul(self, other: Self) -> Self {
self * other
}
#[inline]
fn acc_div(self, n: usize) -> Self {
self / n as f32
}
}
impl Accumulator for f64 {
const ZERO: Self = 0.0;
const ONE: Self = 1.0;
#[inline]
fn acc_in(v: f64) -> Self {
v
}
#[inline]
fn acc_add(self, other: Self) -> Self {
self + other
}
#[inline]
fn acc_mul(self, other: Self) -> Self {
self * other
}
#[inline]
fn acc_div(self, n: usize) -> Self {
self / n as f64
}
}
#[inline]
unsafe fn reduce_kernel_acc<T: Element, A: Accumulator>(
op: ReduceOp,
a: *const T,
out: *mut T,
reduce_size: usize,
outer_size: usize,
) {
match op {
ReduceOp::Sum => {
for o in 0..outer_size {
let mut sum = A::ZERO;
for r in 0..reduce_size {
sum = sum.acc_add(A::acc_in((*a.add(o * reduce_size + r)).to_f64()));
}
*out.add(o) = T::from_f64(sum.into());
}
}
ReduceOp::Mean => {
for o in 0..outer_size {
let mut sum = A::ZERO;
for r in 0..reduce_size {
sum = sum.acc_add(A::acc_in((*a.add(o * reduce_size + r)).to_f64()));
}
*out.add(o) = T::from_f64(sum.acc_div(reduce_size).into());
}
}
ReduceOp::Max => {
for o in 0..outer_size {
let mut max_val = A::acc_in((*a.add(o * reduce_size)).to_f64());
for r in 1..reduce_size {
let val = A::acc_in((*a.add(o * reduce_size + r)).to_f64());
if val > max_val {
max_val = val;
}
}
*out.add(o) = T::from_f64(max_val.into());
}
}
ReduceOp::Min => {
for o in 0..outer_size {
let mut min_val = A::acc_in((*a.add(o * reduce_size)).to_f64());
for r in 1..reduce_size {
let val = A::acc_in((*a.add(o * reduce_size + r)).to_f64());
if val < min_val {
min_val = val;
}
}
*out.add(o) = T::from_f64(min_val.into());
}
}
ReduceOp::Prod => {
for o in 0..outer_size {
let mut prod = A::ONE;
for r in 0..reduce_size {
prod = prod.acc_mul(A::acc_in((*a.add(o * reduce_size + r)).to_f64()));
}
*out.add(o) = T::from_f64(prod.into());
}
}
ReduceOp::All | ReduceOp::Any => {
reduce_kernel(op, a, out, reduce_size, outer_size);
}
}
}