use half::f16;
#[inline(always)]
fn f16_to_i16ord(x: f16) -> i16 {
let x = unsafe { std::mem::transmute::<f16, i16>(x) };
((x >> 15) & 0x7FFF) ^ x
}
pub(crate) fn scalar_argminmax_f16_return_nan(arr: &[f16]) -> (usize, usize) {
assert!(!arr.is_empty());
let mut low_index: usize = 0;
let mut high_index: usize = 0;
let mut low: i16 = f16_to_i16ord(unsafe { *arr.get_unchecked(low_index) });
let mut high: i16 = f16_to_i16ord(unsafe { *arr.get_unchecked(high_index) });
for i in 0..arr.len() {
let v: f16 = unsafe { *arr.get_unchecked(i) };
if v.is_nan() {
return (i, i);
}
let v: i16 = f16_to_i16ord(v);
if v < low {
low = v;
low_index = i;
} else if v > high {
high = v;
high_index = i;
}
}
(low_index, high_index)
}
pub(crate) fn scalar_argmin_f16_return_nan(arr: &[f16]) -> usize {
assert!(!arr.is_empty());
let mut low_index: usize = 0;
let mut low: i16 = f16_to_i16ord(unsafe { *arr.get_unchecked(low_index) });
for i in 0..arr.len() {
let v: f16 = unsafe { *arr.get_unchecked(i) };
if v.is_nan() {
return i;
}
let v: i16 = f16_to_i16ord(v);
if v < low {
low = v;
low_index = i;
}
}
low_index
}
pub(crate) fn scalar_argmax_f16_return_nan(arr: &[f16]) -> usize {
assert!(!arr.is_empty());
let mut high_index: usize = 0;
let mut high: i16 = f16_to_i16ord(unsafe { *arr.get_unchecked(high_index) });
for i in 0..arr.len() {
let v: f16 = unsafe { *arr.get_unchecked(i) };
if v.is_nan() {
return i;
}
let v: i16 = f16_to_i16ord(v);
if v > high {
high = v;
high_index = i;
}
}
high_index
}
pub(crate) fn scalar_argminmax_f16_ignore_nan(arr: &[f16]) -> (usize, usize) {
assert!(!arr.is_empty());
let mut low_index: usize = 0;
let mut high_index: usize = 0;
let mut low: i16 = f16_to_i16ord(f16::INFINITY);
let mut high: i16 = f16_to_i16ord(f16::NEG_INFINITY);
let mut first_non_nan_update = true;
for i in 0..arr.len() {
let v: f16 = unsafe { *arr.get_unchecked(i) };
if v.is_nan() {
} else {
let v: i16 = f16_to_i16ord(v);
if first_non_nan_update {
low = v;
high = v;
low_index = i;
high_index = i;
first_non_nan_update = false;
} else if v < low {
low = v;
low_index = i;
} else if v > high {
high = v;
high_index = i;
}
}
}
(low_index, high_index)
}
pub(crate) fn scalar_argmin_f16_ignore_nan(arr: &[f16]) -> usize {
assert!(!arr.is_empty());
let mut low_index: usize = 0;
let mut low: i16 = f16_to_i16ord(f16::INFINITY);
for i in 0..arr.len() {
let v: f16 = unsafe { *arr.get_unchecked(i) };
if v.is_nan() {
} else {
let v: i16 = f16_to_i16ord(v);
if v < low {
low = v;
low_index = i;
}
}
}
low_index
}
pub(crate) fn scalar_argmax_f16_ignore_nan(arr: &[f16]) -> usize {
assert!(!arr.is_empty());
let mut high_index: usize = 0;
let mut high: i16 = f16_to_i16ord(f16::NEG_INFINITY);
for i in 0..arr.len() {
let v: f16 = unsafe { *arr.get_unchecked(i) };
if v.is_nan() {
} else {
let v: i16 = f16_to_i16ord(v);
if v > high {
high = v;
high_index = i;
}
}
}
high_index
}
#[cfg(all(feature = "float", feature = "half"))]
#[cfg(test)]
mod tests {
use super::{
scalar_argmax_f16_ignore_nan, scalar_argmin_f16_ignore_nan, scalar_argminmax_f16_ignore_nan,
};
use super::{
scalar_argmax_f16_return_nan, scalar_argmin_f16_return_nan, scalar_argminmax_f16_return_nan,
};
use crate::{FloatIgnoreNaN, FloatReturnNaN, ScalarArgMinMax, SCALAR};
use half::f16;
use dev_utils::utils;
const ARR_LEN: usize = 1025;
fn get_arrays(len: usize) -> (Vec<f32>, Vec<f16>) {
let vec_f16: Vec<f16> = utils::SampleUniformFullRange::get_random_array(len);
let vec_f32: Vec<f32> = vec_f16.iter().map(|x| x.to_f32()).collect();
(vec_f32, vec_f16)
}
#[test]
fn test_generic_and_specific_impl_return_the_same_results() {
for _ in 0..100 {
let (vec_f32, vec_f16) = get_arrays(ARR_LEN);
let data_f32: &[f32] = &vec_f32;
let data_f16: &[f16] = &vec_f16;
let (argmin_index, argmax_index) = SCALAR::<FloatReturnNaN>::argminmax(data_f32);
let (argmin_index_f16, argmax_index_f16) = scalar_argminmax_f16_return_nan(data_f16);
let argmin_index_f16_single = scalar_argmin_f16_return_nan(data_f16);
let argmax_index_f16_single = scalar_argmax_f16_return_nan(data_f16);
assert_eq!(argmin_index, argmin_index_f16);
assert_eq!(argmax_index, argmax_index_f16_single);
assert_eq!(argmax_index, argmax_index_f16);
assert_eq!(argmin_index, argmin_index_f16_single);
let (argmin_index, argmax_index) = SCALAR::<FloatIgnoreNaN>::argminmax(data_f32);
let (argmin_index_f16, argmax_index_f16) = scalar_argminmax_f16_ignore_nan(data_f16);
let argmin_index_f16_single = scalar_argmin_f16_ignore_nan(data_f16);
let argmax_index_f16_single = scalar_argmax_f16_ignore_nan(data_f16);
assert_eq!(argmin_index, argmin_index_f16);
assert_eq!(argmin_index, argmin_index_f16_single);
assert_eq!(argmax_index, argmax_index_f16);
assert_eq!(argmax_index, argmax_index_f16_single);
}
}
#[test]
fn test_generic_and_specific_impl_return_nans() {
let nan_pos: [usize; 3] = [0, ARR_LEN / 2, ARR_LEN - 1];
for pos in nan_pos.iter() {
let (vec_f32, vec_f16) = get_arrays(ARR_LEN);
let mut data_f32: Vec<f32> = vec_f32;
let mut data_f16: Vec<f16> = vec_f16;
data_f32[*pos] = f32::NAN;
data_f16[*pos] = f16::NAN;
let (argmin_index, argmax_index) = SCALAR::<FloatReturnNaN>::argminmax(&data_f32);
let (argmin_index_f16, argmax_index_f16) = scalar_argminmax_f16_return_nan(&data_f16);
let argmin_index_f16_single = scalar_argmin_f16_return_nan(&data_f16);
let argmax_index_f16_single = scalar_argmax_f16_return_nan(&data_f16);
assert_eq!(argmin_index, argmin_index_f16);
assert_eq!(argmin_index, argmin_index_f16_single);
assert_eq!(argmax_index, argmax_index_f16);
assert_eq!(argmax_index, argmax_index_f16_single);
}
let (mut vec_f32, mut vec_f16) = get_arrays(ARR_LEN);
vec_f32.iter_mut().for_each(|x| *x = f32::NAN);
vec_f16.iter_mut().for_each(|x| *x = f16::NAN);
let data_f32: &[f32] = &vec_f32;
let data_f16: &[f16] = &vec_f16;
let (argmin_index, argmax_index) = SCALAR::<FloatReturnNaN>::argminmax(data_f32);
let (argmin_index_f16, argmax_index_f16) = scalar_argminmax_f16_return_nan(data_f16);
let argmin_index_f16_single = scalar_argmin_f16_return_nan(data_f16);
let argmax_index_f16_single = scalar_argmax_f16_return_nan(data_f16);
assert_eq!(argmin_index, argmin_index_f16);
assert_eq!(argmin_index, argmin_index_f16_single);
assert_eq!(argmax_index, argmax_index_f16);
assert_eq!(argmax_index, argmax_index_f16_single);
assert_eq!(argmin_index, 0);
assert_eq!(argmax_index, 0);
}
#[test]
fn test_generic_and_specific_impl_ignore_nans() {
let nan_pos: [usize; 3] = [0, ARR_LEN / 2, ARR_LEN - 1];
for pos in nan_pos.iter() {
let (vec_f32, vec_f16) = get_arrays(ARR_LEN);
let mut data_f32: Vec<f32> = vec_f32;
let mut data_f16: Vec<f16> = vec_f16;
data_f32[*pos] = f32::NAN;
data_f16[*pos] = f16::NAN;
let (argmin_index, argmax_index) = SCALAR::<FloatIgnoreNaN>::argminmax(&data_f32);
let (argmin_index_f16, argmax_index_f16) = scalar_argminmax_f16_ignore_nan(&data_f16);
let argmin_index_f16_single = scalar_argmin_f16_ignore_nan(&data_f16);
let argmax_index_f16_single = scalar_argmax_f16_ignore_nan(&data_f16);
assert_eq!(argmin_index, argmin_index_f16);
assert_eq!(argmin_index, argmin_index_f16_single);
assert_eq!(argmax_index, argmax_index_f16);
assert_eq!(argmax_index, argmax_index_f16_single);
}
let (mut vec_f32, mut vec_f16) = get_arrays(ARR_LEN);
vec_f32.iter_mut().for_each(|x| *x = f32::NAN);
vec_f16.iter_mut().for_each(|x| *x = f16::NAN);
let data_f32: &[f32] = &vec_f32;
let data_f16: &[f16] = &vec_f16;
let (argmin_index, argmax_index) = SCALAR::<FloatIgnoreNaN>::argminmax(data_f32);
let (argmin_index_f16, argmax_index_f16) = scalar_argminmax_f16_ignore_nan(data_f16);
let argmin_index_f16_single = scalar_argmin_f16_ignore_nan(data_f16);
let argmax_index_f16_single = scalar_argmax_f16_ignore_nan(data_f16);
assert_eq!(argmin_index, argmin_index_f16);
assert_eq!(argmin_index, argmin_index_f16_single);
assert_eq!(argmax_index, argmax_index_f16);
assert_eq!(argmax_index, argmax_index_f16_single);
assert_eq!(argmin_index, 0);
assert_eq!(argmax_index, 0);
}
}