#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
use super::super::super::math::aarch64::neon::{
hmax_f32, hmax_f64, hmin_f32, hmin_f64, hsum_f32, hsum_f64,
};
use crate::ops::ReduceOp;
const F32_LANES: usize = 4;
const F64_LANES: usize = 2;
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn reduce_f32(
op: ReduceOp,
a: *const f32,
out: *mut f32,
reduce_size: usize,
outer_size: usize,
) {
match op {
ReduceOp::Sum => reduce_sum_f32(a, out, reduce_size, outer_size),
ReduceOp::Max => reduce_max_f32(a, out, reduce_size, outer_size),
ReduceOp::Min => reduce_min_f32(a, out, reduce_size, outer_size),
ReduceOp::Prod => reduce_prod_f32(a, out, reduce_size, outer_size),
ReduceOp::Mean => {
reduce_sum_f32(a, out, reduce_size, outer_size);
let scale = 1.0 / reduce_size as f32;
let v_scale = vdupq_n_f32(scale);
let chunks = outer_size / F32_LANES;
for i in 0..chunks {
let offset = i * F32_LANES;
let v = vld1q_f32(out.add(offset));
vst1q_f32(out.add(offset), vmulq_f32(v, v_scale));
}
for i in (chunks * F32_LANES)..outer_size {
*out.add(i) *= scale;
}
}
ReduceOp::All | ReduceOp::Any => {
super::super::reduce_scalar_f32(op, a, out, reduce_size, outer_size);
}
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn reduce_f64(
op: ReduceOp,
a: *const f64,
out: *mut f64,
reduce_size: usize,
outer_size: usize,
) {
match op {
ReduceOp::Sum => reduce_sum_f64(a, out, reduce_size, outer_size),
ReduceOp::Max => reduce_max_f64(a, out, reduce_size, outer_size),
ReduceOp::Min => reduce_min_f64(a, out, reduce_size, outer_size),
ReduceOp::Prod => reduce_prod_f64(a, out, reduce_size, outer_size),
ReduceOp::Mean => {
reduce_sum_f64(a, out, reduce_size, outer_size);
let scale = 1.0 / reduce_size as f64;
let v_scale = vdupq_n_f64(scale);
let chunks = outer_size / F64_LANES;
for i in 0..chunks {
let offset = i * F64_LANES;
let v = vld1q_f64(out.add(offset));
vst1q_f64(out.add(offset), vmulq_f64(v, v_scale));
}
for i in (chunks * F64_LANES)..outer_size {
*out.add(i) *= scale;
}
}
ReduceOp::All | ReduceOp::Any => {
super::super::reduce_scalar_f64(op, a, out, reduce_size, outer_size);
}
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[inline]
unsafe fn reduce_sum_f32(a: *const f32, out: *mut f32, reduce_size: usize, outer_size: usize) {
let chunks = reduce_size / F32_LANES;
let remainder = reduce_size % F32_LANES;
for o in 0..outer_size {
let base = a.add(o * reduce_size);
let mut acc = vdupq_n_f32(0.0);
for i in 0..chunks {
let v = vld1q_f32(base.add(i * F32_LANES));
acc = vaddq_f32(acc, v);
}
let mut sum = hsum_f32(acc);
for i in 0..remainder {
sum += *base.add(chunks * F32_LANES + i);
}
*out.add(o) = sum;
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[inline]
unsafe fn reduce_sum_f64(a: *const f64, out: *mut f64, reduce_size: usize, outer_size: usize) {
let chunks = reduce_size / F64_LANES;
let remainder = reduce_size % F64_LANES;
for o in 0..outer_size {
let base = a.add(o * reduce_size);
let mut acc = vdupq_n_f64(0.0);
for i in 0..chunks {
let v = vld1q_f64(base.add(i * F64_LANES));
acc = vaddq_f64(acc, v);
}
let mut sum = hsum_f64(acc);
for i in 0..remainder {
sum += *base.add(chunks * F64_LANES + i);
}
*out.add(o) = sum;
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[inline]
unsafe fn reduce_max_f32(a: *const f32, out: *mut f32, reduce_size: usize, outer_size: usize) {
let chunks = reduce_size / F32_LANES;
let remainder = reduce_size % F32_LANES;
for o in 0..outer_size {
let base = a.add(o * reduce_size);
let mut acc = vdupq_n_f32(f32::NEG_INFINITY);
for i in 0..chunks {
let v = vld1q_f32(base.add(i * F32_LANES));
acc = vmaxq_f32(acc, v);
}
let mut max_val = hmax_f32(acc);
for i in 0..remainder {
let val = *base.add(chunks * F32_LANES + i);
if val > max_val {
max_val = val;
}
}
*out.add(o) = max_val;
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[inline]
unsafe fn reduce_max_f64(a: *const f64, out: *mut f64, reduce_size: usize, outer_size: usize) {
let chunks = reduce_size / F64_LANES;
let remainder = reduce_size % F64_LANES;
for o in 0..outer_size {
let base = a.add(o * reduce_size);
let mut acc = vdupq_n_f64(f64::NEG_INFINITY);
for i in 0..chunks {
let v = vld1q_f64(base.add(i * F64_LANES));
acc = vmaxq_f64(acc, v);
}
let mut max_val = hmax_f64(acc);
for i in 0..remainder {
let val = *base.add(chunks * F64_LANES + i);
if val > max_val {
max_val = val;
}
}
*out.add(o) = max_val;
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[inline]
unsafe fn reduce_min_f32(a: *const f32, out: *mut f32, reduce_size: usize, outer_size: usize) {
let chunks = reduce_size / F32_LANES;
let remainder = reduce_size % F32_LANES;
for o in 0..outer_size {
let base = a.add(o * reduce_size);
let mut acc = vdupq_n_f32(f32::INFINITY);
for i in 0..chunks {
let v = vld1q_f32(base.add(i * F32_LANES));
acc = vminq_f32(acc, v);
}
let mut min_val = hmin_f32(acc);
for i in 0..remainder {
let val = *base.add(chunks * F32_LANES + i);
if val < min_val {
min_val = val;
}
}
*out.add(o) = min_val;
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[inline]
unsafe fn reduce_min_f64(a: *const f64, out: *mut f64, reduce_size: usize, outer_size: usize) {
let chunks = reduce_size / F64_LANES;
let remainder = reduce_size % F64_LANES;
for o in 0..outer_size {
let base = a.add(o * reduce_size);
let mut acc = vdupq_n_f64(f64::INFINITY);
for i in 0..chunks {
let v = vld1q_f64(base.add(i * F64_LANES));
acc = vminq_f64(acc, v);
}
let mut min_val = hmin_f64(acc);
for i in 0..remainder {
let val = *base.add(chunks * F64_LANES + i);
if val < min_val {
min_val = val;
}
}
*out.add(o) = min_val;
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[inline]
unsafe fn reduce_prod_f32(a: *const f32, out: *mut f32, reduce_size: usize, outer_size: usize) {
let chunks = reduce_size / F32_LANES;
let remainder = reduce_size % F32_LANES;
for o in 0..outer_size {
let base = a.add(o * reduce_size);
let mut acc = vdupq_n_f32(1.0);
for i in 0..chunks {
let v = vld1q_f32(base.add(i * F32_LANES));
acc = vmulq_f32(acc, v);
}
let mut prod = vgetq_lane_f32::<0>(acc)
* vgetq_lane_f32::<1>(acc)
* vgetq_lane_f32::<2>(acc)
* vgetq_lane_f32::<3>(acc);
for i in 0..remainder {
prod *= *base.add(chunks * F32_LANES + i);
}
*out.add(o) = prod;
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[inline]
unsafe fn reduce_prod_f64(a: *const f64, out: *mut f64, reduce_size: usize, outer_size: usize) {
let chunks = reduce_size / F64_LANES;
let remainder = reduce_size % F64_LANES;
for o in 0..outer_size {
let base = a.add(o * reduce_size);
let mut acc = vdupq_n_f64(1.0);
for i in 0..chunks {
let v = vld1q_f64(base.add(i * F64_LANES));
acc = vmulq_f64(acc, v);
}
let mut prod = vgetq_lane_f64::<0>(acc) * vgetq_lane_f64::<1>(acc);
for i in 0..remainder {
prod *= *base.add(chunks * F64_LANES + i);
}
*out.add(o) = prod;
}
}