use std::arch::x86_64::{
__m256i, _mm_loadu_si128, _mm_setzero_si128, _mm_shuffle_epi32, _mm_storeu_si128,
_mm_unpacklo_epi64, _mm256_cmpgt_epi32, _mm256_cmpgt_epi64, _mm256_cvtepu8_epi32,
_mm256_cvtepu8_epi64, _mm256_cvtepu16_epi32, _mm256_cvtepu16_epi64, _mm256_cvtepu32_epi64,
_mm256_extracti128_si256, _mm256_loadu_si256, _mm256_mask_i32gather_epi32,
_mm256_mask_i64gather_epi32, _mm256_mask_i64gather_epi64, _mm256_set1_epi32,
_mm256_set1_epi64x, _mm256_setzero_si256, _mm256_storeu_si256,
};
use std::convert::identity;
use vortex_buffer::{Alignment, Buffer, BufferMut};
use vortex_dtype::{
NativePType, PType, UnsignedPType, match_each_native_ptype, match_each_unsigned_integer_ptype,
};
use vortex_error::VortexResult;
use crate::arrays::primitive::PrimitiveArray;
use crate::arrays::primitive::compute::take::{TakeImpl, take_primitive_scalar};
use crate::validity::Validity;
use crate::{ArrayRef, IntoArray};
#[allow(unused)]
pub(super) struct TakeKernelAVX2;
impl TakeImpl for TakeKernelAVX2 {
#[allow(clippy::cognitive_complexity)]
#[inline(always)]
fn take(
&self,
values: &PrimitiveArray,
indices: &PrimitiveArray,
validity: Validity,
) -> VortexResult<ArrayRef> {
assert!(indices.ptype().is_unsigned_int());
match_each_unsigned_integer_ptype!(indices.ptype(), |I| {
match_each_native_ptype!(values.ptype(), |V| {
Ok(unsafe {
take_primitive_avx2(indices.as_slice::<I>(), values.as_slice::<V>(), validity)
}
.into_array())
})
})
}
}
pub(crate) trait GatherFn<Idx, Values> {
const WIDTH: usize;
const STRIDE: usize = Self::WIDTH;
unsafe fn gather(indices: *const Idx, max_idx: Idx, src: *const Values, dst: *mut Values);
}
enum AVX2Gather {}
macro_rules! impl_gather {
($idx:ty, $({$value:ty => load: $load:ident, extend: $extend:ident, splat: $splat:ident, zero_vec: $zero_vec:ident, mask_indices: $mask_indices:ident, mask_cvt: |$mask_var:ident| $mask_cvt:block, gather: $masked_gather:ident, store: $store:ident, WIDTH = $WIDTH:literal, STRIDE = $STRIDE:literal }),+) => {
$(
impl_gather!(single; $idx, $value, load: $load, extend: $extend, splat: $splat, zero_vec: $zero_vec, mask_indices: $mask_indices, mask_cvt: |$mask_var| $mask_cvt, gather: $masked_gather, store: $store, WIDTH = $WIDTH, STRIDE = $STRIDE);
)*
};
(single; $idx:ty, $value:ty, load: $load:ident, extend: $extend:ident, splat: $splat:ident, zero_vec: $zero_vec:ident, mask_indices: $mask_indices:ident, mask_cvt: |$mask_var:ident| $mask_cvt:block, gather: $masked_gather:ident, store: $store:ident, WIDTH = $WIDTH:literal, STRIDE = $STRIDE:literal) => {
impl GatherFn<$idx, $value> for AVX2Gather {
const WIDTH: usize = $WIDTH;
const STRIDE: usize = $STRIDE;
#[allow(unused_unsafe, clippy::cast_possible_truncation)]
#[inline(always)]
unsafe fn gather(indices: *const $idx, max_idx: $idx, src: *const $value, dst: *mut $value) {
const {
assert!($WIDTH <= $STRIDE, "dst cannot advance by more than the stride");
}
const SCALE: i32 = std::mem::size_of::<$value>() as i32;
let indices_vec = unsafe { $load(indices.cast()) };
let indices_vec = unsafe { $extend(indices_vec) };
let max_idx_vec = unsafe { $splat(max_idx as _) };
let invalid_mask = unsafe { $mask_indices(max_idx_vec, indices_vec) };
let invalid_mask = {
let $mask_var = invalid_mask;
$mask_cvt
};
let zero_vec = unsafe { $zero_vec() };
let values_vec = unsafe { $masked_gather::<SCALE>(zero_vec, src.cast(), indices_vec, invalid_mask) };
unsafe { $store(dst.cast(), values_vec) };
}
}
};
}
impl_gather!(u8,
{ u32 =>
load: _mm_loadu_si128,
extend: _mm256_cvtepu8_epi32,
splat: _mm256_set1_epi32,
zero_vec: _mm256_setzero_si256,
mask_indices: _mm256_cmpgt_epi32,
mask_cvt: |x| { x },
gather: _mm256_mask_i32gather_epi32,
store: _mm256_storeu_si256,
WIDTH = 8, STRIDE = 16
},
{ i32 =>
load: _mm_loadu_si128,
extend: _mm256_cvtepu8_epi32,
splat: _mm256_set1_epi32,
zero_vec: _mm256_setzero_si256,
mask_indices: _mm256_cmpgt_epi32,
mask_cvt: |x| { x },
gather: _mm256_mask_i32gather_epi32,
store: _mm256_storeu_si256,
WIDTH = 8, STRIDE = 16
},
{ u64 =>
load: _mm_loadu_si128,
extend: _mm256_cvtepu8_epi64,
splat: _mm256_set1_epi64x,
zero_vec: _mm256_setzero_si256,
mask_indices: _mm256_cmpgt_epi64,
mask_cvt: |x| { x },
gather: _mm256_mask_i64gather_epi64,
store: _mm256_storeu_si256,
WIDTH = 4, STRIDE = 16
},
{ i64 =>
load: _mm_loadu_si128,
extend: _mm256_cvtepu8_epi64,
splat: _mm256_set1_epi64x,
zero_vec: _mm256_setzero_si256,
mask_indices: _mm256_cmpgt_epi64,
mask_cvt: |x| { x },
gather: _mm256_mask_i64gather_epi64,
store: _mm256_storeu_si256,
WIDTH = 4, STRIDE = 16
}
);
impl_gather!(u16,
{ u32 =>
load: _mm_loadu_si128,
extend: _mm256_cvtepu16_epi32,
splat: _mm256_set1_epi32,
zero_vec: _mm256_setzero_si256,
mask_indices: _mm256_cmpgt_epi32,
mask_cvt: |x| { x },
gather: _mm256_mask_i32gather_epi32,
store: _mm256_storeu_si256,
WIDTH = 8, STRIDE = 8
},
{ i32 =>
load: _mm_loadu_si128,
extend: _mm256_cvtepu16_epi32,
splat: _mm256_set1_epi32,
zero_vec: _mm256_setzero_si256,
mask_indices: _mm256_cmpgt_epi32,
mask_cvt: |x| { x },
gather: _mm256_mask_i32gather_epi32,
store: _mm256_storeu_si256,
WIDTH = 8, STRIDE = 8
},
{ u64 =>
load: _mm_loadu_si128,
extend: _mm256_cvtepu16_epi64,
splat: _mm256_set1_epi64x,
zero_vec: _mm256_setzero_si256,
mask_indices: _mm256_cmpgt_epi64,
mask_cvt: |x| { x },
gather: _mm256_mask_i64gather_epi64,
store: _mm256_storeu_si256,
WIDTH = 4, STRIDE = 8
},
{ i64 =>
load: _mm_loadu_si128,
extend: _mm256_cvtepu16_epi64,
splat: _mm256_set1_epi64x,
zero_vec: _mm256_setzero_si256,
mask_indices: _mm256_cmpgt_epi64,
mask_cvt: |x| { x },
gather: _mm256_mask_i64gather_epi64,
store: _mm256_storeu_si256,
WIDTH = 4, STRIDE = 8
}
);
impl_gather!(u32,
{ u32 =>
load: _mm256_loadu_si256,
extend: identity,
splat: _mm256_set1_epi32,
zero_vec: _mm256_setzero_si256,
mask_indices: _mm256_cmpgt_epi32,
mask_cvt: |x| { x },
gather: _mm256_mask_i32gather_epi32,
store: _mm256_storeu_si256,
WIDTH = 8, STRIDE = 8
},
{ i32 =>
load: _mm256_loadu_si256,
extend: identity,
splat: _mm256_set1_epi32,
zero_vec: _mm256_setzero_si256,
mask_indices: _mm256_cmpgt_epi32,
mask_cvt: |x| { x },
gather: _mm256_mask_i32gather_epi32,
store: _mm256_storeu_si256,
WIDTH = 8, STRIDE = 8
},
{ u64 =>
load: _mm_loadu_si128,
extend: _mm256_cvtepu32_epi64,
splat: _mm256_set1_epi64x,
zero_vec: _mm256_setzero_si256,
mask_indices: _mm256_cmpgt_epi64,
mask_cvt: |x| { x },
gather: _mm256_mask_i64gather_epi64,
store: _mm256_storeu_si256,
WIDTH = 4, STRIDE = 4
},
{ i64 =>
load: _mm_loadu_si128,
extend: _mm256_cvtepu32_epi64,
splat: _mm256_set1_epi64x,
zero_vec: _mm256_setzero_si256,
mask_indices: _mm256_cmpgt_epi64,
mask_cvt: |x| { x },
gather: _mm256_mask_i64gather_epi64,
store: _mm256_storeu_si256,
WIDTH = 4, STRIDE = 4
}
);
impl_gather!(u64,
{ u32 =>
load: _mm256_loadu_si256,
extend: identity,
splat: _mm256_set1_epi64x,
zero_vec: _mm_setzero_si128,
mask_indices: _mm256_cmpgt_epi64,
mask_cvt: |m| {
unsafe {
let lo_bits = _mm256_extracti128_si256::<0>(m); let hi_bits = _mm256_extracti128_si256::<1>(m); let lo_packed = _mm_shuffle_epi32::<0b01_01_01_01>(lo_bits);
let hi_packed = _mm_shuffle_epi32::<0b01_01_01_01>(hi_bits);
_mm_unpacklo_epi64(lo_packed, hi_packed)
}
},
gather: _mm256_mask_i64gather_epi32,
store: _mm_storeu_si128,
WIDTH = 4, STRIDE = 4
},
{ i32 =>
load: _mm256_loadu_si256,
extend: identity,
splat: _mm256_set1_epi64x,
zero_vec: _mm_setzero_si128,
mask_indices: _mm256_cmpgt_epi64,
mask_cvt: |m| {
unsafe {
let lo_bits = _mm256_extracti128_si256::<0>(m); let hi_bits = _mm256_extracti128_si256::<1>(m); let lo_packed = _mm_shuffle_epi32::<0b01_01_01_01>(lo_bits);
let hi_packed = _mm_shuffle_epi32::<0b01_01_01_01>(hi_bits);
_mm_unpacklo_epi64(lo_packed, hi_packed)
}
},
gather: _mm256_mask_i64gather_epi32,
store: _mm_storeu_si128,
WIDTH = 4, STRIDE = 4
},
{ u64 =>
load: _mm256_loadu_si256,
extend: identity,
splat: _mm256_set1_epi64x,
zero_vec: _mm256_setzero_si256,
mask_indices: _mm256_cmpgt_epi64,
mask_cvt: |x| { x },
gather: _mm256_mask_i64gather_epi64,
store: _mm256_storeu_si256,
WIDTH = 4, STRIDE = 4
},
{ i64 =>
load: _mm256_loadu_si256,
extend: identity,
splat: _mm256_set1_epi64x,
zero_vec: _mm256_setzero_si256,
mask_indices: _mm256_cmpgt_epi64,
mask_cvt: |x| { x },
gather: _mm256_mask_i64gather_epi64,
store: _mm256_storeu_si256,
WIDTH = 4, STRIDE = 4
}
);
#[inline(always)]
fn exec_take<Idx, Value, Gather>(indices: &[Idx], values: &[Value]) -> Buffer<Value>
where
Idx: UnsignedPType,
Value: Copy,
Gather: GatherFn<Idx, Value>,
{
let indices_len = indices.len();
let max_index = Idx::from(values.len()).unwrap_or_else(|| Idx::max_value());
let mut buffer =
BufferMut::<Value>::with_capacity_aligned(indices_len, Alignment::of::<__m256i>());
let buf_uninit = buffer.spare_capacity_mut();
let mut offset = 0;
while offset + Gather::STRIDE < indices_len {
unsafe {
Gather::gather(
indices.as_ptr().add(offset),
max_index,
values.as_ptr(),
buf_uninit.as_mut_ptr().add(offset).cast(),
)
};
offset += Gather::WIDTH;
}
while offset < indices_len {
buf_uninit[offset].write(values[indices[offset].as_()]);
offset += 1;
}
assert_eq!(offset, indices_len);
unsafe { buffer.set_len(indices_len) };
buffer.freeze()
}
#[target_feature(enable = "avx2")]
#[allow(unused, clippy::cognitive_complexity, clippy::useless_transmute)]
pub(crate) fn take_primitive_avx2<I, V>(
indices: &[I],
values: &[V],
validity: Validity,
) -> PrimitiveArray
where
I: UnsignedPType,
V: NativePType,
{
macro_rules! dispatch_avx2 {
($indices:ty, $values:ty) => {
{ let result = dispatch_avx2!($indices, $values, cast: $values); result }
};
($indices:ty, $values:ty, cast: $cast:ty) => {{
let indices = unsafe { std::mem::transmute::<&[I], &[$indices]>(indices) };
let values = unsafe { std::mem::transmute::<&[V], &[$cast]>(values) };
let result = exec_take::<$indices, $cast, AVX2Gather>(indices, values);
let result = unsafe { std::mem::transmute::<Buffer<$cast>, Buffer<$values>>(result) };
PrimitiveArray::new(
unsafe { std::mem::transmute::<Buffer<$values>, Buffer<V>>(result) },
validity,
)
}};
}
match (I::PTYPE, V::PTYPE) {
(PType::U8, PType::I32) => dispatch_avx2!(u8, i32),
(PType::U8, PType::U32) => dispatch_avx2!(u8, u32),
(PType::U8, PType::I64) => dispatch_avx2!(u8, i64),
(PType::U8, PType::U64) => dispatch_avx2!(u8, u64),
(PType::U16, PType::I32) => dispatch_avx2!(u16, i32),
(PType::U16, PType::U32) => dispatch_avx2!(u16, u32),
(PType::U16, PType::I64) => dispatch_avx2!(u16, i64),
(PType::U16, PType::U64) => dispatch_avx2!(u16, u64),
(PType::U32, PType::I32) => dispatch_avx2!(u32, i32),
(PType::U32, PType::U32) => dispatch_avx2!(u32, u32),
(PType::U32, PType::I64) => dispatch_avx2!(u32, i64),
(PType::U32, PType::U64) => dispatch_avx2!(u32, u64),
(PType::U8, PType::F32) => dispatch_avx2!(u8, f32, cast: u32),
(PType::U16, PType::F32) => dispatch_avx2!(u16, f32, cast: u32),
(PType::U32, PType::F32) => dispatch_avx2!(u32, f32, cast: u32),
(PType::U64, PType::F32) => dispatch_avx2!(u64, f32, cast: u32),
(PType::U8, PType::F64) => dispatch_avx2!(u8, f64, cast: u64),
(PType::U16, PType::F64) => dispatch_avx2!(u16, f64, cast: u64),
(PType::U32, PType::F64) => dispatch_avx2!(u32, f64, cast: u64),
(PType::U64, PType::F64) => dispatch_avx2!(u64, f64, cast: u64),
_ => {
log::trace!(
"take AVX2 kernel missing for indices {} values {}, falling back to scalar",
I::PTYPE,
V::PTYPE
);
let result = take_primitive_scalar(values, indices);
PrimitiveArray::new(result, validity)
}
}
}
#[cfg(not(target_arch = "x86_64"))]
pub fn take_primitive_avx2<I, V>(
_indices: &[I],
_values: &[V],
_nullability: Nullability,
) -> Option<PrimitiveArray>
where
I: UnsignedPType,
V: NativePType,
{
None
}
#[cfg(test)]
#[cfg_attr(miri, ignore)]
#[cfg(target_arch = "x86_64")]
mod tests {
use super::*;
macro_rules! test_cases {
(index_type => $IDX:ty, value_types => $($VAL:ty),+) => {
paste::paste! {
$(
#[test]
#[allow(clippy::cast_possible_truncation)]
fn [<test_avx2_take_simple_ $IDX _ $VAL>]() {
let values: Vec<$VAL> = (1..=127).map(|x| x as $VAL).collect();
let indices: Vec<$IDX> = (0..127).collect();
let result = unsafe { take_primitive_avx2(&indices, &values, Validity::NonNullable) };
assert_eq!(&values, result.as_slice::<$VAL>());
}
#[test]
#[should_panic]
#[allow(clippy::cast_possible_truncation)]
fn [<test_avx2_take_empty_ $IDX _ $VAL>]() {
let values: Vec<$VAL> = vec![];
let indices: Vec<$IDX> = (0..127).collect();
let result = unsafe { take_primitive_avx2(&indices, &values, Validity::NonNullable) };
assert!(result.is_empty());
}
#[test]
#[should_panic]
#[allow(clippy::cast_possible_truncation)]
fn [<test_avx2_take_invalid_ $IDX _ $VAL>]() {
let values: Vec<$VAL> = (1..=127).map(|x| x as $VAL).collect();
let indices: Vec<$IDX> = (127..=254).collect();
let result = unsafe { take_primitive_avx2(&indices, &values, Validity::NonNullable) };
assert_eq!(&[0 as $VAL; 127], result.as_slice::<$VAL>());
}
)+
}
};
}
test_cases!(
index_type => u8,
value_types => u32, i32, u64, i64, f32, f64
);
test_cases!(
index_type => u16,
value_types => u32, i32, u64, i64, f32, f64
);
test_cases!(
index_type => u32,
value_types => u32, i32, u64, i64, f32, f64
);
test_cases!(
index_type => u64,
value_types => u32, i32, u64, i64, f32, f64
);
}