use crate::bit_transpose::as_byte_array;
use crate::bit_transpose::as_byte_array_mut;
use crate::bit_transpose::BASE_PATTERN_FIRST;
use crate::bit_transpose::BASE_PATTERN_SECOND;
use crate::bit_transpose::TRANSPOSE_2X2;
use crate::bit_transpose::TRANSPOSE_4X4;
use crate::bit_transpose::TRANSPOSE_8X8;
use crate::FastLanes;
use crate::FL_ORDER;
const HALVES: [[usize; 8]; 2] = [BASE_PATTERN_FIRST, BASE_PATTERN_SECOND];
#[inline]
fn transpose_8x8(mut x: u64) -> u64 {
let t = (x ^ (x >> 7)) & TRANSPOSE_2X2;
x = x ^ t ^ (t << 7);
let t = (x ^ (x >> 14)) & TRANSPOSE_4X4;
x = x ^ t ^ (t << 14);
let t = (x ^ (x >> 28)) & TRANSPOSE_8X8;
x ^ t ^ (t << 28)
}
#[inline]
fn gather(input: &[u8; 128], base: usize) -> u64 {
let mut result = 0u64;
for row in 0..8 {
result |= u64::from(input[base + row * 16]) << (row * 8);
}
result
}
#[inline]
fn scatter(output: &mut [u8; 128], base: usize, val: u64) {
for row in 0..8 {
output[base + row * 16] = (val >> (row * 8)) as u8;
}
}
#[inline]
pub fn transpose_bits(input: &[u64; 16], output: &mut [u64; 16]) {
let input = as_byte_array(input);
let output = as_byte_array_mut(output);
for (half, groups) in HALVES.iter().enumerate() {
for (group, &base) in groups.iter().enumerate() {
let row = transpose_8x8(gather(input, base));
for bit in 0..8 {
output[half * 64 + bit * 8 + group] = (row >> (bit * 8)) as u8;
}
}
}
}
#[inline]
pub fn untranspose_bits<T: FastLanes>(input: &[u64; 16], output: &mut [u64; 16]) {
let input = as_byte_array(input);
let output = as_byte_array_mut(output);
let bytes = T::T / 8;
let lhi_count = 128 / T::T;
for lhi in 0..lhi_count {
for hi in 0..bytes {
let gather_base = lhi * T::T + hi;
let mut packed = 0u64;
for llo in 0..8 {
packed |= u64::from(input[gather_base + llo * bytes]) << (llo * 8);
}
let scatter_base = FL_ORDER[hi] * 2 + lhi;
scatter(output, scatter_base, transpose_8x8(packed));
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bit_transpose::generate_test_data;
use crate::bit_transpose::transpose_bits_baseline;
use crate::bit_transpose::untranspose_bits_baseline;
#[test]
fn test_scalar_matches_baseline() {
for seed in [0, 42, 123, 255] {
let input = generate_test_data(seed);
let mut baseline_out = [0u64; 16];
let mut scalar_out = [0u64; 16];
transpose_bits_baseline(&input, &mut baseline_out);
transpose_bits(&input, &mut scalar_out);
assert_eq!(
baseline_out, scalar_out,
"scalar transpose doesn't match baseline for seed {seed}"
);
}
}
#[test]
fn test_scalar_roundtrip() {
for seed in [0, 42, 123, 255] {
let input = generate_test_data(seed);
let mut transposed = [0u64; 16];
let mut roundtrip = [0u64; 16];
transpose_bits(&input, &mut transposed);
untranspose_bits::<u64>(&transposed, &mut roundtrip);
assert_eq!(input, roundtrip, "scalar roundtrip failed for seed {seed}");
}
}
#[test]
fn test_all_zeros() {
let input = [0u64; 16];
let mut output = [u64::MAX; 16];
transpose_bits(&input, &mut output);
assert_eq!(output, [0u64; 16]);
untranspose_bits::<u64>(&input, &mut output);
assert_eq!(output, [0u64; 16]);
}
#[test]
fn test_all_ones() {
let input = [u64::MAX; 16];
let mut output = [0u64; 16];
transpose_bits(&input, &mut output);
assert_eq!(output, [u64::MAX; 16]);
untranspose_bits::<u64>(&input, &mut output);
assert_eq!(output, [u64::MAX; 16]);
}
#[test]
fn test_untranspose_all_widths_match_baseline() {
fn check<T: FastLanes>() {
for seed in [0, 42, 123, 255] {
let input = generate_test_data(seed);
let mut baseline_out = [0u64; 16];
let mut scalar_out = [0u64; 16];
untranspose_bits_baseline::<T>(&input, &mut baseline_out);
untranspose_bits::<T>(&input, &mut scalar_out);
assert_eq!(
baseline_out,
scalar_out,
"scalar untranspose != baseline for type={} seed={seed}",
core::any::type_name::<T>()
);
}
}
check::<u8>();
check::<u16>();
check::<u32>();
check::<u64>();
}
}