#![cfg(target_arch = "x86_64")]
use core::arch::x86_64::__m512i;
use core::arch::x86_64::_mm512_and_si512;
use core::arch::x86_64::_mm512_loadu_si512;
use core::arch::x86_64::_mm512_permutex2var_epi8;
use core::arch::x86_64::_mm512_permutexvar_epi8;
use core::arch::x86_64::_mm512_set1_epi64;
use core::arch::x86_64::_mm512_slli_epi64;
use core::arch::x86_64::_mm512_srli_epi64;
use core::arch::x86_64::_mm512_storeu_si512;
use core::arch::x86_64::_mm512_xor_si512;
use core::arch::x86_64::_pdep_u64;
use core::arch::x86_64::_pext_u64;
use crate::bit_transpose::as_byte_array;
use crate::bit_transpose::as_byte_array_mut;
use crate::bit_transpose::group_perm::group_tables;
use crate::bit_transpose::BASE_PATTERN_FIRST;
use crate::bit_transpose::BASE_PATTERN_SECOND;
use crate::bit_transpose::TRANSPOSE_2X2;
use crate::bit_transpose::TRANSPOSE_4X4;
use crate::bit_transpose::TRANSPOSE_8X8;
use crate::FastLanes;
use crate::FL_ORDER;
#[cfg(feature = "std")]
#[inline]
#[must_use]
pub fn has_bmi2() -> bool {
std::is_x86_feature_detected!("bmi2")
}
#[cfg(feature = "std")]
#[inline]
#[must_use]
pub fn has_vbmi() -> bool {
std::is_x86_feature_detected!("avx512vbmi")
&& std::is_x86_feature_detected!("avx512bw")
&& std::is_x86_feature_detected!("avx512f")
}
const BIT_MASKS: [u64; 8] = [
0x0101_0101_0101_0101,
0x0202_0202_0202_0202,
0x0404_0404_0404_0404,
0x0808_0808_0808_0808,
0x1010_1010_1010_1010,
0x2020_2020_2020_2020,
0x4040_4040_4040_4040,
0x8080_8080_8080_8080,
];
#[target_feature(enable = "bmi2")]
#[inline]
#[allow(unsafe_op_in_unsafe_fn)]
pub unsafe fn transpose_bits_bmi2(input: &[u64; 16], output: &mut [u64; 16]) {
let input = as_byte_array(input);
let output = as_byte_array_mut(output);
for (half, groups) in [BASE_PATTERN_FIRST, BASE_PATTERN_SECOND].iter().enumerate() {
for (group, &base) in groups.iter().enumerate() {
let g = gather(input, base);
for (bit, &mask) in BIT_MASKS.iter().enumerate() {
output[half * 64 + bit * 8 + group] = _pext_u64(g, mask) as u8;
}
}
}
}
#[target_feature(enable = "bmi2")]
#[inline]
#[allow(unsafe_op_in_unsafe_fn)]
pub unsafe fn untranspose_bits_bmi2<T: FastLanes>(input: &[u64; 16], output: &mut [u64; 16]) {
let input = as_byte_array(input);
let output = as_byte_array_mut(output);
let bytes = T::T / 8;
let lhi_count = 128 / T::T;
for lhi in 0..lhi_count {
for hi in 0..bytes {
let gather_base = lhi * T::T + hi;
let mut v = 0u64;
for (llo, &mask) in BIT_MASKS.iter().enumerate() {
v |= _pdep_u64(u64::from(input[gather_base + llo * bytes]), mask);
}
let scatter_base = FL_ORDER[hi] * 2 + lhi;
for row in 0..8 {
output[scatter_base + row * 16] = (v >> (row * 8)) as u8;
}
}
}
}
#[inline]
fn gather(input: &[u8; 128], base: usize) -> u64 {
let mut result = 0u64;
for row in 0..8 {
result |= u64::from(input[base + row * 16]) << (row * 8);
}
result
}
static GATHER_FIRST: [u8; 64] = [
0, 16, 32, 48, 64, 80, 96, 112, 8, 24, 40, 56, 72, 88, 104, 120, 4, 20, 36, 52, 68, 84, 100, 116, 12, 28, 44, 60, 76, 92, 108, 124, 2, 18, 34, 50, 66, 82, 98, 114, 10, 26, 42, 58, 74, 90, 106, 122, 6, 22, 38, 54, 70, 86, 102, 118, 14, 30, 46, 62, 78, 94, 110, 126, ];
static GATHER_SECOND: [u8; 64] = [
1, 17, 33, 49, 65, 81, 97, 113, 9, 25, 41, 57, 73, 89, 105, 121, 5, 21, 37, 53, 69, 85, 101, 117, 13, 29, 45, 61, 77, 93, 109, 125, 3, 19, 35, 51, 67, 83, 99, 115, 11, 27, 43, 59, 75, 91, 107, 123, 7, 23, 39, 55, 71, 87, 103, 119, 15, 31, 47, 63, 79, 95, 111, 127, ];
static SCATTER_8X8: [u8; 64] = [
0, 8, 16, 24, 32, 40, 48, 56, 1, 9, 17, 25, 33, 41, 49, 57, 2, 10, 18, 26, 34, 42, 50, 58, 3, 11, 19, 27, 35, 43, 51, 59, 4, 12, 20, 28, 36, 44, 52, 60, 5, 13, 21, 29, 37, 45, 53, 61, 6, 14, 22, 30, 38, 46, 54, 62, 7, 15, 23, 31, 39, 47, 55, 63, ];
#[target_feature(enable = "avx512f", enable = "avx512bw")]
#[inline]
#[allow(clippy::cast_possible_wrap, unsafe_op_in_unsafe_fn)]
unsafe fn bit_transpose_8x8_zmm(mut v: __m512i) -> __m512i {
let mask1 = _mm512_set1_epi64(TRANSPOSE_2X2 as i64);
let mask2 = _mm512_set1_epi64(TRANSPOSE_4X4 as i64);
let mask3 = _mm512_set1_epi64(TRANSPOSE_8X8 as i64);
let t = _mm512_and_si512(_mm512_xor_si512(v, _mm512_srli_epi64::<7>(v)), mask1);
v = _mm512_xor_si512(_mm512_xor_si512(v, t), _mm512_slli_epi64::<7>(t));
let t = _mm512_and_si512(_mm512_xor_si512(v, _mm512_srli_epi64::<14>(v)), mask2);
v = _mm512_xor_si512(_mm512_xor_si512(v, t), _mm512_slli_epi64::<14>(t));
let t = _mm512_and_si512(_mm512_xor_si512(v, _mm512_srli_epi64::<28>(v)), mask3);
_mm512_xor_si512(_mm512_xor_si512(v, t), _mm512_slli_epi64::<28>(t))
}
#[target_feature(enable = "avx512f", enable = "avx512bw", enable = "avx512vbmi")]
#[inline]
#[allow(clippy::cast_ptr_alignment, unsafe_op_in_unsafe_fn)]
pub unsafe fn transpose_bits_vbmi(input: &[u64; 16], output: &mut [u64; 16]) {
let input = as_byte_array(input);
let output = as_byte_array_mut(output);
let in_lo = _mm512_loadu_si512(input.as_ptr().cast::<__m512i>());
let in_hi = _mm512_loadu_si512(input.as_ptr().add(64).cast::<__m512i>());
let idx_scatter = _mm512_loadu_si512(SCATTER_8X8.as_ptr().cast::<__m512i>());
for (half, gather_idx) in [&GATHER_FIRST, &GATHER_SECOND].iter().enumerate() {
let idx = _mm512_loadu_si512(gather_idx.as_ptr().cast::<__m512i>());
let gathered = _mm512_permutex2var_epi8(in_lo, idx, in_hi);
let transposed = bit_transpose_8x8_zmm(gathered);
let scattered = _mm512_permutexvar_epi8(idx_scatter, transposed);
_mm512_storeu_si512(
output.as_mut_ptr().add(half * 64).cast::<__m512i>(),
scattered,
);
}
}
#[target_feature(enable = "avx512f", enable = "avx512bw", enable = "avx512vbmi")]
#[inline]
#[allow(clippy::cast_ptr_alignment, unsafe_op_in_unsafe_fn)]
pub unsafe fn untranspose_bits_vbmi<T: FastLanes>(input: &[u64; 16], output: &mut [u64; 16]) {
if T::T != 64 {
untranspose_bits_vbmi_lt64::<T>(input, output);
return;
}
let input = as_byte_array(input);
let output = as_byte_array_mut(output);
let gather_indices: [u8; 64] = SCATTER_8X8;
let idx = _mm512_loadu_si512(gather_indices.as_ptr().cast::<__m512i>());
for (half, groups) in [BASE_PATTERN_FIRST, BASE_PATTERN_SECOND].iter().enumerate() {
let in_half = _mm512_loadu_si512(input.as_ptr().add(half * 64).cast::<__m512i>());
let gathered = _mm512_permutexvar_epi8(idx, in_half);
let transposed = bit_transpose_8x8_zmm(gathered);
let mut lanes = [0u64; 8];
_mm512_storeu_si512(lanes.as_mut_ptr().cast::<__m512i>(), transposed);
for (group, &base) in groups.iter().enumerate() {
for row in 0..8 {
output[base + row * 16] = (lanes[group] >> (row * 8)) as u8;
}
}
}
}
#[target_feature(enable = "avx512f", enable = "avx512bw", enable = "avx512vbmi")]
#[inline]
#[allow(clippy::cast_ptr_alignment, unsafe_op_in_unsafe_fn)]
unsafe fn untranspose_bits_vbmi_lt64<T: FastLanes>(input: &[u64; 16], output: &mut [u64; 16]) {
let input = as_byte_array(input);
let output = as_byte_array_mut(output);
let (gather_idx, scatter_idx) = group_tables::<T>();
let in_lo = _mm512_loadu_si512(input.as_ptr().cast::<__m512i>());
let in_hi = _mm512_loadu_si512(input.as_ptr().add(64).cast::<__m512i>());
let g_lo = _mm512_loadu_si512(gather_idx.as_ptr().cast::<__m512i>());
let g_hi = _mm512_loadu_si512(gather_idx.as_ptr().add(64).cast::<__m512i>());
let grp0 = _mm512_permutex2var_epi8(in_lo, g_lo, in_hi);
let grp1 = _mm512_permutex2var_epi8(in_lo, g_hi, in_hi);
let t0 = bit_transpose_8x8_zmm(grp0);
let t1 = bit_transpose_8x8_zmm(grp1);
let s_lo = _mm512_loadu_si512(scatter_idx.as_ptr().cast::<__m512i>());
let s_hi = _mm512_loadu_si512(scatter_idx.as_ptr().add(64).cast::<__m512i>());
_mm512_storeu_si512(
output.as_mut_ptr().cast::<__m512i>(),
_mm512_permutex2var_epi8(t0, s_lo, t1),
);
_mm512_storeu_si512(
output.as_mut_ptr().add(64).cast::<__m512i>(),
_mm512_permutex2var_epi8(t0, s_hi, t1),
);
}
#[cfg(test)]
mod tests {
#[cfg(feature = "std")]
use super::*;
#[cfg(feature = "std")]
use crate::bit_transpose::generate_test_data;
#[cfg(feature = "std")]
use crate::bit_transpose::transpose_bits_baseline;
#[cfg(feature = "std")]
use crate::bit_transpose::untranspose_bits_baseline;
#[cfg(feature = "std")]
#[test]
fn test_bmi2_matches_baseline() {
if !has_bmi2() {
return;
}
for seed in [0, 42, 123, 255] {
let input = generate_test_data(seed);
let mut baseline_out = [0u64; 16];
let mut bmi2_out = [0u64; 16];
transpose_bits_baseline(&input, &mut baseline_out);
unsafe { transpose_bits_bmi2(&input, &mut bmi2_out) };
assert_eq!(
baseline_out, bmi2_out,
"BMI2 transpose doesn't match baseline for seed {seed}"
);
}
}
#[cfg(feature = "std")]
#[test]
fn test_bmi2_roundtrip() {
if !has_bmi2() {
return;
}
for seed in [0, 42, 123, 255] {
let input = generate_test_data(seed);
let mut transposed = [0u64; 16];
let mut roundtrip = [0u64; 16];
unsafe {
transpose_bits_bmi2(&input, &mut transposed);
untranspose_bits_bmi2::<u64>(&transposed, &mut roundtrip);
}
assert_eq!(input, roundtrip, "BMI2 roundtrip failed for seed {seed}");
}
}
#[cfg(feature = "std")]
#[test]
fn test_bmi2_untranspose_all_widths_match_baseline() {
fn check<T: FastLanes>() {
for seed in [0, 42, 123, 255] {
let input = generate_test_data(seed);
let mut baseline_out = [0u64; 16];
let mut bmi2_out = [0u64; 16];
untranspose_bits_baseline::<T>(&input, &mut baseline_out);
unsafe { untranspose_bits_bmi2::<T>(&input, &mut bmi2_out) };
assert_eq!(
baseline_out,
bmi2_out,
"BMI2 untranspose != baseline for type={} seed={seed}",
core::any::type_name::<T>()
);
}
}
if !has_bmi2() {
return;
}
check::<u8>();
check::<u16>();
check::<u32>();
check::<u64>();
}
#[cfg(feature = "std")]
#[test]
fn test_vbmi_matches_baseline() {
if !has_vbmi() {
return;
}
for seed in [0, 42, 123, 255] {
let input = generate_test_data(seed);
let mut baseline_out = [0u64; 16];
let mut vbmi_out = [0u64; 16];
transpose_bits_baseline(&input, &mut baseline_out);
unsafe { transpose_bits_vbmi(&input, &mut vbmi_out) };
assert_eq!(
baseline_out, vbmi_out,
"VBMI transpose doesn't match baseline for seed {seed}"
);
}
}
#[cfg(feature = "std")]
#[test]
fn test_vbmi_roundtrip() {
if !has_vbmi() {
return;
}
for seed in [0, 42, 123, 255] {
let input = generate_test_data(seed);
let mut transposed = [0u64; 16];
let mut roundtrip = [0u64; 16];
unsafe {
transpose_bits_vbmi(&input, &mut transposed);
untranspose_bits_vbmi::<u64>(&transposed, &mut roundtrip);
}
assert_eq!(input, roundtrip, "VBMI roundtrip failed for seed {seed}");
}
}
#[cfg(feature = "std")]
#[test]
fn test_vbmi_untranspose_all_widths_match_baseline() {
fn check<T: FastLanes>() {
for seed in [0, 42, 123, 255] {
let input = generate_test_data(seed);
let mut baseline_out = [0u64; 16];
let mut vbmi_out = [0u64; 16];
untranspose_bits_baseline::<T>(&input, &mut baseline_out);
unsafe { untranspose_bits_vbmi::<T>(&input, &mut vbmi_out) };
assert_eq!(
baseline_out,
vbmi_out,
"VBMI untranspose != baseline for type={} seed={seed}",
core::any::type_name::<T>()
);
}
}
if !has_vbmi() {
return;
}
check::<u8>();
check::<u16>();
check::<u32>();
check::<u64>();
}
}