#![cfg_attr(feature = "nightly_simd", feature(cfg_version))]
#![cfg_attr(
all(
feature = "nightly_simd",
any(target_arch = "x86_64", target_arch = "x86")
),
cfg_attr(version("1.78"), feature(stdarch_x86_avx512))
)]
#![cfg_attr(
all(feature = "nightly_simd", target_arch = "arm"),
cfg_attr(
version("1.78"),
feature(stdarch_arm_neon_intrinsics),
feature(stdarch_arm_feature_detection)
)
)]
#![cfg_attr(
feature = "nightly_simd",
cfg_attr(not(version("1.78")), feature(stdsimd))
)]
#![cfg_attr(feature = "nightly_simd", feature(avx512_target_feature))]
#![cfg_attr(feature = "nightly_simd", feature(arm_target_feature))]
#[cfg(test)]
use rstest_reuse;
pub mod dtype_strategy;
pub mod scalar;
pub mod simd;
pub(crate) use dtype_strategy::Int;
#[cfg(any(feature = "float", feature = "half"))]
pub(crate) use dtype_strategy::{FloatIgnoreNaN, FloatReturnNaN};
pub(crate) use scalar::{ScalarArgMinMax, SCALAR};
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly_simd")]
pub(crate) use simd::AVX512;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub(crate) use simd::{SIMDArgMinMax, AVX2, SSE};
#[cfg(any(
all(target_arch = "arm", feature = "nightly_simd"),
target_arch = "aarch64"
))]
pub(crate) use simd::{SIMDArgMinMax, NEON};
#[cfg(feature = "half")]
use half::f16;
pub trait ArgMinMax {
fn argminmax(&self) -> (usize, usize);
fn argmin(&self) -> usize;
fn argmax(&self) -> usize;
}
#[cfg(any(feature = "float", feature = "half"))]
pub trait NaNArgMinMax {
fn nanargminmax(&self) -> (usize, usize);
fn nanargmin(&self) -> usize;
fn nanargmax(&self) -> usize;
}
trait DTypeInfo {
const NB_BITS: usize;
}
macro_rules! impl_nb_bits {
($($data_type:ty)*) => ($(
impl DTypeInfo for $data_type {
const NB_BITS: usize = std::mem::size_of::<$data_type>() * 8;
}
)*)
}
impl_nb_bits!(i8 i16 i32 i64 u8 u16 u32 u64);
#[cfg(feature = "float")]
impl_nb_bits!(f32 f64);
#[cfg(feature = "half")]
impl_nb_bits!(f16);
macro_rules! impl_argminmax_int {
($($int_type:ty),*) => {
$(
impl ArgMinMax for &[$int_type] {
fn argminmax(&self) -> (usize, usize) {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("sse4.1") & (<$int_type>::NB_BITS == 8) {
return unsafe { SSE::<Int>::argminmax(self) }
}
#[cfg(feature = "nightly_simd")]
{
if is_x86_feature_detected!("avx512bw") & (<$int_type>::NB_BITS <= 16) {
return unsafe { AVX512::<Int>::argminmax(self) }
}
else if is_x86_feature_detected!("avx512f") { return unsafe { AVX512::<Int>::argminmax(self) }
}
}
if is_x86_feature_detected!("avx2") {
return unsafe { AVX2::<Int>::argminmax(self) }
} else if is_x86_feature_detected!("sse4.1") & (<$int_type>::NB_BITS < 64) {
return unsafe { SSE::<Int>::argminmax(self) }
}
}
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") & (<$int_type>::NB_BITS < 64) {
return unsafe { NEON::<Int>::argminmax(self) }
}
}
#[cfg(all(target_arch = "arm", feature = "nightly_simd"))]
{
if std::arch::is_arm_feature_detected!("neon") & (<$int_type>::NB_BITS < 64) {
return unsafe { NEON::<Int>::argminmax(self) }
}
}
SCALAR::<Int>::argminmax(self)
}
fn argmin(&self) -> usize {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("sse4.1") & (<$int_type>::NB_BITS == 8) {
return unsafe { SSE::<Int>::argmin(self) }
}
#[cfg(feature = "nightly_simd")]
{
if is_x86_feature_detected!("avx512bw") & (<$int_type>::NB_BITS <= 16) {
return unsafe { AVX512::<Int>::argmin(self) }
} else if is_x86_feature_detected!("avx512f") {
return unsafe { AVX512::<Int>::argmin(self) }
}
}
if is_x86_feature_detected!("avx2") {
return unsafe { AVX2::<Int>::argmin(self) }
} else if is_x86_feature_detected!("sse4.1") & (<$int_type>::NB_BITS < 64) {
return unsafe { SSE::<Int>::argmin(self) }
}
}
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") {
return unsafe { NEON::<Int>::argmin(self) }
}
}
#[cfg(all(target_arch = "arm", feature = "nightly_simd"))]
{
if std::arch::is_arm_feature_detected!("neon") & (<$int_type>::NB_BITS < 64) {
return unsafe { NEON::<Int>::argmin(self) }
}
}
SCALAR::<Int>::argmin(self)
}
fn argmax(&self) -> usize {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("sse4.1") & (<$int_type>::NB_BITS == 8) {
return unsafe { SSE::<Int>::argmax(self) }
}
#[cfg(feature = "nightly_simd")]
{
if is_x86_feature_detected!("avx512bw") & (<$int_type>::NB_BITS <= 16) {
return unsafe { AVX512::<Int>::argmax(self) }
} else if is_x86_feature_detected!("avx512f") {
return unsafe { AVX512::<Int>::argmax(self) }
}
}
if is_x86_feature_detected!("avx2") {
return unsafe { AVX2::<Int>::argmax(self) }
} else if is_x86_feature_detected!("sse4.1") & (<$int_type>::NB_BITS < 64) {
return unsafe { SSE::<Int>::argmax(self) }
}
}
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") {
return unsafe { NEON::<Int>::argmax(self) }
}
}
#[cfg(all(target_arch = "arm", feature = "nightly_simd"))]
{
if std::arch::is_arm_feature_detected!("neon") & (<$int_type>::NB_BITS < 64) {
return unsafe { NEON::<Int>::argmax(self) }
}
}
SCALAR::<Int>::argmax(self)
}
}
)*
};
}
#[cfg(any(feature = "float", feature = "half"))]
macro_rules! impl_argminmax_float {
($($float_type:ty),*) => {
$(
impl ArgMinMax for &[$float_type] {
fn argminmax(&self) -> (usize, usize) {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
#[cfg(feature = "nightly_simd")]
{
if is_x86_feature_detected!("avx512bw") & (<$float_type>::NB_BITS == 16) {
return unsafe { AVX512::<FloatIgnoreNaN>::argminmax(self) }
} else if is_x86_feature_detected!("avx512f") {
return unsafe { AVX512::<FloatIgnoreNaN>::argminmax(self) }
}
}
if is_x86_feature_detected!("avx2") {
return unsafe { AVX2::<FloatIgnoreNaN>::argminmax(self) }
} else if is_x86_feature_detected!("avx") & (<$float_type>::NB_BITS > 16) {
return unsafe { AVX2::<FloatIgnoreNaN>::argminmax(self) }
} else if is_x86_feature_detected!("sse4.1") & (<$float_type>::NB_BITS < 64) {
return unsafe { SSE::<FloatIgnoreNaN>::argminmax(self) }
}
}
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") {
return unsafe { NEON::<FloatIgnoreNaN>::argminmax(self) }
}
}
#[cfg(all(target_arch = "arm", feature = "nightly_simd"))]
{
if std::arch::is_arm_feature_detected!("neon") & (<$float_type>::NB_BITS < 64) {
return unsafe { NEON::<FloatIgnoreNaN>::argminmax(self) }
}
}
SCALAR::<FloatIgnoreNaN>::argminmax(self)
}
fn argmin(&self) -> usize {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
#[cfg(feature = "nightly_simd")]
{
if is_x86_feature_detected!("avx512bw") & (<$float_type>::NB_BITS == 16) {
return unsafe { AVX512::<FloatIgnoreNaN>::argmin(self) }
} else if is_x86_feature_detected!("avx512f") {
return unsafe { AVX512::<FloatIgnoreNaN>::argmin(self) }
}
}
if is_x86_feature_detected!("avx2") {
return unsafe { AVX2::<FloatIgnoreNaN>::argmin(self) }
} else if is_x86_feature_detected!("avx") & (<$float_type>::NB_BITS > 16) {
return unsafe { AVX2::<FloatIgnoreNaN>::argmin(self) }
} else if is_x86_feature_detected!("sse4.1") & (<$float_type>::NB_BITS < 64) {
return unsafe { SSE::<FloatIgnoreNaN>::argmin(self) }
}
}
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") {
return unsafe { NEON::<FloatIgnoreNaN>::argmin(self) }
}
}
#[cfg(all(target_arch = "arm", feature = "nightly_simd"))]
{
if std::arch::is_arm_feature_detected!("neon") & (<$float_type>::NB_BITS < 64) {
return unsafe { NEON::<FloatIgnoreNaN>::argmin(self) }
}
}
SCALAR::<FloatIgnoreNaN>::argmin(self)
}
fn argmax(&self) -> usize {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
#[cfg(feature = "nightly_simd")]
{
if is_x86_feature_detected!("avx512bw") & (<$float_type>::NB_BITS == 16) {
return unsafe { AVX512::<FloatIgnoreNaN>::argmax(self) }
} else if is_x86_feature_detected!("avx512f") {
return unsafe { AVX512::<FloatIgnoreNaN>::argmax(self) }
}
}
if is_x86_feature_detected!("avx2") {
return unsafe { AVX2::<FloatIgnoreNaN>::argmax(self) }
} else if is_x86_feature_detected!("avx") & (<$float_type>::NB_BITS > 16) {
return unsafe { AVX2::<FloatIgnoreNaN>::argmax(self) }
} else if is_x86_feature_detected!("sse4.1") & (<$float_type>::NB_BITS < 64) {
return unsafe { SSE::<FloatIgnoreNaN>::argmax(self) }
}
}
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") {
return unsafe { NEON::<FloatIgnoreNaN>::argmax(self) }
}
}
#[cfg(all(target_arch = "arm", feature = "nightly_simd"))]
{
if std::arch::is_arm_feature_detected!("neon") & (<$float_type>::NB_BITS < 64) {
return unsafe { NEON::<FloatIgnoreNaN>::argmax(self) }
}
}
SCALAR::<FloatIgnoreNaN>::argmax(self)
}
}
impl NaNArgMinMax for &[$float_type] {
fn nanargminmax(&self) -> (usize, usize) {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
#[cfg(feature = "nightly_simd")]
{
if is_x86_feature_detected!("avx512bw") & (<$float_type>::NB_BITS == 16) {
return unsafe { AVX512::<FloatReturnNaN>::argminmax(self) }
} else if is_x86_feature_detected!("avx512f") {
return unsafe { AVX512::<FloatReturnNaN>::argminmax(self) }
}
}
if is_x86_feature_detected!("avx2") {
return unsafe { AVX2::<FloatReturnNaN>::argminmax(self) }
} else if is_x86_feature_detected!("sse4.1") & (<$float_type>::NB_BITS < 64) {
return unsafe { SSE::<FloatReturnNaN>::argminmax(self) }
}
}
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") {
return unsafe { NEON::<FloatReturnNaN>::argminmax(self) }
}
}
#[cfg(all(target_arch = "arm", feature = "nightly_simd"))]
{
if std::arch::is_arm_feature_detected!("neon") & (<$float_type>::NB_BITS < 64) {
return unsafe { NEON::<FloatReturnNaN>::argminmax(self) }
}
}
SCALAR::<FloatReturnNaN>::argminmax(self)
}
fn nanargmin(&self) -> usize {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
#[cfg(feature = "nightly_simd")]
{
if is_x86_feature_detected!("avx512bw") & (<$float_type>::NB_BITS == 16) {
return unsafe { AVX512::<FloatReturnNaN>::argmin(self) }
} else if is_x86_feature_detected!("avx512f") {
return unsafe { AVX512::<FloatReturnNaN>::argmin(self) }
}
}
if is_x86_feature_detected!("avx2") {
return unsafe { AVX2::<FloatReturnNaN>::argmin(self) }
} else if is_x86_feature_detected!("sse4.1") & (<$float_type>::NB_BITS < 64) {
return unsafe { SSE::<FloatReturnNaN>::argmin(self) }
}
}
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") {
return unsafe { NEON::<FloatReturnNaN>::argmin(self) }
}
}
#[cfg(all(target_arch = "arm", feature = "nightly_simd"))]
{
if std::arch::is_arm_feature_detected!("neon") & (<$float_type>::NB_BITS < 64) {
return unsafe { NEON::<FloatReturnNaN>::argmin(self) }
}
}
SCALAR::<FloatReturnNaN>::argmin(self)
}
fn nanargmax(&self) -> usize {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
#[cfg(feature = "nightly_simd")]
{
if is_x86_feature_detected!("avx512bw") & (<$float_type>::NB_BITS == 16) {
return unsafe { AVX512::<FloatReturnNaN>::argmax(self) }
} else if is_x86_feature_detected!("avx512f") {
return unsafe { AVX512::<FloatReturnNaN>::argmax(self) }
}
}
if is_x86_feature_detected!("avx2") {
return unsafe { AVX2::<FloatReturnNaN>::argmax(self) }
} else if is_x86_feature_detected!("sse4.1") & (<$float_type>::NB_BITS < 64) {
return unsafe { SSE::<FloatReturnNaN>::argmax(self) }
}
}
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") {
return unsafe { NEON::<FloatReturnNaN>::argmax(self) }
}
}
#[cfg(all(target_arch = "arm", feature = "nightly_simd"))]
{
if std::arch::is_arm_feature_detected!("neon") & (<$float_type>::NB_BITS < 64) {
return unsafe { NEON::<FloatReturnNaN>::argmax(self) }
}
}
SCALAR::<FloatReturnNaN>::argmax(self)
}
}
)*
};
}
impl_argminmax_int!(i8, i16, i32, i64, u8, u16, u32, u64);
#[cfg(feature = "float")]
impl_argminmax_float!(f32, f64);
#[cfg(feature = "half")]
impl_argminmax_float!(f16);
impl<T> ArgMinMax for Vec<T>
where
for<'a> &'a [T]: ArgMinMax,
{
fn argminmax(&self) -> (usize, usize) {
self.as_slice().argminmax()
}
fn argmin(&self) -> usize {
self.as_slice().argmin()
}
fn argmax(&self) -> usize {
self.as_slice().argmax()
}
}
#[cfg(any(feature = "float", feature = "half"))]
impl<T> NaNArgMinMax for Vec<T>
where
for<'a> &'a [T]: NaNArgMinMax,
{
fn nanargminmax(&self) -> (usize, usize) {
self.as_slice().nanargminmax()
}
fn nanargmin(&self) -> usize {
self.as_slice().nanargmin()
}
fn nanargmax(&self) -> usize {
self.as_slice().nanargmax()
}
}
#[cfg(feature = "ndarray")]
mod ndarray_impl {
use super::*;
use ndarray::{ArrayBase, Data, Ix1};
impl<S> ArgMinMax for ArrayBase<S, Ix1>
where
S: Data,
for<'a> &'a [S::Elem]: ArgMinMax,
{
fn argminmax(&self) -> (usize, usize) {
self.as_slice().unwrap().argminmax()
}
fn argmin(&self) -> usize {
self.as_slice().unwrap().argmin()
}
fn argmax(&self) -> usize {
self.as_slice().unwrap().argmax()
}
}
#[cfg(any(feature = "float", feature = "half"))]
impl<S> NaNArgMinMax for ArrayBase<S, Ix1>
where
S: Data,
for<'a> &'a [S::Elem]: NaNArgMinMax,
{
fn nanargminmax(&self) -> (usize, usize) {
self.as_slice().unwrap().nanargminmax()
}
fn nanargmin(&self) -> usize {
self.as_slice().unwrap().nanargmin()
}
fn nanargmax(&self) -> usize {
self.as_slice().unwrap().nanargmax()
}
}
}
#[cfg(feature = "arrow")]
mod arrow_impl {
use super::*;
use arrow::array::PrimitiveArray;
impl<T> ArgMinMax for PrimitiveArray<T>
where
T: arrow::datatypes::ArrowNumericType,
for<'a> &'a [T::Native]: ArgMinMax,
{
fn argminmax(&self) -> (usize, usize) {
self.values().as_ref().argminmax()
}
fn argmin(&self) -> usize {
self.values().as_ref().argmin()
}
fn argmax(&self) -> usize {
self.values().as_ref().argmax()
}
}
#[cfg(any(feature = "float", feature = "half"))]
impl<T> NaNArgMinMax for PrimitiveArray<T>
where
T: arrow::datatypes::ArrowNumericType,
for<'a> &'a [T::Native]: NaNArgMinMax,
{
fn nanargminmax(&self) -> (usize, usize) {
self.values().as_ref().nanargminmax()
}
fn nanargmin(&self) -> usize {
self.values().as_ref().nanargmin()
}
fn nanargmax(&self) -> usize {
self.values().as_ref().nanargmax()
}
}
}
#[cfg(feature = "arrow2")]
mod arrow2_impl {
use super::*;
use arrow2::array::PrimitiveArray;
impl<T> ArgMinMax for PrimitiveArray<T>
where
T: arrow2::types::NativeType,
for<'a> &'a [T]: ArgMinMax,
{
fn argminmax(&self) -> (usize, usize) {
self.values().as_ref().argminmax()
}
fn argmin(&self) -> usize {
self.values().as_ref().argmin()
}
fn argmax(&self) -> usize {
self.values().as_ref().argmax()
}
}
#[cfg(feature = "float")]
impl<T> NaNArgMinMax for PrimitiveArray<T>
where
T: arrow2::types::NativeType,
for<'a> &'a [T]: NaNArgMinMax,
{
fn nanargminmax(&self) -> (usize, usize) {
self.values().as_ref().nanargminmax()
}
fn nanargmin(&self) -> usize {
self.values().as_ref().nanargmin()
}
fn nanargmax(&self) -> usize {
self.values().as_ref().nanargmax()
}
}
#[cfg(feature = "half")]
#[inline(always)]
fn _to_half_f16_slice(
primitive_array_f16: &PrimitiveArray<arrow2::types::f16>,
) -> &[half::f16] {
unsafe {
std::slice::from_raw_parts(
primitive_array_f16.values().as_ptr() as *const half::f16,
primitive_array_f16.len(),
)
}
}
#[cfg(feature = "half")]
impl ArgMinMax for PrimitiveArray<arrow2::types::f16> {
fn argminmax(&self) -> (usize, usize) {
_to_half_f16_slice(self).argminmax()
}
fn argmin(&self) -> usize {
_to_half_f16_slice(self).argmin()
}
fn argmax(&self) -> usize {
_to_half_f16_slice(self).argmax()
}
}
#[cfg(feature = "half")]
impl NaNArgMinMax for PrimitiveArray<arrow2::types::f16> {
fn nanargminmax(&self) -> (usize, usize) {
_to_half_f16_slice(self).nanargminmax()
}
fn nanargmin(&self) -> usize {
_to_half_f16_slice(self).nanargmin()
}
fn nanargmax(&self) -> usize {
_to_half_f16_slice(self).nanargmax()
}
}
}