use crate::bit_transpose::BASE_PATTERN_FIRST;
use crate::bit_transpose::BASE_PATTERN_SECOND;
use crate::bit_transpose::TRANSPOSE_2X2;
use crate::bit_transpose::TRANSPOSE_4X4;
use crate::bit_transpose::TRANSPOSE_8X8;
#[inline(never)]
#[allow(dead_code)]
pub fn transpose_bits_scalar(input: &[u8; 128], output: &mut [u8; 128]) {
fn transpose_8x8(mut x: u64) -> u64 {
let t = (x ^ (x >> 7)) & TRANSPOSE_2X2;
x = x ^ t ^ (t << 7);
let t = (x ^ (x >> 14)) & TRANSPOSE_4X4;
x = x ^ t ^ (t << 14);
let t = (x ^ (x >> 28)) & TRANSPOSE_8X8;
x ^ t ^ (t << 28)
}
fn gather(input: &[u8; 128], base: usize) -> u64 {
u64::from(input[base])
| (u64::from(input[base + 16]) << 8)
| (u64::from(input[base + 32]) << 16)
| (u64::from(input[base + 48]) << 24)
| (u64::from(input[base + 64]) << 32)
| (u64::from(input[base + 80]) << 40)
| (u64::from(input[base + 96]) << 48)
| (u64::from(input[base + 112]) << 56)
}
let r0 = transpose_8x8(gather(input, BASE_PATTERN_FIRST[0]));
let r1 = transpose_8x8(gather(input, BASE_PATTERN_FIRST[1]));
let r2 = transpose_8x8(gather(input, BASE_PATTERN_FIRST[2]));
let r3 = transpose_8x8(gather(input, BASE_PATTERN_FIRST[3]));
let r4 = transpose_8x8(gather(input, BASE_PATTERN_FIRST[4]));
let r5 = transpose_8x8(gather(input, BASE_PATTERN_FIRST[5]));
let r6 = transpose_8x8(gather(input, BASE_PATTERN_FIRST[6]));
let r7 = transpose_8x8(gather(input, BASE_PATTERN_FIRST[7]));
for bit_pos in 0..8 {
output[bit_pos * 8] = (r0 >> (bit_pos * 8)) as u8;
output[bit_pos * 8 + 1] = (r1 >> (bit_pos * 8)) as u8;
output[bit_pos * 8 + 2] = (r2 >> (bit_pos * 8)) as u8;
output[bit_pos * 8 + 3] = (r3 >> (bit_pos * 8)) as u8;
output[bit_pos * 8 + 4] = (r4 >> (bit_pos * 8)) as u8;
output[bit_pos * 8 + 5] = (r5 >> (bit_pos * 8)) as u8;
output[bit_pos * 8 + 6] = (r6 >> (bit_pos * 8)) as u8;
output[bit_pos * 8 + 7] = (r7 >> (bit_pos * 8)) as u8;
}
let r0 = transpose_8x8(gather(input, BASE_PATTERN_SECOND[0]));
let r1 = transpose_8x8(gather(input, BASE_PATTERN_SECOND[1]));
let r2 = transpose_8x8(gather(input, BASE_PATTERN_SECOND[2]));
let r3 = transpose_8x8(gather(input, BASE_PATTERN_SECOND[3]));
let r4 = transpose_8x8(gather(input, BASE_PATTERN_SECOND[4]));
let r5 = transpose_8x8(gather(input, BASE_PATTERN_SECOND[5]));
let r6 = transpose_8x8(gather(input, BASE_PATTERN_SECOND[6]));
let r7 = transpose_8x8(gather(input, BASE_PATTERN_SECOND[7]));
for bit_pos in 0..8 {
output[64 + bit_pos * 8] = (r0 >> (bit_pos * 8)) as u8;
output[64 + bit_pos * 8 + 1] = (r1 >> (bit_pos * 8)) as u8;
output[64 + bit_pos * 8 + 2] = (r2 >> (bit_pos * 8)) as u8;
output[64 + bit_pos * 8 + 3] = (r3 >> (bit_pos * 8)) as u8;
output[64 + bit_pos * 8 + 4] = (r4 >> (bit_pos * 8)) as u8;
output[64 + bit_pos * 8 + 5] = (r5 >> (bit_pos * 8)) as u8;
output[64 + bit_pos * 8 + 6] = (r6 >> (bit_pos * 8)) as u8;
output[64 + bit_pos * 8 + 7] = (r7 >> (bit_pos * 8)) as u8;
}
}
#[inline(never)]
#[allow(dead_code)]
pub fn untranspose_bits_scalar(input: &[u8; 128], output: &mut [u8; 128]) {
fn transpose_8x8(mut x: u64) -> u64 {
let t = (x ^ (x >> 7)) & TRANSPOSE_2X2;
x = x ^ t ^ (t << 7);
let t = (x ^ (x >> 14)) & TRANSPOSE_4X4;
x = x ^ t ^ (t << 14);
let t = (x ^ (x >> 28)) & TRANSPOSE_8X8;
x ^ t ^ (t << 28)
}
fn gather_transposed(input: &[u8; 128], base_group: usize, offset: usize) -> u64 {
let mut result: u64 = 0;
for bit_pos in 0..8 {
result |= u64::from(input[offset + bit_pos * 8 + base_group]) << (bit_pos * 8);
}
result
}
fn scatter(output: &mut [u8; 128], base: usize, val: u64) {
output[base] = val as u8;
output[base + 16] = (val >> 8) as u8;
output[base + 32] = (val >> 16) as u8;
output[base + 48] = (val >> 24) as u8;
output[base + 64] = (val >> 32) as u8;
output[base + 80] = (val >> 40) as u8;
output[base + 96] = (val >> 48) as u8;
output[base + 112] = (val >> 56) as u8;
}
let r0 = transpose_8x8(gather_transposed(input, 0, 0));
let r1 = transpose_8x8(gather_transposed(input, 1, 0));
let r2 = transpose_8x8(gather_transposed(input, 2, 0));
let r3 = transpose_8x8(gather_transposed(input, 3, 0));
let r4 = transpose_8x8(gather_transposed(input, 4, 0));
let r5 = transpose_8x8(gather_transposed(input, 5, 0));
let r6 = transpose_8x8(gather_transposed(input, 6, 0));
let r7 = transpose_8x8(gather_transposed(input, 7, 0));
scatter(output, BASE_PATTERN_FIRST[0], r0);
scatter(output, BASE_PATTERN_FIRST[1], r1);
scatter(output, BASE_PATTERN_FIRST[2], r2);
scatter(output, BASE_PATTERN_FIRST[3], r3);
scatter(output, BASE_PATTERN_FIRST[4], r4);
scatter(output, BASE_PATTERN_FIRST[5], r5);
scatter(output, BASE_PATTERN_FIRST[6], r6);
scatter(output, BASE_PATTERN_FIRST[7], r7);
let r0 = transpose_8x8(gather_transposed(input, 0, 64));
let r1 = transpose_8x8(gather_transposed(input, 1, 64));
let r2 = transpose_8x8(gather_transposed(input, 2, 64));
let r3 = transpose_8x8(gather_transposed(input, 3, 64));
let r4 = transpose_8x8(gather_transposed(input, 4, 64));
let r5 = transpose_8x8(gather_transposed(input, 5, 64));
let r6 = transpose_8x8(gather_transposed(input, 6, 64));
let r7 = transpose_8x8(gather_transposed(input, 7, 64));
scatter(output, BASE_PATTERN_SECOND[0], r0);
scatter(output, BASE_PATTERN_SECOND[1], r1);
scatter(output, BASE_PATTERN_SECOND[2], r2);
scatter(output, BASE_PATTERN_SECOND[3], r3);
scatter(output, BASE_PATTERN_SECOND[4], r4);
scatter(output, BASE_PATTERN_SECOND[5], r5);
scatter(output, BASE_PATTERN_SECOND[6], r6);
scatter(output, BASE_PATTERN_SECOND[7], r7);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bit_transpose::generate_test_data;
#[test]
fn test_scalar_matches_baseline() {
for seed in [0, 42, 123, 255] {
let input = generate_test_data(seed);
let mut baseline_out = [0u8; 128];
let mut fast_out = [0u8; 128];
transpose_bits_scalar(&input, &mut baseline_out);
transpose_bits_scalar(&input, &mut fast_out);
assert_eq!(
baseline_out, fast_out,
"scalar_fast transpose doesn't match baseline for seed {seed}"
);
}
}
#[test]
fn test_scalar_roundtrip() {
for seed in [0, 42, 123, 255] {
let input = generate_test_data(seed);
let mut transposed = [0u8; 128];
let mut roundtrip = [0u8; 128];
transpose_bits_scalar(&input, &mut transposed);
untranspose_bits_scalar(&transposed, &mut roundtrip);
assert_eq!(
input, roundtrip,
"scalar_fast roundtrip failed for seed {seed}"
);
}
}
#[test]
fn test_all_zeros() {
let input = [0u8; 128];
let mut output = [0xFFu8; 128];
transpose_bits_scalar(&input, &mut output);
assert_eq!(output, [0u8; 128]);
untranspose_bits_scalar(&input, &mut output);
assert_eq!(output, [0u8; 128]);
}
#[test]
fn test_all_ones() {
let input = [0xFFu8; 128];
let mut output = [0u8; 128];
transpose_bits_scalar(&input, &mut output);
assert_eq!(output, [0xFFu8; 128]);
untranspose_bits_scalar(&input, &mut output);
assert_eq!(output, [0xFFu8; 128]);
}
}