pub mod scalar;
#[cfg(target_arch = "x86_64")]
pub mod x86;
#[cfg(target_arch = "aarch64")]
pub mod aarch64;
pub(crate) const BASE_PATTERN_FIRST: [usize; 8] = [0, 8, 4, 12, 2, 10, 6, 14];
pub(crate) const BASE_PATTERN_SECOND: [usize; 8] = [1, 9, 5, 13, 3, 11, 7, 15];
pub(crate) const TRANSPOSE_2X2: u64 = 0x00AA_00AA_00AA_00AA;
pub(crate) const TRANSPOSE_4X4: u64 = 0x0000_CCCC_0000_CCCC;
pub(crate) const TRANSPOSE_8X8: u64 = 0x0000_0000_F0F0_F0F0;
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
pub(crate) mod group_perm {
const fn gather_indices(tb: usize) -> [u8; 128] {
let bytes = tb / 8;
let mut idx = [0u8; 128];
let mut g = 0;
while g < 16 {
let lhi = g / bytes;
let hi = g % bytes;
let gather_base = lhi * tb + hi;
let mut llo = 0;
while llo < 8 {
idx[g * 8 + llo] = (gather_base + llo * bytes) as u8;
llo += 1;
}
g += 1;
}
idx
}
const fn scatter_indices(tb: usize) -> [u8; 128] {
let bytes = tb / 8;
let mut idx = [0u8; 128];
let mut g = 0;
while g < 16 {
let lhi = g / bytes;
let hi = g % bytes;
let scatter_base = crate::FL_ORDER[hi] * 2 + lhi;
let mut lo = 0;
while lo < 8 {
idx[scatter_base + lo * 16] = (g * 8 + lo) as u8;
lo += 1;
}
g += 1;
}
idx
}
static GATHER_8: [u8; 128] = gather_indices(8);
static GATHER_16: [u8; 128] = gather_indices(16);
static GATHER_32: [u8; 128] = gather_indices(32);
static GATHER_64: [u8; 128] = gather_indices(64);
static SCATTER_8: [u8; 128] = scatter_indices(8);
static SCATTER_16: [u8; 128] = scatter_indices(16);
static SCATTER_32: [u8; 128] = scatter_indices(32);
static SCATTER_64: [u8; 128] = scatter_indices(64);
#[inline]
pub(crate) fn group_tables<T: crate::FastLanes>() -> (&'static [u8; 128], &'static [u8; 128]) {
match T::T {
8 => (&GATHER_8, &SCATTER_8),
16 => (&GATHER_16, &SCATTER_16),
32 => (&GATHER_32, &SCATTER_32),
_ => (&GATHER_64, &SCATTER_64),
}
}
}
#[inline]
#[must_use]
pub(crate) fn as_byte_array(block: &[u64; 16]) -> &[u8; 128] {
unsafe { &*block.as_ptr().cast::<[u8; 128]>() }
}
#[inline]
#[must_use]
pub(crate) fn as_byte_array_mut(block: &mut [u64; 16]) -> &mut [u8; 128] {
unsafe { &mut *block.as_mut_ptr().cast::<[u8; 128]>() }
}
#[cfg(target_arch = "x86_64")]
#[inline]
fn detect_vbmi() -> bool {
#[cfg(feature = "std")]
{
std::is_x86_feature_detected!("avx512vbmi")
&& std::is_x86_feature_detected!("avx512bw")
&& std::is_x86_feature_detected!("avx512f")
}
#[cfg(not(feature = "std"))]
{
cfg!(all(
target_feature = "avx512vbmi",
target_feature = "avx512bw",
target_feature = "avx512f"
))
}
}
#[cfg(target_arch = "x86_64")]
#[inline]
fn detect_bmi2() -> bool {
#[cfg(feature = "std")]
{
std::is_x86_feature_detected!("bmi2")
}
#[cfg(not(feature = "std"))]
{
cfg!(target_feature = "bmi2")
}
}
#[inline]
pub fn transpose_bits(input: &[u64; 16], output: &mut [u64; 16]) {
#[cfg(target_arch = "x86_64")]
{
if detect_vbmi() {
unsafe { x86::transpose_bits_vbmi(input, output) }
} else if detect_bmi2() {
unsafe { x86::transpose_bits_bmi2(input, output) }
} else {
scalar::transpose_bits(input, output);
}
}
#[cfg(target_arch = "aarch64")]
unsafe {
aarch64::transpose_bits_neon(input, output);
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
scalar::transpose_bits(input, output);
}
#[inline]
pub fn untranspose_bits<T: crate::FastLanes>(input: &[u64; 16], output: &mut [u64; 16]) {
#[cfg(target_arch = "x86_64")]
{
if detect_vbmi() {
unsafe { x86::untranspose_bits_vbmi::<T>(input, output) }
} else if detect_bmi2() {
unsafe { x86::untranspose_bits_bmi2::<T>(input, output) }
} else {
scalar::untranspose_bits::<T>(input, output);
}
}
#[cfg(target_arch = "aarch64")]
unsafe {
aarch64::untranspose_bits_neon::<T>(input, output);
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
scalar::untranspose_bits::<T>(input, output);
}
#[cfg(test)]
pub(crate) fn generate_test_data(seed: u8) -> [u64; 16] {
let mut data = [0u64; 16];
for (i, byte) in as_byte_array_mut(&mut data).iter_mut().enumerate() {
*byte = seed.wrapping_mul(17).wrapping_add(i as u8).wrapping_mul(31);
}
data
}
#[cfg(test)]
pub(crate) fn transpose_bits_baseline(input: &[u64; 16], output: &mut [u64; 16]) {
let input = as_byte_array(input);
let output = as_byte_array_mut(output);
*output = [0u8; 128];
for in_bit in 0..1024 {
let out_bit = crate::transpose(in_bit);
let bit_val = (input[in_bit / 8] >> (in_bit % 8)) & 1;
output[out_bit / 8] |= bit_val << (out_bit % 8);
}
}
#[cfg(test)]
pub(crate) fn untranspose_bits_baseline<T: crate::FastLanes>(
input: &[u64; 16],
output: &mut [u64; 16],
) {
let input = as_byte_array(input);
let output = as_byte_array_mut(output);
*output = [0u8; 128];
for b in 0..1024 {
let lane = b / T::T;
let row = b % T::T;
let logical = crate::FL_ORDER[row / 8] * 16 + (row % 8) * 128 + lane;
let bit_val = (input[b / 8] >> (b % 8)) & 1;
output[logical / 8] |= bit_val << (logical % 8);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_baseline_roundtrip() {
let input = generate_test_data(42);
let mut transposed = [0u64; 16];
let mut roundtrip = [0u64; 16];
transpose_bits_baseline(&input, &mut transposed);
untranspose_bits_baseline::<u64>(&transposed, &mut roundtrip);
assert_eq!(input, roundtrip);
}
#[test]
fn test_dispatch_matches_baseline() {
for seed in [0, 42, 123, 255] {
let input = generate_test_data(seed);
let mut baseline_out = [0u64; 16];
let mut out = [0u64; 16];
transpose_bits_baseline(&input, &mut baseline_out);
transpose_bits(&input, &mut out);
assert_eq!(
baseline_out, out,
"transpose dispatch doesn't match baseline for seed {seed}"
);
}
}
#[test]
fn test_untranspose_dispatch_matches_baseline() {
fn check<T: crate::FastLanes>() {
for seed in [0, 42, 123, 255] {
let input = generate_test_data(seed);
let mut baseline_out = [0u64; 16];
let mut out = [0u64; 16];
untranspose_bits_baseline::<T>(&input, &mut baseline_out);
untranspose_bits::<T>(&input, &mut out);
assert_eq!(
baseline_out,
out,
"untranspose dispatch doesn't match baseline for type={} seed={seed}",
core::any::type_name::<T>()
);
}
}
check::<u8>();
check::<u16>();
check::<u32>();
check::<u64>();
}
#[test]
fn test_dispatch_roundtrip() {
for seed in [0, 42, 123, 255] {
let input = generate_test_data(seed);
let mut transposed = [0u64; 16];
let mut roundtrip = [0u64; 16];
transpose_bits(&input, &mut transposed);
untranspose_bits::<u64>(&transposed, &mut roundtrip);
assert_eq!(
input, roundtrip,
"dispatch roundtrip failed for seed {seed}"
);
}
}
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
mod group_perm_tables {
use super::*;
fn untranspose_via_tables<T: crate::FastLanes>(input: &[u64; 16]) -> [u64; 16] {
fn transpose_8x8(mut x: u64) -> u64 {
let t = (x ^ (x >> 7)) & TRANSPOSE_2X2;
x = x ^ t ^ (t << 7);
let t = (x ^ (x >> 14)) & TRANSPOSE_4X4;
x = x ^ t ^ (t << 14);
let t = (x ^ (x >> 28)) & TRANSPOSE_8X8;
x ^ t ^ (t << 28)
}
let (gather, scatter) = group_perm::group_tables::<T>();
let src = as_byte_array(input);
let mut grouped = [0u8; 128];
for k in 0..128 {
grouped[k] = src[gather[k] as usize];
}
let mut transposed = [0u8; 128];
for g in 0..16 {
let mut word = 0u64;
for b in 0..8 {
word |= u64::from(grouped[g * 8 + b]) << (b * 8);
}
let w = transpose_8x8(word);
for b in 0..8 {
transposed[g * 8 + b] = (w >> (b * 8)) as u8;
}
}
let mut out = [0u64; 16];
let dst = as_byte_array_mut(&mut out);
for k in 0..128 {
dst[k] = transposed[scatter[k] as usize];
}
out
}
#[test]
fn tables_match_baseline_all_widths() {
fn check<T: crate::FastLanes>() {
for seed in [0, 1, 42, 123, 200, 255] {
let input = generate_test_data(seed);
let mut baseline = [0u64; 16];
untranspose_bits_baseline::<T>(&input, &mut baseline);
assert_eq!(
untranspose_via_tables::<T>(&input),
baseline,
"group_perm tables != baseline for type={} seed={seed}",
core::any::type_name::<T>()
);
}
}
check::<u8>();
check::<u16>();
check::<u32>();
check::<u64>();
}
#[test]
fn tables_are_permutations() {
fn is_permutation(t: &[u8; 128]) -> bool {
let mut seen = [false; 128];
for &i in t {
if i as usize >= 128 || seen[i as usize] {
return false;
}
seen[i as usize] = true;
}
seen.iter().all(|&b| b)
}
fn check<T: crate::FastLanes>() {
let (gather, scatter) = group_perm::group_tables::<T>();
assert!(
is_permutation(gather),
"gather table for {} is not a permutation",
core::any::type_name::<T>()
);
assert!(
is_permutation(scatter),
"scatter table for {} is not a permutation",
core::any::type_name::<T>()
);
}
check::<u8>();
check::<u16>();
check::<u32>();
check::<u64>();
}
}
}