#[cfg(target_arch = "x86")]
use core::arch::x86 as simd;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64 as simd;
use std::mem;
use simd::{
__m256i,
__m512i,
_mm256_i32gather_epi32,
_mm256_i32scatter_epi64,
_mm256_permutexvar_epi32,
_mm256_set_epi32,
_mm256_set_epi8,
_mm256_shuffle_epi8,
_mm512_loadu_si512,
_mm512_permutex2var_epi64,
_mm512_permutexvar_epi32,
_mm512_set_epi32,
_mm512_set_epi64,
_mm512_set_epi8,
_mm512_shuffle_epi8,
_mm512_storeu_si512,
_mm512_unpackhi_epi16,
_mm512_unpackhi_epi32,
_mm512_unpackhi_epi64,
_mm512_unpackhi_epi8,
_mm512_unpacklo_epi16,
_mm512_unpacklo_epi32,
_mm512_unpacklo_epi64,
_mm512_unpacklo_epi8,
_mm_set_epi32,
};
const SOI32: usize = mem::size_of::<i32>();
const SO256I: usize = mem::size_of::<__m256i>();
const SO512I: usize = mem::size_of::<__m512i>();
#[target_feature(enable = "avx512f")]
#[target_feature(enable = "avx512vl")]
unsafe fn shuffle_8x4(ymm: __m256i) -> __m256i {
#[rustfmt::skip]
let shuf8 = _mm256_set_epi8(
15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0,
15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0,
);
let shuf32 = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
let ymm1 = _mm256_shuffle_epi8(ymm, shuf8);
_mm256_permutexvar_epi32(shuf32, ymm1)
}
#[allow(clippy::needless_range_loop)] #[target_feature(enable = "avx512f")]
#[target_feature(enable = "avx512bw")]
unsafe fn shuffle2(
vectorizable_elements: usize,
total_elements: usize,
src: *const u8,
dst: *mut u8,
) {
const TS: usize = 2;
let mut zmm0: [__m512i; 2] = mem::zeroed();
let mut zmm1: [__m512i; 2] = mem::zeroed();
#[rustfmt::skip]
let shuf8 = _mm512_set_epi8(
15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0,
15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0,
15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0,
15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0,
);
let idx0 = _mm512_set_epi64(0xe, 0xc, 0xa, 8, 6, 4, 2, 0);
let idx1 = _mm512_set_epi64(0xf, 0xd, 0xb, 9, 7, 5, 3, 1);
for j in (0..vectorizable_elements).step_by(SO512I) {
for k in 0..2 {
let p = src.add(j * TS + k * SO512I) as *const __m512i;
zmm0[k] = _mm512_loadu_si512(p);
zmm0[k] = _mm512_shuffle_epi8(zmm0[k], shuf8);
}
zmm1[0] = _mm512_permutex2var_epi64(zmm0[0], idx0, zmm0[1]);
zmm1[1] = _mm512_permutex2var_epi64(zmm0[0], idx1, zmm0[1]);
for k in 0..2 {
let p = dst.add(j + k * total_elements) as *mut __m512i;
_mm512_storeu_si512(p, zmm1[k]);
}
}
}
#[allow(clippy::needless_range_loop)] #[target_feature(enable = "avx512f")]
#[target_feature(enable = "avx512bw")]
unsafe fn shuffle16(
vectorizable_elements: usize,
total_elements: usize,
src: *const u8,
dst: *mut u8,
) {
const TS: usize = 16;
let mut zmm0: [__m512i; TS] = mem::zeroed();
let mut zmm1: [__m512i; TS] = mem::zeroed();
#[rustfmt::skip]
let shmask = _mm512_set_epi8(
15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0,
15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0,
15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0,
15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0);
let shuf32 = _mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0);
for j in (0..vectorizable_elements).step_by(SO512I) {
for k in 0..TS {
let p = src.add(j * TS + k * SO512I) as *const __m512i;
zmm0[k] = _mm512_loadu_si512(p);
}
for k in 0..(TS / 2) {
zmm1[k * 2] = _mm512_unpacklo_epi8(zmm0[k * 2], zmm0[k * 2 + 1]);
zmm1[k * 2 + 1] = _mm512_unpackhi_epi8(zmm0[k * 2], zmm0[k * 2 + 1]);
}
let mut l = 0;
for k in 0..(TS / 2) {
zmm0[k * 2] = _mm512_unpacklo_epi16(zmm1[l], zmm1[l + 2]);
zmm0[k * 2 + 1] = _mm512_unpackhi_epi16(zmm1[l], zmm1[l + 2]);
l += 1;
if k % 2 == 1 {
l += 2;
}
}
l = 0;
for k in 0..(TS / 2) {
zmm1[k * 2] = _mm512_unpacklo_epi32(zmm0[l], zmm0[l + 4]);
zmm1[k * 2 + 1] = _mm512_unpackhi_epi32(zmm0[l], zmm0[l + 4]);
l += 1;
if k % 4 == 3 {
l += 4;
}
}
for k in 0..(TS / 2) {
zmm0[k * 2] = _mm512_unpacklo_epi64(zmm1[k], zmm1[k + 8]);
zmm0[k * 2 + 1] = _mm512_unpackhi_epi64(zmm1[k], zmm1[k + 8]);
}
for k in 0..TS {
zmm1[k] = _mm512_permutexvar_epi32(shuf32, zmm0[k]);
zmm0[k] = _mm512_shuffle_epi8(zmm1[k], shmask);
}
for k in 0..TS {
let p = dst.add(j + k * total_elements) as *mut __m512i;
_mm512_storeu_si512(p, zmm0[k]);
}
}
}
#[allow(clippy::needless_range_loop)] #[target_feature(enable = "avx512f")]
#[target_feature(enable = "avx512vl")]
unsafe fn shuffle_sg(
vectorizable_elements: usize,
total_elements: usize,
ts: usize,
src: *const u8,
dst: *mut u8,
) {
const I32PM256: usize = SO256I / SOI32;
debug_assert_eq!(vectorizable_elements % 4, 0);
let loadindex = _mm256_set_epi32(
7 * ts as i32,
6 * ts as i32,
5 * ts as i32,
4 * ts as i32,
3 * ts as i32,
2 * ts as i32,
ts as i32,
0,
);
let storeindex = _mm_set_epi32(
3 * total_elements as i32,
2 * total_elements as i32,
total_elements as i32,
0,
);
for i in 0..(vectorizable_elements / I32PM256) {
for j in 0..(ts / SOI32) {
let p = src.add(i * I32PM256 * ts + j * SOI32) as *const i32;
let mut zmm = _mm256_i32gather_epi32(p, loadindex, 1);
zmm = shuffle_8x4(zmm);
let p = dst.add(i * I32PM256 + j * total_elements * SOI32) as *mut i64;
_mm256_i32scatter_epi64(p, storeindex, zmm, 1);
}
}
}
pub unsafe fn shuffle(typesize: usize, len: usize, src: *const u8, dst: *mut u8) {
let total_elements = len / typesize;
let vectorized_chunk_size = typesize * SO512I;
let vectorizable_bytes = len - (len % vectorized_chunk_size);
let vectorizable_elements = vectorizable_bytes / typesize;
let sg_chunk_size = typesize * SO256I / 4;
let sg_bytes = len - (len % sg_chunk_size);
let sg_elements = sg_bytes / typesize;
let vectorized_bytes = if typesize == 2 && len >= vectorized_chunk_size {
shuffle2(vectorizable_elements, total_elements, src, dst);
vectorizable_bytes
} else if typesize == 16 && len >= vectorized_chunk_size {
shuffle16(vectorizable_elements, total_elements, src, dst);
vectorizable_bytes
} else if typesize.is_multiple_of(4) && len >= sg_chunk_size {
shuffle_sg(sg_elements, total_elements, typesize, src, dst);
sg_bytes
} else {
return crate::avx2::shuffle(typesize, len, src, dst);
};
if vectorized_bytes < len {
crate::generic::shuffle_partial(typesize, vectorized_bytes, len, src, dst);
}
}
#[cfg(test)]
mod t {
macro_rules! require_avx512f {
() => {
if !is_x86_feature_detected!("avx512f") || !is_x86_feature_detected!("avx512vl") {
eprintln!("Skipping: AVX512F or AVX512VL unavailable.");
return;
}
};
}
mod shuffle_8x4 {
#[cfg(target_arch = "x86")]
use core::arch::x86 as simd;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64 as simd;
use rstest::rstest;
use simd::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256};
use super::super::shuffle_8x4;
#[rstest]
fn t() {
require_avx512f!();
#[rustfmt::skip]
let input = vec![
0, 1, 2, 3, 16, 17, 18, 19, 32, 33, 34, 35, 48, 49, 50, 51,
64, 65, 66, 67, 80, 81, 82, 83, 96, 97, 98, 99, 112, 113, 114, 115];
let mut want = vec![0; 32];
let mut actual = vec![0; 32];
unsafe {
let x = _mm256_loadu_si256(input.as_ptr() as *const __m256i);
let y = shuffle_8x4(x);
crate::generic::shuffle(4, input.len(), input.as_ptr(), want.as_mut_ptr());
_mm256_storeu_si256(actual.as_mut_ptr() as *mut __m256i, y);
}
assert_eq!(want, actual);
}
}
mod shuffle {
use rand::Rng;
use rstest::rstest;
#[rstest]
#[case(16, 256)]
#[case(16, 512)]
#[case(16, 4096)]
#[case(16, 4352)]
#[case(16, 65536)]
fn compare(#[case] typesize: usize, #[case] len: usize) {
require_avx512f!();
let mut rng = rand::rng();
let src = (0..len).map(|_| rng.random()).collect::<Vec<u8>>();
let mut generic_dst = vec![0u8; len];
let mut sse2_dst = vec![0u8; len];
unsafe {
crate::generic::shuffle(typesize, len, src.as_ptr(), generic_dst.as_mut_ptr());
crate::avx512f::shuffle(typesize, len, src.as_ptr(), sse2_dst.as_mut_ptr());
}
assert_eq!(generic_dst, sse2_dst);
}
#[rstest]
fn compare16x1024() {
require_avx512f!();
let typesize = 16;
let len = 1024;
let src = (0..len).map(|i| i.min(255) as u8).collect::<Vec<u8>>();
let mut generic_dst = vec![0u8; len];
let mut sse2_dst = vec![0u8; len];
unsafe {
crate::generic::shuffle(typesize, len, src.as_ptr(), generic_dst.as_mut_ptr());
crate::avx512f::shuffle(typesize, len, src.as_ptr(), sse2_dst.as_mut_ptr());
}
assert_eq!(generic_dst, sse2_dst);
}
#[rstest]
fn compare16x272() {
require_avx512f!();
let typesize = 16;
let len = 272;
let src = (0..len).map(|i| (i % 256) as u8).collect::<Vec<u8>>();
let mut generic_dst = vec![0u8; len];
let mut sse2_dst = vec![0u8; len];
unsafe {
crate::generic::shuffle(typesize, len, src.as_ptr(), generic_dst.as_mut_ptr());
crate::avx512f::shuffle(typesize, len, src.as_ptr(), sse2_dst.as_mut_ptr());
}
assert_eq!(generic_dst, sse2_dst);
}
#[rstest]
fn compare18x288() {
require_avx512f!();
let typesize = 18;
let len = 288;
let src = (0..len).map(|i| (i % 256) as u8).collect::<Vec<u8>>();
let mut generic_dst = vec![0u8; len];
let mut sse2_dst = vec![0u8; len];
unsafe {
crate::generic::shuffle(typesize, len, src.as_ptr(), generic_dst.as_mut_ptr());
crate::avx512f::shuffle(typesize, len, src.as_ptr(), sse2_dst.as_mut_ptr());
}
assert_eq!(generic_dst, sse2_dst);
}
#[rstest]
fn compare20x160() {
require_avx512f!();
let typesize = 20;
let len = 160;
let src = (0..len).map(|i| (i % 256) as u8).collect::<Vec<u8>>();
let mut generic_dst = vec![0u8; len];
let mut sse2_dst = vec![0u8; len];
unsafe {
crate::generic::shuffle(typesize, len, src.as_ptr(), generic_dst.as_mut_ptr());
crate::avx512f::shuffle(typesize, len, src.as_ptr(), sse2_dst.as_mut_ptr());
}
assert_eq!(generic_dst, sse2_dst);
}
}
}