#[cfg(feature = "float")]
use num_traits::float::FloatCore;
use num_traits::PrimInt;
use super::super::dtype_strategy::Int;
#[cfg(any(feature = "float", feature = "half"))]
use super::super::dtype_strategy::{FloatIgnoreNaN, FloatReturnNaN};
trait SCALARInit<ScalarDType: Copy + PartialOrd> {
const _RETURN_AT_NAN: bool;
fn _init_min(start_value: ScalarDType) -> ScalarDType;
fn _init_max(start_value: ScalarDType) -> ScalarDType;
fn _allow_first_non_nan_update(start_value: ScalarDType) -> bool;
fn _nan_check(v: ScalarDType) -> bool;
}
pub trait ScalarArgMinMax<ScalarDType: Copy + PartialOrd> {
fn argminmax(data: &[ScalarDType]) -> (usize, usize);
fn argmin(data: &[ScalarDType]) -> usize;
fn argmax(data: &[ScalarDType]) -> usize;
}
pub struct SCALAR<DTypeStrategy> {
pub(crate) _dtype_strategy: std::marker::PhantomData<DTypeStrategy>,
}
impl<ScalarDType> SCALARInit<ScalarDType> for SCALAR<Int>
where
ScalarDType: PrimInt,
{
const _RETURN_AT_NAN: bool = false;
#[inline(always)]
fn _init_min(start_value: ScalarDType) -> ScalarDType {
start_value
}
#[inline(always)]
fn _init_max(start_value: ScalarDType) -> ScalarDType {
start_value
}
#[inline(always)]
fn _allow_first_non_nan_update(_start_value: ScalarDType) -> bool {
false
}
#[inline(always)]
fn _nan_check(_v: ScalarDType) -> bool {
false
}
}
#[cfg(feature = "float")]
impl<ScalarDType> SCALARInit<ScalarDType> for SCALAR<FloatReturnNaN>
where
ScalarDType: FloatCore,
{
const _RETURN_AT_NAN: bool = true;
#[inline(always)]
fn _init_min(start_value: ScalarDType) -> ScalarDType {
start_value
}
#[inline(always)]
fn _init_max(start_value: ScalarDType) -> ScalarDType {
start_value
}
#[inline(always)]
fn _allow_first_non_nan_update(_start_value: ScalarDType) -> bool {
false
}
#[inline(always)]
fn _nan_check(v: ScalarDType) -> bool {
v.is_nan()
}
}
#[cfg(feature = "float")]
impl<ScalarDType> SCALARInit<ScalarDType> for SCALAR<FloatIgnoreNaN>
where
ScalarDType: FloatCore,
{
const _RETURN_AT_NAN: bool = false;
#[inline(always)]
fn _init_min(start_value: ScalarDType) -> ScalarDType {
if start_value.is_nan() {
ScalarDType::infinity()
} else {
start_value
}
}
#[inline(always)]
fn _init_max(start_value: ScalarDType) -> ScalarDType {
if start_value.is_nan() {
ScalarDType::neg_infinity()
} else {
start_value
}
}
#[inline(always)]
fn _allow_first_non_nan_update(start_value: ScalarDType) -> bool {
start_value.is_nan()
}
#[inline(always)]
fn _nan_check(v: ScalarDType) -> bool {
v.is_nan()
}
}
macro_rules! impl_scalar {
($dtype_strategy:ty, $($dtype:ty),*) => {
$(
impl ScalarArgMinMax<$dtype> for SCALAR<$dtype_strategy>
{
#[inline(always)]
fn argminmax(arr: &[$dtype]) -> (usize, usize) {
assert!(!arr.is_empty());
let mut low_index: usize = 0;
let mut high_index: usize = 0;
let start_value: $dtype = unsafe { *arr.get_unchecked(0) };
let mut low: $dtype = Self::_init_min(start_value);
let mut high: $dtype = Self::_init_max(start_value);
let mut first_non_nan_update: bool = Self::_allow_first_non_nan_update(start_value);
for i in 0..arr.len() {
let v: $dtype = unsafe { *arr.get_unchecked(i) };
if <Self as SCALARInit<$dtype>>::_RETURN_AT_NAN && Self::_nan_check(v) {
return (i, i); }
if first_non_nan_update {
if !Self::_nan_check(v) {
low = v;
low_index = i;
high = v;
high_index = i;
first_non_nan_update = false;
}
} else if v < low {
low = v;
low_index = i;
} else if v > high {
high = v;
high_index = i;
}
}
(low_index, high_index)
}
#[inline(always)]
fn argmin(arr: &[$dtype]) -> usize {
assert!(!arr.is_empty());
let mut low_index: usize = 0;
let start_value: $dtype = unsafe { *arr.get_unchecked(0) };
let mut low: $dtype = Self::_init_min(start_value);
let mut first_non_nan_update: bool = Self::_allow_first_non_nan_update(start_value);
for i in 0..arr.len() {
let v: $dtype = unsafe { *arr.get_unchecked(i) };
if <Self as SCALARInit<$dtype>>::_RETURN_AT_NAN && Self::_nan_check(v) {
return i; }
if first_non_nan_update {
if !Self::_nan_check(v) {
low = v;
low_index = i;
first_non_nan_update = false;
}
} else if v < low {
low = v;
low_index = i;
}
}
low_index
}
#[inline(always)]
fn argmax(arr: &[$dtype]) -> usize {
assert!(!arr.is_empty());
let mut high_index: usize = 0;
let start_value: $dtype = unsafe { *arr.get_unchecked(0) };
let mut high: $dtype = Self::_init_max(start_value);
let mut first_non_nan_update: bool = Self::_allow_first_non_nan_update(start_value);
for i in 0..arr.len() {
let v: $dtype = unsafe { *arr.get_unchecked(i) };
if <Self as SCALARInit<$dtype>>::_RETURN_AT_NAN && Self::_nan_check(v) {
return i; }
if first_non_nan_update {
if !Self::_nan_check(v) {
high = v;
high_index = i;
first_non_nan_update = false;
}
} else if v > high {
high = v;
high_index = i;
}
}
high_index
}
}
)*
};
}
impl_scalar!(Int, i8, i16, i32, i64, u8, u16, u32, u64);
#[cfg(feature = "float")]
impl_scalar!(FloatReturnNaN, f32, f64);
#[cfg(feature = "float")]
impl_scalar!(FloatIgnoreNaN, f32, f64);
#[cfg(feature = "half")]
use super::scalar_f16::{
scalar_argmax_f16_ignore_nan, scalar_argmin_f16_ignore_nan, scalar_argminmax_f16_ignore_nan,
};
#[cfg(feature = "half")]
use super::scalar_f16::{
scalar_argmax_f16_return_nan, scalar_argmin_f16_return_nan, scalar_argminmax_f16_return_nan,
};
#[cfg(feature = "half")]
use half::f16;
#[cfg(feature = "half")]
impl ScalarArgMinMax<f16> for SCALAR<FloatReturnNaN> {
#[inline(always)]
fn argminmax(arr: &[f16]) -> (usize, usize) {
scalar_argminmax_f16_return_nan(arr)
}
#[inline(always)]
fn argmin(arr: &[f16]) -> usize {
scalar_argmin_f16_return_nan(arr)
}
#[inline(always)]
fn argmax(arr: &[f16]) -> usize {
scalar_argmax_f16_return_nan(arr)
}
}
#[cfg(feature = "half")]
impl ScalarArgMinMax<f16> for SCALAR<FloatIgnoreNaN> {
#[inline(always)]
fn argminmax(arr: &[f16]) -> (usize, usize) {
scalar_argminmax_f16_ignore_nan(arr)
}
#[inline(always)]
fn argmin(arr: &[f16]) -> usize {
scalar_argmin_f16_ignore_nan(arr)
}
#[inline(always)]
fn argmax(arr: &[f16]) -> usize {
scalar_argmax_f16_ignore_nan(arr)
}
}