use std::cmp::Ordering;
#[inline(always)]
pub(crate) fn argminmax_generic<T: Copy + PartialOrd>(
arr: &[T],
lane_size: usize,
core_argminmax: unsafe fn(&[T]) -> (usize, T, usize, T),
scalar_argminmax: fn(&[T]) -> (usize, usize),
nan_check: fn(T) -> bool, ignore_nan: bool, ) -> (usize, usize) {
assert!(!arr.is_empty()); match split_array(arr, lane_size) {
(Some(simd_arr), Some(rem)) => {
let simd_result = unsafe { core_argminmax(simd_arr) };
let (rem_min_index, rem_max_index) = scalar_argminmax(rem);
let rem_result = (
rem_min_index + simd_arr.len(),
rem[rem_min_index],
rem_max_index + simd_arr.len(),
rem[rem_max_index],
);
let (min_index, min_value) = find_final_index_min(
(simd_result.0, simd_result.1),
(rem_result.0, rem_result.1),
nan_check,
ignore_nan,
);
let (max_index, max_value) = find_final_index_max(
(simd_result.2, simd_result.3),
(rem_result.2, rem_result.3),
nan_check,
ignore_nan,
);
get_correct_argminmax_result(
min_index, min_value, max_index, max_value, nan_check, ignore_nan,
)
}
(Some(simd_arr), None) => {
let (min_index, min_value, max_index, max_value) = unsafe { core_argminmax(simd_arr) };
get_correct_argminmax_result(
min_index, min_value, max_index, max_value, nan_check, ignore_nan,
)
}
(None, Some(rem)) => {
let (rem_min_index, rem_max_index) = scalar_argminmax(rem);
(rem_min_index, rem_max_index)
}
(None, None) => panic!("Array is empty"), }
}
#[inline(always)]
pub(crate) fn argmin_generic<T: Copy + PartialOrd>(
arr: &[T],
lane_size: usize,
core_argmin: unsafe fn(&[T]) -> (usize, T),
scalar_argmin: fn(&[T]) -> usize,
nan_check: fn(T) -> bool, ignore_nan: bool, ) -> usize {
assert!(!arr.is_empty()); match split_array(arr, lane_size) {
(Some(simd_arr), Some(rem)) => {
let simd_result = unsafe { core_argmin(simd_arr) };
let rem_min_index = scalar_argmin(rem);
let rem_result = (rem_min_index + simd_arr.len(), rem[rem_min_index]);
let (min_index, _) =
find_final_index_min(simd_result, rem_result, nan_check, ignore_nan);
min_index
}
(Some(simd_arr), None) => {
let (min_index, _) = unsafe { core_argmin(simd_arr) };
min_index
}
(None, Some(rem)) => scalar_argmin(rem),
(None, None) => panic!("Array is empty"), }
}
#[inline(always)]
pub(crate) fn argmax_generic<T: Copy + PartialOrd>(
arr: &[T],
lane_size: usize,
core_argmax: unsafe fn(&[T]) -> (usize, T),
scalar_argmax: fn(&[T]) -> usize,
nan_check: fn(T) -> bool, ignore_nan: bool, ) -> usize {
assert!(!arr.is_empty()); match split_array(arr, lane_size) {
(Some(simd_arr), Some(rem)) => {
let simd_result = unsafe { core_argmax(simd_arr) };
let rem_max_index = scalar_argmax(rem);
let rem_result = (rem_max_index + simd_arr.len(), rem[rem_max_index]);
let (max_index, _) =
find_final_index_max(simd_result, rem_result, nan_check, ignore_nan);
max_index
}
(Some(simd_arr), None) => {
let (max_index, _) = unsafe { core_argmax(simd_arr) };
max_index
}
(None, Some(rem)) => scalar_argmax(rem),
(None, None) => panic!("Array is empty"), }
}
#[inline(always)]
fn split_array<T: Copy>(arr: &[T], lane_size: usize) -> (Option<&[T]>, Option<&[T]>) {
let n = arr.len();
let (left_arr, right_arr) = arr.split_at(n - n % lane_size);
match (left_arr.is_empty(), right_arr.is_empty()) {
(true, true) => (None, None),
(false, false) => (Some(left_arr), Some(right_arr)),
(true, false) => (None, Some(right_arr)),
(false, true) => (Some(left_arr), None),
}
}
#[inline(always)]
fn find_final_index_min<T: Copy + PartialOrd>(
simd_result: (usize, T),
remainder_result: (usize, T),
nan_check: fn(T) -> bool,
ignore_nan: bool,
) -> (usize, T) {
let (min_index, min_value) = match simd_result.1.partial_cmp(&remainder_result.1) {
Some(Ordering::Less) => simd_result,
Some(Ordering::Equal) => simd_result,
Some(Ordering::Greater) => remainder_result,
None => {
if !ignore_nan {
if nan_check(simd_result.1) {
simd_result
} else {
remainder_result
}
} else {
if nan_check(simd_result.1) && nan_check(remainder_result.1) {
panic!("Data contains only NaNs (or +/- inf)")
} else if nan_check(remainder_result.1) {
simd_result
} else {
remainder_result
}
}
}
};
(min_index, min_value)
}
#[inline(always)]
fn find_final_index_max<T: Copy + PartialOrd>(
simd_result: (usize, T),
remainder_result: (usize, T),
nan_check: fn(T) -> bool,
ignore_nan: bool,
) -> (usize, T) {
let (max_index, max_value) = match simd_result.1.partial_cmp(&remainder_result.1) {
Some(Ordering::Greater) => simd_result,
Some(Ordering::Equal) => simd_result,
Some(Ordering::Less) => remainder_result,
None => {
if !ignore_nan {
if nan_check(simd_result.1) {
simd_result
} else {
remainder_result
}
} else {
if nan_check(simd_result.1) && nan_check(remainder_result.1) {
panic!("Data contains only NaNs (or +/- inf)")
} else if nan_check(remainder_result.1) {
simd_result
} else {
remainder_result
}
}
}
};
(max_index, max_value)
}
fn get_correct_argminmax_result<T: Copy + PartialOrd>(
min_index: usize,
min_value: T,
max_index: usize,
max_value: T,
nan_check: fn(T) -> bool,
ignore_nan: bool,
) -> (usize, usize) {
if !ignore_nan && (nan_check(min_value) || nan_check(max_value)) {
if nan_check(min_value) && nan_check(max_value) {
let lowest_index = std::cmp::min(min_index, max_index);
return (lowest_index, lowest_index);
} else if nan_check(min_value) {
return (min_index, min_index);
} else {
return (max_index, max_index);
}
}
(min_index, max_index)
}
pub(crate) fn min_index_value<T: Copy + PartialOrd>(index: &[T], values: &[T]) -> (T, T) {
assert!(!index.is_empty());
assert_eq!(index.len(), values.len());
let mut min_index: T = unsafe { *index.get_unchecked(0) };
let mut min_value: T = unsafe { *values.get_unchecked(0) };
for i in 0..values.len() {
let v: T = unsafe { *values.get_unchecked(i) };
let idx: T = unsafe { *index.get_unchecked(i) };
if v < min_value || (v == min_value && idx < min_index) {
min_value = v;
min_index = idx;
}
}
(min_index, min_value)
}
pub(crate) fn max_index_value<T: Copy + PartialOrd>(index: &[T], values: &[T]) -> (T, T) {
assert!(!index.is_empty());
assert_eq!(index.len(), values.len());
let mut max_index: T = unsafe { *index.get_unchecked(0) };
let mut max_value: T = unsafe { *values.get_unchecked(0) };
for i in 0..values.len() {
let v: T = unsafe { *values.get_unchecked(i) };
let idx: T = unsafe { *index.get_unchecked(i) };
if v > max_value || (v == max_value && idx < max_index) {
max_value = v;
max_index = idx;
}
}
(max_index, max_value)
}