use approx::{assert_relative_eq, RelativeEq};
use num_traits::{Bounded, Float, NumCast, Zero};
use crate::{
tests::assert_eq, vload_unaligned, ReduceAdd, ReduceMax, ReduceMin, Scalar, Simd, Vector,
};
use core::fmt::Debug;
use core::ops::Add;
macro_rules! reduce_op {
($trait: ident, $scalar_trait: path, $impl: expr, $impl_scalar: expr) => {
::paste::paste! {
struct [<$trait Op>]<T>(::core::marker::PhantomData<T>);
impl<T: $trait + $scalar_trait> ReduceOp<T> for [<$trait Op>]<T> {
#[inline(always)]
fn call<S: Simd>(lhs: Vector<S, T>) -> T {
$impl(lhs)
}
#[inline(always)]
fn call_scalar(lhs: T, rhs: T) -> T {
$impl_scalar(lhs, rhs)
}
}
}
};
}
#[inline(always)]
fn test_reduce_add_impl<S: Simd, T: ReduceAdd + Add + Zero + Debug>(a: &[T]) -> T {
reduce_op!(
ReduceAdd,
Add<Output = T>,
|a: Vector<S, T>| a.reduce_add(),
Add::add
);
test_reduce_op::<S, T, ReduceAddOp<T>>(a, Zero::zero())
}
#[inline(always)]
fn test_reduce_min_ord_impl<S: Simd, T: ReduceMin + Ord + Bounded + Debug>(a: &[T]) -> T {
reduce_op!(ReduceMin, Ord, |a: Vector<S, T>| a.reduce_min(), Ord::min);
test_reduce_op::<S, T, ReduceMinOp<T>>(a, Bounded::max_value())
}
#[inline(always)]
fn test_reduce_max_ord_impl<S: Simd, T: ReduceMax + Ord + Bounded + Debug>(a: &[T]) -> T {
reduce_op!(ReduceMax, Ord, |a: Vector<S, T>| a.reduce_max(), Ord::max);
test_reduce_op::<S, T, ReduceMaxOp<T>>(a, Bounded::min_value())
}
#[inline(always)]
fn test_reduce_min_float_impl<S: Simd, T: ReduceMin + Float + Bounded + Debug>(a: &[T]) -> T {
reduce_op!(
ReduceMin,
Float,
|a: Vector<S, T>| a.reduce_min(),
Float::min
);
test_reduce_op::<S, T, ReduceMinOp<T>>(a, Bounded::max_value())
}
#[inline(always)]
fn test_reduce_max_float_impl<S: Simd, T: ReduceMax + Float + Bounded + Debug>(a: &[T]) -> T {
reduce_op!(
ReduceMax,
Float,
|a: Vector<S, T>| a.reduce_max(),
Float::max
);
test_reduce_op::<S, T, ReduceMaxOp<T>>(a, Bounded::min_value())
}
pub(crate) trait ReduceOp<T: Scalar> {
fn call<S: Simd>(lhs: Vector<S, T>) -> T;
fn call_scalar(lhs: T, rhs: T) -> T;
}
#[inline(always)]
fn test_reduce_op<S: Simd, T: Scalar + Debug, Op: ReduceOp<T>>(a: &[T], default: T) -> T {
let lanes = T::lanes::<S>();
let mut output = default;
let a = a.chunks_exact(lanes);
for a in a {
let a = unsafe { vload_unaligned(a.as_ptr()) };
let val = Op::call::<S>(a);
output = Op::call_scalar(output, val);
}
output
}
macro_rules! testgen_reduce {
($test_fn: ident, $reference: expr, $default: expr, $lo: expr, $hi: expr, $size: expr, $assert: ident, $($ty: ty),*) => {
$(::paste::paste! {
#[::wasm_bindgen_test::wasm_bindgen_test(unsupported = test)]
fn [<$test_fn _ $ty>]() {
use num_traits::NumCast;
let a = $crate::tests::random_of_size::<$ty>(NumCast::from($lo).unwrap(), NumCast::from($hi).unwrap(), $size);
let out_ref: [$ty; 1] = [a.iter().copied().fold($default, |a: $ty, b: $ty| a.$reference(b))];
#[cfg(x86)]
{
use $crate::backend::x86::*;
#[cfg(fp16)]
if V4FP16::is_available() {
let out = V4FP16::run_vectorized(|| [<$test_fn _impl>]::<V4FP16, $ty>(&a));
$assert(&out_ref, &[out]);
}
#[cfg(avx512)]
if V4::is_available() {
let out = V4::run_vectorized(|| [<$test_fn _impl>]::<V4, $ty>(&a));
$assert(&out_ref, &[out]);
}
if V3::is_available() {
let out = V3::run_vectorized(|| [<$test_fn _impl>]::<V3, $ty>(&a));
$assert(&out_ref, &[out]);
}
if V2::is_available() {
let out = V2::run_vectorized(|| [<$test_fn _impl>]::<V2, $ty>(&a));
$assert(&out_ref, &[out]);
}
}
#[cfg(aarch64)]
{
use $crate::backend::aarch64::NeonFma;
if NeonFma::is_available() {
let out = NeonFma::run_vectorized(|| [<$test_fn _impl>]::<NeonFma, $ty>(&a));
$assert(&out_ref, &[out]);
}
}
#[cfg(loong64)]
{
use $crate::backend::loong64::*;
if Lasx::is_available() {
let out = Lasx::run_vectorized(|| [<$test_fn _impl>]::<Lasx, $ty>(&a));
$assert(&out_ref, &[out]);
}
if Lsx::is_available() {
let out = Lsx::run_vectorized(|| [<$test_fn _impl>]::<Lsx, $ty>(&a));
$assert(&out_ref, &[out]);
}
}
#[cfg(wasm32)]
{
use crate::backend::wasm32;
#[cfg(relaxed_simd)]
if wasm32::Simd128Relaxed::is_available() {
let out = wasm32::Simd128Relaxed::run_vectorized(|| [<$test_fn _impl>]::<wasm32::Simd128Relaxed, $ty>(&a));
$assert(&out_ref, &[out]);
}
if wasm32::Simd128Fallback::is_available() {
let out = wasm32::Simd128Fallback::run_vectorized(|| [<$test_fn _impl>]::<wasm32::Simd128Fallback, $ty>(&a));
$assert(&out_ref, &[out]);
}
}
let out = [<$test_fn _impl>]::<$crate::backend::scalar::Fallback, $ty>(&a);
$assert(&out_ref, &[out]);
}
})*
};
}
testgen_reduce!(
test_reduce_add,
add,
Zero::zero(),
1,
100,
128,
assert_approx_eq_sum,
f32,
f64
);
testgen_reduce!(
test_reduce_add,
wrapping_add,
Zero::zero(),
1,
100,
128,
assert_eq,
u16,
i16,
u32,
i32,
u64,
i64
);
testgen_reduce!(
test_reduce_add,
wrapping_add,
Zero::zero(),
0,
2,
64,
assert_eq,
u8,
i8
);
testgen_reduce!(
test_reduce_min_ord,
min,
Bounded::max_value(),
1,
100,
128,
assert_eq,
u8,
u16,
u32,
u64
);
testgen_reduce!(
test_reduce_min_ord,
min,
Bounded::max_value(),
-50,
50,
128,
assert_eq,
i8,
i16,
i32,
i64
);
testgen_reduce!(
test_reduce_min_float,
min,
Bounded::max_value(),
-50,
50,
128,
assert_eq,
f32,
f64
);
testgen_reduce!(
test_reduce_max_ord,
max,
Bounded::min_value(),
1,
100,
128,
assert_eq,
u8,
u16,
u32,
u64
);
testgen_reduce!(
test_reduce_max_ord,
max,
Bounded::min_value(),
-50,
50,
128,
assert_eq,
i8,
i16,
i32,
i64
);
testgen_reduce!(
test_reduce_max_float,
max,
Bounded::min_value(),
-50,
50,
128,
assert_eq,
f32,
f64
);
fn assert_approx_eq_sum<T: RelativeEq<Epsilon = T> + Debug + NumCast + Copy>(lhs: &[T], rhs: &[T]) {
let epsilon = T::from(2.0.powf(-8.0)).unwrap();
for (a, b) in lhs.iter().zip(rhs) {
assert_relative_eq!(*a, *b, epsilon = epsilon);
}
}