#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn unpack_4bit_avx2(packed: *const u8, output: *mut u8) {
let data = _mm256_loadu_si256(packed as *const __m256i);
let mask_low = _mm256_set1_epi8(0x0F);
let high = _mm256_and_si256(_mm256_srli_epi16(data, 4), mask_low);
let low = _mm256_and_si256(data, mask_low);
let interleaved_lo = _mm256_unpacklo_epi8(high, low);
let interleaved_hi = _mm256_unpackhi_epi8(high, low);
let lo_128 = _mm256_castsi256_si128(interleaved_lo);
let hi_128_lo = _mm256_extracti128_si256(interleaved_lo, 1);
let lo_128_hi = _mm256_castsi256_si128(interleaved_hi);
let hi_128_hi = _mm256_extracti128_si256(interleaved_hi, 1);
let result_0 = _mm256_set_m128i(lo_128_hi, lo_128);
let result_1 = _mm256_set_m128i(hi_128_hi, hi_128_lo);
_mm256_storeu_si256(output as *mut __m256i, result_0);
_mm256_storeu_si256(output.add(32) as *mut __m256i, result_1);
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn unpack_4bit_buffer_avx2(packed: *const u8, packed_len: usize, output: *mut u8) {
let chunks = packed_len / 32;
let remainder = packed_len % 32;
for i in 0..chunks {
let packed_ptr = packed.add(i * 32);
let output_ptr = output.add(i * 64);
unpack_4bit_avx2(packed_ptr, output_ptr);
}
let rem_start = chunks * 32;
for i in 0..remainder {
let byte = *packed.add(rem_start + i);
let high = byte >> 4;
let low = byte & 0x0F;
*output.add((rem_start + i) * 2) = high;
*output.add((rem_start + i) * 2 + 1) = low;
}
}
#[inline]
pub fn unpack_4bit_scalar(packed: &[u8], output: &mut [u8]) {
debug_assert!(output.len() >= packed.len() * 2);
for (i, &byte) in packed.iter().enumerate() {
output[i * 2] = byte >> 4;
output[i * 2 + 1] = byte & 0x0F;
}
}
#[inline]
pub fn unpack_4bit(packed: &[u8], output: &mut [u8]) {
#[cfg(target_arch = "x86_64")]
{
if std::arch::is_x86_feature_detected!("avx2") {
unsafe {
unpack_4bit_buffer_avx2(packed.as_ptr(), packed.len(), output.as_mut_ptr());
}
return;
}
}
unpack_4bit_scalar(packed, output);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_unpack_scalar() {
let packed = vec![0x01, 0x23, 0x45, 0x67];
let mut output = vec![0u8; 8];
unpack_4bit_scalar(&packed, &mut output);
assert_eq!(output, vec![0, 1, 2, 3, 4, 5, 6, 7]);
}
#[test]
fn test_unpack_all_values() {
let packed: Vec<u8> = (0..=255u8).collect();
let mut output = vec![0u8; 512];
unpack_4bit_scalar(&packed, &mut output);
for i in 0..256 {
let expected_high = (i >> 4) as u8;
let expected_low = (i & 0x0F) as u8;
assert_eq!(
output[i * 2],
expected_high,
"High nibble mismatch at {}",
i
);
assert_eq!(
output[i * 2 + 1],
expected_low,
"Low nibble mismatch at {}",
i
);
}
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_unpack_avx2_matches_scalar() {
if !std::arch::is_x86_feature_detected!("avx2") {
println!("AVX2 not available, skipping test");
return;
}
for size in [32, 64, 100, 128, 256, 1000] {
let packed: Vec<u8> = (0..size).map(|i| i as u8).collect();
let mut output_scalar = vec![0u8; size * 2];
let mut output_simd = vec![0u8; size * 2];
unpack_4bit_scalar(&packed, &mut output_scalar);
unpack_4bit(&packed, &mut output_simd);
assert_eq!(output_scalar, output_simd, "Mismatch for size {}", size);
}
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_unpack_avx2_exact_chunk() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let packed: Vec<u8> = (0..32).collect();
let mut output = vec![0u8; 64];
unsafe {
unpack_4bit_avx2(packed.as_ptr(), output.as_mut_ptr());
}
for i in 0..32 {
let expected_high = (i >> 4) as u8;
let expected_low = (i & 0x0F) as u8;
assert_eq!(
output[i * 2],
expected_high,
"High nibble mismatch at packed byte {}",
i
);
assert_eq!(
output[i * 2 + 1],
expected_low,
"Low nibble mismatch at packed byte {}",
i
);
}
}
#[test]
fn test_unpack_runtime_dispatch() {
let packed: Vec<u8> = (0..100).collect();
let mut output = vec![0u8; 200];
unpack_4bit(&packed, &mut output);
for i in 0..100 {
let expected_high = (i >> 4) as u8;
let expected_low = (i & 0x0F) as u8;
assert_eq!(output[i * 2], expected_high);
assert_eq!(output[i * 2 + 1], expected_low);
}
}
}