use wide::{u8x16, u32x16};
#[inline]
pub fn gather_u32index_u8(indices: u32x16, base: &[u8], scale: u8) -> u8x16 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512bw") {
return unsafe { gather_u32index_u8_avx512(indices, base, scale) };
}
}
gather_u32index_u8_scalar(indices, base, scale)
}
#[inline]
pub fn gather_masked_u32index_u8(
indices: u32x16,
base: &[u8],
scale: u8,
mask: u16,
fallback: u32x16,
) -> u8x16 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512bw") {
return unsafe { gather_masked_u32index_u8_avx512(indices, base, scale, mask, fallback) };
}
}
gather_masked_u32index_u8_scalar(indices, base, scale, mask, fallback)
}
#[inline]
pub fn gather_u32index_u32(indices: u32x16, base: &[u32], scale: u8) -> u32x16 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
return unsafe { gather_u32index_u32_avx512(indices, base, scale) };
}
}
gather_u32index_u32_scalar(indices, base, scale)
}
#[inline]
pub fn gather_masked_u32index_u32(
indices: u32x16,
base: &[u32],
scale: u8,
mask: u16,
fallback: u32x16,
) -> u32x16 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
return unsafe { gather_masked_u32index_u32_avx512(indices, base, scale, mask, fallback) };
}
}
gather_masked_u32index_u32_scalar(indices, base, scale, mask, fallback)
}
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[cfg(target_arch = "x86_64")]
use std::arch::is_x86_feature_detected;
#[cfg(target_arch = "x86_64")]
#[inline]
#[target_feature(enable = "avx512f", enable = "avx512bw")]
unsafe fn gather_u32index_u8_avx512(indices: u32x16, base: &[u8], scale: u8) -> u8x16 {
unsafe {
let idx = std::mem::transmute::<u32x16, __m512i>(indices);
let gathered = match scale {
1 => _mm512_i32gather_epi32::<1>(idx, base.as_ptr() as *const i32),
2 => _mm512_i32gather_epi32::<2>(idx, base.as_ptr() as *const i32),
4 => _mm512_i32gather_epi32::<4>(idx, base.as_ptr() as *const i32),
8 => _mm512_i32gather_epi32::<8>(idx, base.as_ptr() as *const i32),
_ => _mm512_i32gather_epi32::<1>(idx, base.as_ptr() as *const i32),
};
extract_low_bytes_avx512(gathered)
}
}
#[cfg(target_arch = "x86_64")]
#[inline]
#[target_feature(enable = "avx512f", enable = "avx512bw")]
unsafe fn gather_masked_u32index_u8_avx512(
indices: u32x16,
base: &[u8],
scale: u8,
mask: u16,
fallback: u32x16,
) -> u8x16 {
unsafe {
let idx = std::mem::transmute::<u32x16, __m512i>(indices);
let src = std::mem::transmute::<u32x16, __m512i>(fallback);
let gathered = match scale {
1 => _mm512_mask_i32gather_epi32::<1>(src, mask, idx, base.as_ptr() as *const i32),
2 => _mm512_mask_i32gather_epi32::<2>(src, mask, idx, base.as_ptr() as *const i32),
4 => _mm512_mask_i32gather_epi32::<4>(src, mask, idx, base.as_ptr() as *const i32),
8 => _mm512_mask_i32gather_epi32::<8>(src, mask, idx, base.as_ptr() as *const i32),
_ => _mm512_mask_i32gather_epi32::<1>(src, mask, idx, base.as_ptr() as *const i32),
};
extract_low_bytes_avx512(gathered)
}
}
#[cfg(target_arch = "x86_64")]
#[inline]
#[target_feature(enable = "avx512f")]
unsafe fn gather_u32index_u32_avx512(indices: u32x16, base: &[u32], scale: u8) -> u32x16 {
unsafe {
let idx = std::mem::transmute::<u32x16, __m512i>(indices);
let gathered = match scale {
1 => _mm512_i32gather_epi32::<1>(idx, base.as_ptr() as *const i32),
2 => _mm512_i32gather_epi32::<2>(idx, base.as_ptr() as *const i32),
4 => _mm512_i32gather_epi32::<4>(idx, base.as_ptr() as *const i32),
8 => _mm512_i32gather_epi32::<8>(idx, base.as_ptr() as *const i32),
_ => _mm512_i32gather_epi32::<4>(idx, base.as_ptr() as *const i32),
};
std::mem::transmute::<__m512i, u32x16>(gathered)
}
}
#[cfg(target_arch = "x86_64")]
#[inline]
#[target_feature(enable = "avx512f")]
unsafe fn gather_masked_u32index_u32_avx512(
indices: u32x16,
base: &[u32],
scale: u8,
mask: u16,
fallback: u32x16,
) -> u32x16 {
unsafe {
let idx = std::mem::transmute::<u32x16, __m512i>(indices);
let src = std::mem::transmute::<u32x16, __m512i>(fallback);
let gathered = match scale {
1 => _mm512_mask_i32gather_epi32::<1>(src, mask, idx, base.as_ptr() as *const i32),
2 => _mm512_mask_i32gather_epi32::<2>(src, mask, idx, base.as_ptr() as *const i32),
4 => _mm512_mask_i32gather_epi32::<4>(src, mask, idx, base.as_ptr() as *const i32),
8 => _mm512_mask_i32gather_epi32::<8>(src, mask, idx, base.as_ptr() as *const i32),
_ => _mm512_mask_i32gather_epi32::<4>(src, mask, idx, base.as_ptr() as *const i32),
};
std::mem::transmute::<__m512i, u32x16>(gathered)
}
}
#[cfg(target_arch = "x86_64")]
#[inline]
#[target_feature(enable = "avx512f", enable = "avx512bw")]
unsafe fn extract_low_bytes_avx512(gathered: __m512i) -> u8x16 {
unsafe {
let packed = _mm512_cvtepi32_epi8(gathered);
std::mem::transmute::<__m128i, u8x16>(packed)
}
}
#[inline]
fn gather_u32index_u8_scalar(indices: u32x16, base: &[u8], scale: u8) -> u8x16 {
let idx_arr = indices.to_array();
let scale = scale as usize;
let mut result = [0u8; 16];
for i in 0..16 {
let offset = idx_arr[i] as usize * scale;
result[i] = base[offset];
}
u8x16::from(result)
}
#[inline]
fn gather_masked_u32index_u8_scalar(
indices: u32x16,
base: &[u8],
scale: u8,
mask: u16,
fallback: u32x16,
) -> u8x16 {
let idx_arr = indices.to_array();
let fallback_arr = fallback.to_array();
let scale = scale as usize;
let mut result = [0u8; 16];
for i in 0..16 {
if (mask >> i) & 1 != 0 {
let offset = idx_arr[i] as usize * scale;
result[i] = base[offset];
} else {
result[i] = fallback_arr[i] as u8;
}
}
u8x16::from(result)
}
#[inline]
fn gather_u32index_u32_scalar(indices: u32x16, base: &[u32], scale: u8) -> u32x16 {
let idx_arr = indices.to_array();
let scale = scale as usize;
let mut result = [0u32; 16];
for i in 0..16 {
let offset = (idx_arr[i] as usize * scale) / 4;
result[i] = base[offset];
}
u32x16::from(result)
}
#[inline]
fn gather_masked_u32index_u32_scalar(
indices: u32x16,
base: &[u32],
scale: u8,
mask: u16,
fallback: u32x16,
) -> u32x16 {
let idx_arr = indices.to_array();
let fallback_arr = fallback.to_array();
let scale = scale as usize;
let mut result = [0u32; 16];
for i in 0..16 {
if (mask >> i) & 1 != 0 {
let offset = (idx_arr[i] as usize * scale) / 4;
result[i] = base[offset];
} else {
result[i] = fallback_arr[i];
}
}
u32x16::from(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gather_u32index_u8_basic() {
let data: Vec<u8> = (0..256).map(|i| i as u8).collect();
let indices = u32x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
let result = gather_u32index_u8(indices, &data, 1);
assert_eq!(
result.to_array(),
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
);
}
#[test]
fn test_gather_u32index_u8_scaled() {
let data: Vec<u8> = (0..256).map(|i| i as u8).collect();
let indices = u32x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
let result = gather_u32index_u8(indices, &data, 2);
assert_eq!(
result.to_array(),
[0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30]
);
}
#[test]
fn test_gather_u32index_u8_non_sequential() {
let data: Vec<u8> = (0..256).map(|i| i as u8).collect();
let indices = u32x16::from([100, 50, 200, 25, 150, 75, 225, 10, 0, 255, 128, 64, 192, 32, 96, 160]);
let result = gather_u32index_u8(indices, &data, 1);
assert_eq!(
result.to_array(),
[100, 50, 200, 25, 150, 75, 225, 10, 0, 255, 128, 64, 192, 32, 96, 160]
);
}
#[test]
fn test_gather_masked_u32index_u8() {
let data: Vec<u8> = (0..256).map(|i| i as u8).collect();
let indices = u32x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
let fallback = u32x16::from([255; 16]);
let mask = 0b0101010101010101u16;
let result = gather_masked_u32index_u8(indices, &data, 1, mask, fallback);
assert_eq!(
result.to_array(),
[0, 255, 2, 255, 4, 255, 6, 255, 8, 255, 10, 255, 12, 255, 14, 255]
);
}
#[test]
fn test_gather_masked_u32index_u8_all_masked() {
let data: Vec<u8> = (0..256).map(|i| i as u8).collect();
let indices = u32x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
let fallback = u32x16::from([42; 16]);
let mask = 0u16;
let result = gather_masked_u32index_u8(indices, &data, 1, mask, fallback);
assert_eq!(result.to_array(), [42; 16]);
}
#[test]
fn test_gather_u32index_u32_basic() {
let data: Vec<u32> = (0..256).map(|i| i as u32 * 1000).collect();
let indices = u32x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
let result = gather_u32index_u32(indices, &data, 4);
assert_eq!(
result.to_array(),
[0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000, 11000, 12000, 13000, 14000, 15000]
);
}
#[test]
fn test_gather_masked_u32index_u32() {
let data: Vec<u32> = (0..256).map(|i| i as u32 * 100).collect();
let indices = u32x16::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
let fallback = u32x16::from([999; 16]);
let mask = 0b1010101010101010u16;
let result = gather_masked_u32index_u32(indices, &data, 4, mask, fallback);
assert_eq!(
result.to_array(),
[999, 100, 999, 300, 999, 500, 999, 700, 999, 900, 999, 1100, 999, 1300, 999, 1500]
);
}
#[test]
fn test_gather_u32index_u32_non_sequential() {
let data: Vec<u32> = (0..256).map(|i| i as u32).collect();
let indices = u32x16::from([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]);
let result = gather_u32index_u32(indices, &data, 4);
assert_eq!(
result.to_array(),
[15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
);
}
}