use std::simd::prelude::*;
use std::simd::{LaneCount, SimdElement, SupportedLaneCount};
use arrow::array::PrimitiveArray;
use arrow::bitmap::Bitmap;
use arrow::bitmap::bitmask::BitMask;
use arrow::types::NativeType;
use polars_utils::min_max::MinMax;
use super::MinMaxKernel;
fn scalar_reduce_min_propagate_nan<T: MinMax + Copy, const N: usize>(arr: &[T; N]) -> T {
let it = arr.iter().copied();
it.reduce(MinMax::min_propagate_nan).unwrap()
}
fn scalar_reduce_max_propagate_nan<T: MinMax + Copy, const N: usize>(arr: &[T; N]) -> T {
let it = arr.iter().copied();
it.reduce(MinMax::max_propagate_nan).unwrap()
}
fn fold_agg_kernel<const N: usize, T, F>(
arr: &[T],
validity: Option<&Bitmap>,
scalar_identity: T,
mut simd_f: F,
) -> Option<Simd<T, N>>
where
T: SimdElement + NativeType,
F: FnMut(Simd<T, N>, Simd<T, N>) -> Simd<T, N>,
LaneCount<N>: SupportedLaneCount,
{
if arr.is_empty() {
return None;
}
let mut arr_chunks = arr.chunks_exact(N);
let identity = Simd::splat(scalar_identity);
let mut state = identity;
if let Some(valid) = validity {
if valid.unset_bits() == arr.len() {
return None;
}
let mask = BitMask::from_bitmap(valid);
let mut offset = 0;
for c in arr_chunks.by_ref() {
let m: Mask<_, N> = mask.get_simd(offset);
state = simd_f(state, m.select(Simd::from_slice(c), identity));
offset += N;
}
if !arr.len().is_multiple_of(N) {
let mut rest: [T; N] = identity.to_array();
let arr_rest = arr_chunks.remainder();
rest[..arr_rest.len()].copy_from_slice(arr_rest);
let m: Mask<_, N> = mask.get_simd(offset);
state = simd_f(state, m.select(Simd::from_array(rest), identity));
}
} else {
for c in arr_chunks.by_ref() {
state = simd_f(state, Simd::from_slice(c));
}
if !arr.len().is_multiple_of(N) {
let mut rest: [T; N] = identity.to_array();
let arr_rest = arr_chunks.remainder();
rest[..arr_rest.len()].copy_from_slice(arr_rest);
state = simd_f(state, Simd::from_array(rest));
}
}
Some(state)
}
fn fold_agg_min_max_kernel<const N: usize, T, F>(
arr: &[T],
validity: Option<&Bitmap>,
min_scalar_identity: T,
max_scalar_identity: T,
mut simd_f: F,
) -> Option<(Simd<T, N>, Simd<T, N>)>
where
T: SimdElement + NativeType,
F: FnMut((Simd<T, N>, Simd<T, N>), (Simd<T, N>, Simd<T, N>)) -> (Simd<T, N>, Simd<T, N>),
LaneCount<N>: SupportedLaneCount,
{
if arr.is_empty() {
return None;
}
let mut arr_chunks = arr.chunks_exact(N);
let min_identity = Simd::splat(min_scalar_identity);
let max_identity = Simd::splat(max_scalar_identity);
let mut state = (min_identity, max_identity);
if let Some(valid) = validity {
if valid.unset_bits() == arr.len() {
return None;
}
let mask = BitMask::from_bitmap(valid);
let mut offset = 0;
for c in arr_chunks.by_ref() {
let m: Mask<_, N> = mask.get_simd(offset);
let slice = Simd::from_slice(c);
state = simd_f(
state,
(m.select(slice, min_identity), m.select(slice, max_identity)),
);
offset += N;
}
if !arr.len().is_multiple_of(N) {
let mut min_rest: [T; N] = min_identity.to_array();
let mut max_rest: [T; N] = max_identity.to_array();
let arr_rest = arr_chunks.remainder();
min_rest[..arr_rest.len()].copy_from_slice(arr_rest);
max_rest[..arr_rest.len()].copy_from_slice(arr_rest);
let m: Mask<_, N> = mask.get_simd(offset);
let min_rest = Simd::from_array(min_rest);
let max_rest = Simd::from_array(max_rest);
state = simd_f(
state,
(
m.select(min_rest, min_identity),
m.select(max_rest, max_identity),
),
);
}
} else {
for c in arr_chunks.by_ref() {
let slice = Simd::from_slice(c);
state = simd_f(state, (slice, slice));
}
if !arr.len().is_multiple_of(N) {
let mut min_rest: [T; N] = min_identity.to_array();
let mut max_rest: [T; N] = max_identity.to_array();
let arr_rest = arr_chunks.remainder();
min_rest[..arr_rest.len()].copy_from_slice(arr_rest);
max_rest[..arr_rest.len()].copy_from_slice(arr_rest);
let min_rest = Simd::from_array(min_rest);
let max_rest = Simd::from_array(max_rest);
state = simd_f(state, (min_rest, max_rest));
}
}
Some(state)
}
macro_rules! impl_min_max_kernel_int {
($T:ty, $N:literal) => {
impl MinMaxKernel for PrimitiveArray<$T> {
type Scalar<'a> = $T;
fn min_ignore_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
fold_agg_kernel::<$N, $T, _>(self.values(), self.validity(), <$T>::MAX, |a, b| {
a.simd_min(b)
})
.map(|s| s.reduce_min())
}
fn max_ignore_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
fold_agg_kernel::<$N, $T, _>(self.values(), self.validity(), <$T>::MIN, |a, b| {
a.simd_max(b)
})
.map(|s| s.reduce_max())
}
fn min_max_ignore_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> {
fold_agg_min_max_kernel::<$N, $T, _>(
self.values(),
self.validity(),
<$T>::MAX,
<$T>::MIN,
|(cmin, cmax), (min, max)| (cmin.simd_min(min), cmax.simd_max(max)),
)
.map(|(min, max)| (min.reduce_min(), max.reduce_max()))
}
fn min_propagate_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
self.min_ignore_nan_kernel()
}
fn max_propagate_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
self.max_ignore_nan_kernel()
}
fn min_max_propagate_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> {
self.min_max_ignore_nan_kernel()
}
}
impl MinMaxKernel for [$T] {
type Scalar<'a> = $T;
fn min_ignore_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
fold_agg_kernel::<$N, $T, _>(self, None, <$T>::MAX, |a, b| a.simd_min(b))
.map(|s| s.reduce_min())
}
fn max_ignore_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
fold_agg_kernel::<$N, $T, _>(self, None, <$T>::MIN, |a, b| a.simd_max(b))
.map(|s| s.reduce_max())
}
fn min_max_ignore_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> {
fold_agg_min_max_kernel::<$N, $T, _>(
self,
None,
<$T>::MAX,
<$T>::MIN,
|(cmin, cmax), (min, max)| (cmin.simd_min(min), cmax.simd_max(max)),
)
.map(|(min, max)| (min.reduce_min(), max.reduce_max()))
}
fn min_propagate_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
self.min_ignore_nan_kernel()
}
fn max_propagate_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
self.max_ignore_nan_kernel()
}
fn min_max_propagate_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> {
self.min_max_ignore_nan_kernel()
}
}
};
}
impl_min_max_kernel_int!(u8, 32);
impl_min_max_kernel_int!(u16, 16);
impl_min_max_kernel_int!(u32, 16);
impl_min_max_kernel_int!(u64, 8);
impl_min_max_kernel_int!(i8, 32);
impl_min_max_kernel_int!(i16, 16);
impl_min_max_kernel_int!(i32, 16);
impl_min_max_kernel_int!(i64, 8);
macro_rules! impl_min_max_kernel_float {
($T:ty, $N:literal) => {
impl MinMaxKernel for PrimitiveArray<$T> {
type Scalar<'a> = $T;
fn min_ignore_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
fold_agg_kernel::<$N, $T, _>(self.values(), self.validity(), <$T>::NAN, |a, b| {
a.simd_min(b)
})
.map(|s| s.reduce_min())
}
fn max_ignore_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
fold_agg_kernel::<$N, $T, _>(self.values(), self.validity(), <$T>::NAN, |a, b| {
a.simd_max(b)
})
.map(|s| s.reduce_max())
}
fn min_max_ignore_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> {
fold_agg_min_max_kernel::<$N, $T, _>(
self.values(),
self.validity(),
<$T>::NAN,
<$T>::NAN,
|(cmin, cmax), (min, max)| (cmin.simd_min(min), cmax.simd_max(max)),
)
.map(|(min, max)| (min.reduce_min(), max.reduce_max()))
}
fn min_propagate_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
fold_agg_kernel::<$N, $T, _>(
self.values(),
self.validity(),
<$T>::INFINITY,
|a, b| (a.simd_lt(b) | a.simd_ne(a)).select(a, b),
)
.map(|s| scalar_reduce_min_propagate_nan(s.as_array()))
}
fn max_propagate_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
fold_agg_kernel::<$N, $T, _>(
self.values(),
self.validity(),
<$T>::NEG_INFINITY,
|a, b| (a.simd_gt(b) | a.simd_ne(a)).select(a, b),
)
.map(|s| scalar_reduce_max_propagate_nan(s.as_array()))
}
fn min_max_propagate_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> {
fold_agg_min_max_kernel::<$N, $T, _>(
self.values(),
self.validity(),
<$T>::INFINITY,
<$T>::NEG_INFINITY,
|(cmin, cmax), (min, max)| {
(
(cmin.simd_lt(min) | cmin.simd_ne(cmin)).select(cmin, min),
(cmax.simd_gt(max) | cmax.simd_ne(cmax)).select(cmax, max),
)
},
)
.map(|(min, max)| {
(
scalar_reduce_min_propagate_nan(min.as_array()),
scalar_reduce_max_propagate_nan(max.as_array()),
)
})
}
}
impl MinMaxKernel for [$T] {
type Scalar<'a> = $T;
fn min_ignore_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
fold_agg_kernel::<$N, $T, _>(self, None, <$T>::NAN, |a, b| a.simd_min(b))
.map(|s| s.reduce_min())
}
fn max_ignore_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
fold_agg_kernel::<$N, $T, _>(self, None, <$T>::NAN, |a, b| a.simd_max(b))
.map(|s| s.reduce_max())
}
fn min_max_ignore_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> {
fold_agg_min_max_kernel::<$N, $T, _>(
self,
None,
<$T>::NAN,
<$T>::NAN,
|(cmin, cmax), (min, max)| (cmin.simd_min(min), cmax.simd_max(max)),
)
.map(|(min, max)| (min.reduce_min(), max.reduce_max()))
}
fn min_propagate_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
fold_agg_kernel::<$N, $T, _>(self, None, <$T>::INFINITY, |a, b| {
(a.simd_lt(b) | a.simd_ne(a)).select(a, b)
})
.map(|s| scalar_reduce_min_propagate_nan(s.as_array()))
}
fn max_propagate_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
fold_agg_kernel::<$N, $T, _>(self, None, <$T>::NEG_INFINITY, |a, b| {
(a.simd_gt(b) | a.simd_ne(a)).select(a, b)
})
.map(|s| scalar_reduce_max_propagate_nan(s.as_array()))
}
fn min_max_propagate_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> {
fold_agg_min_max_kernel::<$N, $T, _>(
self,
None,
<$T>::INFINITY,
<$T>::NEG_INFINITY,
|(cmin, cmax), (min, max)| {
(
(cmin.simd_lt(min) | cmin.simd_ne(cmin)).select(cmin, min),
(cmax.simd_gt(max) | cmax.simd_ne(cmax)).select(cmax, max),
)
},
)
.map(|(min, max)| {
(
scalar_reduce_min_propagate_nan(min.as_array()),
scalar_reduce_max_propagate_nan(max.as_array()),
)
})
}
}
};
}
impl_min_max_kernel_float!(f32, 16);
impl_min_max_kernel_float!(f64, 8);