use num_traits::AsPrimitive;
use super::task::*;
use crate::scalar::ScalarArgMinMax;
#[doc(hidden)]
pub trait SIMDOps<ScalarDType, SIMDVecDtype, SIMDMaskDtype, const LANE_SIZE: usize>
where
ScalarDType: Copy + PartialOrd + AsPrimitive<usize>,
SIMDVecDtype: Copy,
SIMDMaskDtype: Copy,
{
const MAX_INDEX: usize;
const INITIAL_INDEX: SIMDVecDtype;
const INDEX_INCREMENT: SIMDVecDtype;
unsafe fn _reg_to_arr(reg: SIMDVecDtype) -> [ScalarDType; LANE_SIZE];
unsafe fn _mm_loadu(data: *const ScalarDType) -> SIMDVecDtype;
unsafe fn _mm_add(a: SIMDVecDtype, b: SIMDVecDtype) -> SIMDVecDtype;
unsafe fn _mm_cmpgt(a: SIMDVecDtype, b: SIMDVecDtype) -> SIMDMaskDtype;
unsafe fn _mm_cmplt(a: SIMDVecDtype, b: SIMDVecDtype) -> SIMDMaskDtype;
unsafe fn _mm_blendv(a: SIMDVecDtype, b: SIMDVecDtype, mask: SIMDMaskDtype) -> SIMDVecDtype;
#[inline(always)]
unsafe fn _horiz_min(index: SIMDVecDtype, value: SIMDVecDtype) -> (usize, ScalarDType) {
let index_arr = Self::_reg_to_arr(index);
let value_arr = Self::_reg_to_arr(value);
let (min_index, min_value) = min_index_value(&index_arr, &value_arr);
(min_index.as_(), min_value)
}
#[inline(always)]
unsafe fn _horiz_max(index: SIMDVecDtype, value: SIMDVecDtype) -> (usize, ScalarDType) {
let index_arr = Self::_reg_to_arr(index);
let value_arr = Self::_reg_to_arr(value);
let (max_index, max_value) = max_index_value(&index_arr, &value_arr);
(max_index.as_(), max_value)
}
#[inline(always)]
fn _get_overflow_lane_size_limit() -> usize {
Self::MAX_INDEX - Self::MAX_INDEX % LANE_SIZE
}
#[inline(always)]
unsafe fn _mm_set1(_value: ScalarDType) -> SIMDVecDtype {
unreachable!()
}
}
#[doc(hidden)]
pub trait SIMDInit<ScalarDType, SIMDVecDtype, SIMDMaskDtype, const LANE_SIZE: usize>:
SIMDOps<ScalarDType, SIMDVecDtype, SIMDMaskDtype, LANE_SIZE>
where
ScalarDType: Copy + PartialOrd + AsPrimitive<usize>,
SIMDVecDtype: Copy,
SIMDMaskDtype: Copy,
{
const IGNORE_NAN: bool = false;
#[inline(always)]
unsafe fn _initialize_index_values_low(
arr_ptr: *const ScalarDType,
) -> (SIMDVecDtype, SIMDVecDtype) {
(Self::INITIAL_INDEX, Self::_mm_loadu(arr_ptr))
}
#[inline(always)]
unsafe fn _initialize_index_values_high(
arr_ptr: *const ScalarDType,
) -> (SIMDVecDtype, SIMDVecDtype) {
(Self::INITIAL_INDEX, Self::_mm_loadu(arr_ptr))
}
#[inline(always)]
fn _initialize_min_value(arr: &[ScalarDType]) -> ScalarDType {
unsafe { *arr.get_unchecked(0) }
}
#[inline(always)]
fn _initialize_max_value(arr: &[ScalarDType]) -> ScalarDType {
unsafe { *arr.get_unchecked(0) }
}
#[inline(always)]
fn _return_check(_v: ScalarDType) -> bool {
false
}
#[inline(always)]
fn _nan_check(_v: ScalarDType) -> bool {
false
}
}
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
all(target_arch = "arm", feature = "nightly_simd"),
target_arch = "aarch64",
))]
macro_rules! impl_SIMDInit_Int {
($scalar_dtype:ty, $simd_vec_dtype:ty, $simd_mask_dtype:ty, $lane_size:expr, $simd_struct:ty) => {
impl SIMDInit<$scalar_dtype, $simd_vec_dtype, $simd_mask_dtype, $lane_size>
for $simd_struct
{
}
};
}
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
all(target_arch = "arm", feature = "nightly_simd"),
target_arch = "aarch64",
))]
pub(crate) use impl_SIMDInit_Int;
#[cfg(any(feature = "float", feature = "half"))]
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
all(target_arch = "arm", feature = "nightly_simd"),
target_arch = "aarch64",
))]
macro_rules! impl_SIMDInit_FloatReturnNaN {
($scalar_dtype:ty, $simd_vec_dtype:ty, $simd_mask_dtype:ty, $lane_size:expr, $simd_struct:ty) => {
impl SIMDInit<$scalar_dtype, $simd_vec_dtype, $simd_mask_dtype, $lane_size>
for $simd_struct
{
#[inline(always)]
fn _return_check(v: $scalar_dtype) -> bool {
v.is_nan()
}
#[inline(always)]
fn _nan_check(v: $scalar_dtype) -> bool {
v.is_nan()
}
}
};
}
#[cfg(any(feature = "float", feature = "half"))]
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
all(target_arch = "arm", feature = "nightly_simd"),
target_arch = "aarch64",
))]
pub(crate) use impl_SIMDInit_FloatReturnNaN;
#[cfg(any(feature = "float", feature = "half"))]
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
all(target_arch = "arm", feature = "nightly_simd"),
target_arch = "aarch64",
))]
macro_rules! impl_SIMDInit_FloatIgnoreNaN {
($scalar_dtype:ty, $simd_vec_dtype:ty, $simd_mask_dtype:ty, $lane_size:expr, $simd_struct:ty) => {
impl SIMDInit<$scalar_dtype, $simd_vec_dtype, $simd_mask_dtype, $lane_size>
for $simd_struct
{
const IGNORE_NAN: bool = true;
#[inline(always)]
unsafe fn _initialize_index_values_low(
arr_ptr: *const $scalar_dtype,
) -> ($simd_vec_dtype, $simd_vec_dtype) {
let new_values = Self::_mm_loadu(arr_ptr);
let mask_low =
Self::_mm_cmplt(new_values, Self::_mm_set1(<$scalar_dtype>::INFINITY));
let values_low = Self::_mm_blendv(
Self::_mm_set1(<$scalar_dtype>::INFINITY),
new_values,
mask_low,
);
let index_low = Self::_mm_blendv(
Self::_mm_set1(<$scalar_dtype>::zero()),
Self::INITIAL_INDEX,
mask_low,
);
(index_low, values_low)
}
#[inline(always)]
unsafe fn _initialize_index_values_high(
arr_ptr: *const $scalar_dtype,
) -> ($simd_vec_dtype, $simd_vec_dtype) {
let new_values = Self::_mm_loadu(arr_ptr);
let mask_high =
Self::_mm_cmpgt(new_values, Self::_mm_set1(<$scalar_dtype>::NEG_INFINITY));
let values_high = Self::_mm_blendv(
Self::_mm_set1(<$scalar_dtype>::NEG_INFINITY),
new_values,
mask_high,
);
let index_high = Self::_mm_blendv(
Self::_mm_set1(<$scalar_dtype>::zero()),
Self::INITIAL_INDEX,
mask_high,
);
(index_high, values_high)
}
#[inline(always)]
fn _initialize_min_value(_: &[$scalar_dtype]) -> $scalar_dtype {
<$scalar_dtype>::INFINITY
}
#[inline(always)]
fn _initialize_max_value(_: &[$scalar_dtype]) -> $scalar_dtype {
<$scalar_dtype>::NEG_INFINITY
}
#[inline(always)]
fn _nan_check(v: $scalar_dtype) -> bool {
v.is_nan()
}
}
};
}
#[cfg(any(feature = "float", feature = "half"))]
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
all(target_arch = "arm", feature = "nightly_simd"),
target_arch = "aarch64",
))]
pub(crate) use impl_SIMDInit_FloatIgnoreNaN;
#[doc(hidden)]
pub trait SIMDCore<ScalarDType, SIMDVecDtype, SIMDMaskDtype, const LANE_SIZE: usize>:
SIMDOps<ScalarDType, SIMDVecDtype, SIMDMaskDtype, LANE_SIZE>
+ SIMDInit<ScalarDType, SIMDVecDtype, SIMDMaskDtype, LANE_SIZE>
where
ScalarDType: Copy + PartialOrd + AsPrimitive<usize>,
SIMDVecDtype: Copy,
SIMDMaskDtype: Copy,
{
#[inline(always)]
unsafe fn _core_argminmax(arr: &[ScalarDType]) -> (usize, ScalarDType, usize, ScalarDType) {
assert_eq!(arr.len() % LANE_SIZE, 0);
let mut arr_ptr = arr.as_ptr(); let mut new_index = Self::INITIAL_INDEX; let (mut index_low, mut values_low) = Self::_initialize_index_values_low(arr_ptr);
let (mut index_high, mut values_high) = Self::_initialize_index_values_high(arr_ptr);
for _ in 0..arr.len() / LANE_SIZE - 1 {
new_index = Self::_mm_add(new_index, Self::INDEX_INCREMENT);
arr_ptr = arr_ptr.add(LANE_SIZE);
let new_values = Self::_mm_loadu(arr_ptr);
let mask_low = Self::_mm_cmplt(new_values, values_low);
values_low = Self::_mm_blendv(values_low, new_values, mask_low);
index_low = Self::_mm_blendv(index_low, new_index, mask_low);
let mask_high = Self::_mm_cmpgt(new_values, values_high);
values_high = Self::_mm_blendv(values_high, new_values, mask_high);
index_high = Self::_mm_blendv(index_high, new_index, mask_high);
}
let (min_index, min_value) = Self::_horiz_min(index_low, values_low);
let (max_index, max_value) = Self::_horiz_max(index_high, values_high);
(min_index, min_value, max_index, max_value)
}
#[inline(always)]
unsafe fn _core_argmin(arr: &[ScalarDType]) -> (usize, ScalarDType) {
let mut arr_ptr = arr.as_ptr(); let mut new_index = Self::INITIAL_INDEX; let (mut index_low, mut values_low) = Self::_initialize_index_values_low(arr_ptr);
for _ in 0..arr.len() / LANE_SIZE - 1 {
new_index = Self::_mm_add(new_index, Self::INDEX_INCREMENT);
arr_ptr = arr_ptr.add(LANE_SIZE);
let new_values = Self::_mm_loadu(arr_ptr);
let mask_low = Self::_mm_cmplt(new_values, values_low);
values_low = Self::_mm_blendv(values_low, new_values, mask_low);
index_low = Self::_mm_blendv(index_low, new_index, mask_low);
}
Self::_horiz_min(index_low, values_low)
}
#[inline(always)]
unsafe fn _core_argmax(arr: &[ScalarDType]) -> (usize, ScalarDType) {
let mut arr_ptr = arr.as_ptr(); let mut new_index = Self::INITIAL_INDEX; let (mut index_high, mut values_high) = Self::_initialize_index_values_high(arr_ptr);
for _ in 0..arr.len() / LANE_SIZE - 1 {
new_index = Self::_mm_add(new_index, Self::INDEX_INCREMENT);
arr_ptr = arr_ptr.add(LANE_SIZE);
let new_values = Self::_mm_loadu(arr_ptr);
let mask_high = Self::_mm_cmpgt(new_values, values_high);
values_high = Self::_mm_blendv(values_high, new_values, mask_high);
index_high = Self::_mm_blendv(index_high, new_index, mask_high);
}
Self::_horiz_max(index_high, values_high)
}
#[inline(always)]
unsafe fn _overflow_safe_core_argminmax(
arr: &[ScalarDType],
) -> (usize, ScalarDType, usize, ScalarDType) {
assert!(!arr.is_empty());
assert_eq!(arr.len() % LANE_SIZE, 0);
let dtype_max = Self::_get_overflow_lane_size_limit();
let n_loops = arr.len() / dtype_max;
let mut min_index: usize = 0;
let mut min_value: ScalarDType = Self::_initialize_min_value(arr);
let mut max_index: usize = 0;
let mut max_value: ScalarDType = Self::_initialize_max_value(arr);
let mut start: usize = 0;
for _ in 0..n_loops {
if Self::_return_check(min_value) || Self::_return_check(max_value) {
return (min_index, min_value, max_index, max_value);
}
let (min_index_, min_value_, max_index_, max_value_) =
Self::_core_argminmax(&arr[start..start + dtype_max]);
if min_value_ < min_value || Self::_return_check(min_value_) {
min_index = start + min_index_;
min_value = min_value_;
}
if max_value_ > max_value || Self::_return_check(max_value_) {
max_index = start + max_index_;
max_value = max_value_;
}
start += dtype_max;
}
if start < arr.len() {
if Self::_return_check(min_value) || Self::_return_check(max_value) {
return (min_index, min_value, max_index, max_value);
}
let (min_index_, min_value_, max_index_, max_value_) =
Self::_core_argminmax(&arr[start..]);
if min_value_ < min_value || Self::_return_check(min_value_) {
min_index = start + min_index_;
min_value = min_value_;
}
if max_value_ > max_value || Self::_return_check(max_value_) {
max_index = start + max_index_;
max_value = max_value_;
}
}
(min_index, min_value, max_index, max_value)
}
#[inline(always)]
unsafe fn _overflow_safe_core_argmin(arr: &[ScalarDType]) -> (usize, ScalarDType) {
assert!(!arr.is_empty());
assert_eq!(arr.len() % LANE_SIZE, 0);
let dtype_max = Self::_get_overflow_lane_size_limit();
let n_loops = arr.len() / dtype_max;
let mut min_index: usize = 0;
let mut min_value: ScalarDType = Self::_initialize_min_value(arr);
let mut start: usize = 0;
for _ in 0..n_loops {
if Self::_return_check(min_value) {
return (min_index, min_value);
}
let (min_index_, min_value_) = Self::_core_argmin(&arr[start..start + dtype_max]);
if min_value_ < min_value || Self::_return_check(min_value_) {
min_index = start + min_index_;
min_value = min_value_;
}
start += dtype_max;
}
if start < arr.len() {
if Self::_return_check(min_value) {
return (min_index, min_value);
}
let (min_index_, min_value_) = Self::_core_argmin(&arr[start..]);
if min_value_ < min_value || Self::_return_check(min_value_) {
min_index = start + min_index_;
min_value = min_value_;
}
}
(min_index, min_value)
}
#[inline(always)]
unsafe fn _overflow_safe_core_argmax(arr: &[ScalarDType]) -> (usize, ScalarDType) {
assert!(!arr.is_empty());
assert_eq!(arr.len() % LANE_SIZE, 0);
let dtype_max = Self::_get_overflow_lane_size_limit();
let n_loops = arr.len() / dtype_max;
let mut max_index: usize = 0;
let mut max_value: ScalarDType = Self::_initialize_max_value(arr);
let mut start: usize = 0;
for _ in 0..n_loops {
if Self::_return_check(max_value) {
return (max_index, max_value);
}
let (max_index_, max_value_) = Self::_core_argmax(&arr[start..start + dtype_max]);
if max_value_ > max_value || Self::_return_check(max_value_) {
max_index = start + max_index_;
max_value = max_value_;
}
start += dtype_max;
}
if start < arr.len() {
if Self::_return_check(max_value) {
return (max_index, max_value);
}
let (max_index_, max_value_) = Self::_core_argmax(&arr[start..]);
if max_value_ > max_value || Self::_return_check(max_value_) {
max_index = start + max_index_;
max_value = max_value_;
}
}
(max_index, max_value)
}
}
impl<T, ScalarDType, SIMDVecDtype, SIMDMaskDtype, const LANE_SIZE: usize>
SIMDCore<ScalarDType, SIMDVecDtype, SIMDMaskDtype, LANE_SIZE> for T
where
ScalarDType: Copy + PartialOrd + AsPrimitive<usize>,
SIMDVecDtype: Copy,
SIMDMaskDtype: Copy,
T: SIMDOps<ScalarDType, SIMDVecDtype, SIMDMaskDtype, LANE_SIZE>
+ SIMDInit<ScalarDType, SIMDVecDtype, SIMDMaskDtype, LANE_SIZE>,
{
}
pub trait SIMDArgMinMax<ScalarDType, SIMDVecDtype, SIMDMaskDtype, const LANE_SIZE: usize, SCALAR>:
SIMDCore<ScalarDType, SIMDVecDtype, SIMDMaskDtype, LANE_SIZE>
where
ScalarDType: Copy + PartialOrd + AsPrimitive<usize>,
SIMDVecDtype: Copy,
SIMDMaskDtype: Copy,
SCALAR: ScalarArgMinMax<ScalarDType>,
{
unsafe fn argminmax(data: &[ScalarDType]) -> (usize, usize);
#[doc(hidden)]
#[inline(always)]
unsafe fn _argminmax(data: &[ScalarDType]) -> (usize, usize)
where
SCALAR: ScalarArgMinMax<ScalarDType>,
{
argminmax_generic(
data,
LANE_SIZE,
Self::_overflow_safe_core_argminmax, SCALAR::argminmax, Self::_nan_check, Self::IGNORE_NAN, )
}
unsafe fn argmin(data: &[ScalarDType]) -> usize;
#[doc(hidden)]
#[inline(always)]
unsafe fn _argmin(data: &[ScalarDType]) -> usize
where
SCALAR: ScalarArgMinMax<ScalarDType>,
{
argmin_generic(
data,
LANE_SIZE,
Self::_overflow_safe_core_argmin, SCALAR::argmin, Self::_nan_check, Self::IGNORE_NAN, )
}
unsafe fn argmax(data: &[ScalarDType]) -> usize;
#[doc(hidden)]
#[inline(always)]
unsafe fn _argmax(data: &[ScalarDType]) -> usize
where
SCALAR: ScalarArgMinMax<ScalarDType>,
{
argmax_generic(
data,
LANE_SIZE,
Self::_overflow_safe_core_argmax, SCALAR::argmax, Self::_nan_check, Self::IGNORE_NAN, )
}
}
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
all(target_arch = "arm", feature = "nightly_simd"),
target_arch = "aarch64",
))]
macro_rules! impl_SIMDArgMinMax {
($scalar_dtype:ty, $simd_vec_dtype:ty, $simd_mask_dtype:ty, $lane_size:expr, $scalar_struct:ty, $simd_struct:ty, $target:expr) => {
impl
SIMDArgMinMax<
$scalar_dtype,
$simd_vec_dtype,
$simd_mask_dtype,
$lane_size,
$scalar_struct,
> for $simd_struct
{
#[target_feature(enable = $target)]
unsafe fn argminmax(data: &[$scalar_dtype]) -> (usize, usize) {
Self::_argminmax(data)
}
#[target_feature(enable = $target)]
unsafe fn argmin(data: &[$scalar_dtype]) -> usize {
Self::_argmin(data)
}
#[target_feature(enable = $target)]
unsafe fn argmax(data: &[$scalar_dtype]) -> usize {
Self::_argmax(data)
}
}
};
}
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
all(target_arch = "arm", feature = "nightly_simd"),
target_arch = "aarch64",
))]
pub(crate) use impl_SIMDArgMinMax;
#[cfg(all(target_arch = "arm", feature = "nightly_simd"))]
macro_rules! unimpl_SIMDOps {
($scalar_type:ty, $reg:ty, $simd_struct:ty) => {
impl SIMDOps<$scalar_type, $reg, $reg, 0> for $simd_struct {
const INITIAL_INDEX: $reg = 0;
const INDEX_INCREMENT: $reg = 0;
const MAX_INDEX: usize = 0;
unsafe fn _reg_to_arr(_reg: $reg) -> [$scalar_type; 0] {
unimplemented!()
}
unsafe fn _mm_loadu(_data: *const $scalar_type) -> $reg {
unimplemented!()
}
unsafe fn _mm_add(_a: $reg, _b: $reg) -> $reg {
unimplemented!()
}
unsafe fn _mm_cmpgt(_a: $reg, _b: $reg) -> $reg {
unimplemented!()
}
unsafe fn _mm_cmplt(_a: $reg, _b: $reg) -> $reg {
unimplemented!()
}
unsafe fn _mm_blendv(_a: $reg, _b: $reg, _mask: $reg) -> $reg {
unimplemented!()
}
}
};
}
#[cfg(all(target_arch = "arm", feature = "nightly_simd"))]
macro_rules! unimpl_SIMDInit {
($scalar_type:ty, $reg:ty, $simd_struct:ty) => {
impl SIMDInit<$scalar_type, $reg, $reg, 0> for $simd_struct {
}
};
}
#[cfg(all(target_arch = "arm", feature = "nightly_simd"))]
macro_rules! unimpl_SIMDArgMinMax {
($scalar_type:ty, $reg:ty, $scalar:ty, $simd_struct:ty) => {
impl SIMDArgMinMax<$scalar_type, $reg, $reg, 0, $scalar> for $simd_struct {
unsafe fn argminmax(_data: &[$scalar_type]) -> (usize, usize) {
unimplemented!()
}
unsafe fn argmin(_data: &[$scalar_type]) -> usize {
unimplemented!()
}
unsafe fn argmax(_data: &[$scalar_type]) -> usize {
unimplemented!()
}
}
};
}
#[cfg(all(target_arch = "arm", feature = "nightly_simd"))]
pub(crate) use unimpl_SIMDArgMinMax;
#[cfg(all(target_arch = "arm", feature = "nightly_simd"))]
pub(crate) use unimpl_SIMDInit;
#[cfg(all(target_arch = "arm", feature = "nightly_simd"))]
pub(crate) use unimpl_SIMDOps;