use core::arch::x86_64::*;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
#[cfg(feature = "std")]
use std::vec::Vec;
use super::shuffle::{ENCODE_TABLE, TABLE};
use crate::error::DecodeError;
#[allow(dead_code)]
#[target_feature(enable = "avx2")]
pub(super) unsafe fn encode_into(values: &[u16], out: &mut Vec<u8>) {
let n = values.len();
if n == 0 {
return;
}
let ctrl_len = n.div_ceil(8);
let ctrl_start = out.len();
out.reserve(ctrl_len + 2 * n + 16);
out.resize(ctrl_start + ctrl_len, 0u8);
let simd_n = (n / 16) * 16;
let data_start = ctrl_start + ctrl_len;
let base_ptr = out.as_mut_ptr();
let mut data_pos = 0usize;
let mut block = 0usize;
while block * 8 < simd_n {
let i = block * 8;
let v = unsafe {
_mm256_loadu_si256(values.as_ptr().add(i) as *const __m256i)
};
let hi = _mm256_srli_epi16(v, 8);
let needs_two = _mm256_cmpgt_epi16(hi, _mm256_setzero_si256());
let ctrl_packed = _mm256_packs_epi16(needs_two, needs_two);
let movemask = _mm256_movemask_epi8(ctrl_packed);
let c0 = (movemask & 0xFF) as u8;
let c1 = ((movemask >> 16) & 0xFF) as u8;
unsafe {
*base_ptr.add(ctrl_start + block) = c0;
*base_ptr.add(ctrl_start + block + 1) = c1;
let enc_mask_lo = _mm_loadu_si128(ENCODE_TABLE[c0 as usize].as_ptr() as *const __m128i);
let enc_mask_hi = _mm_loadu_si128(ENCODE_TABLE[c1 as usize].as_ptr() as *const __m128i);
let enc_mask = _mm256_set_m128i(enc_mask_hi, enc_mask_lo);
let packed = _mm256_shuffle_epi8(v, enc_mask);
let lo128 = _mm256_castsi256_si128(packed);
_mm_storeu_si128(base_ptr.add(data_start + data_pos) as *mut __m128i, lo128);
data_pos += 8 + c0.count_ones() as usize;
let hi128 = _mm256_extracti128_si256(packed, 1);
_mm_storeu_si128(base_ptr.add(data_start + data_pos) as *mut __m128i, hi128);
data_pos += 8 + c1.count_ones() as usize;
}
block += 2;
}
unsafe {
out.set_len(data_start + data_pos);
}
for j in simd_n..n {
let v = values[j];
if v <= 0xFF {
out.push(v as u8);
} else {
out[ctrl_start + j / 8] |= 1u8 << (j % 8);
out.extend_from_slice(&v.to_le_bytes());
}
}
}
#[allow(dead_code)]
#[target_feature(enable = "avx2")]
pub(super) unsafe fn decode_into(
data: &[u8],
n: usize,
out: &mut Vec<u16>,
) -> Result<(), DecodeError> {
if n == 0 {
return Ok(());
}
let ctrl_len = n.div_ceil(8);
if data.len() < ctrl_len {
return Err(DecodeError::ControlStreamTooShort {
need: ctrl_len,
have: data.len(),
});
}
let ctrl = &data[..ctrl_len];
let data_bytes = &data[ctrl_len..];
out.reserve(n);
let base = out.len();
let mut ctrl_pos = 0usize;
let mut data_pos = 0usize;
let mut decoded = 0usize;
while decoded + 16 <= n {
let c0 = ctrl[ctrl_pos];
let c1 = ctrl[ctrl_pos + 1];
let c0_bytes = 8 + c0.count_ones() as usize;
if data_pos + c0_bytes + 16 > data_bytes.len() {
break;
}
let result = unsafe {
let mask_lo = _mm_loadu_si128(TABLE[c0 as usize].as_ptr() as *const __m128i);
let mask_hi = _mm_loadu_si128(TABLE[c1 as usize].as_ptr() as *const __m128i);
let chunk_lo = _mm_loadu_si128(data_bytes.as_ptr().add(data_pos) as *const __m128i);
let chunk_hi =
_mm_loadu_si128(data_bytes.as_ptr().add(data_pos + c0_bytes) as *const __m128i);
let mask256 = _mm256_set_m128i(mask_hi, mask_lo);
let data256 = _mm256_set_m128i(chunk_hi, chunk_lo);
_mm256_shuffle_epi8(data256, mask256)
};
unsafe {
let out_ptr = out.as_mut_ptr().add(base + decoded) as *mut __m256i;
_mm256_storeu_si256(out_ptr, result);
}
data_pos += c0_bytes + 8 + c1.count_ones() as usize;
ctrl_pos += 2;
decoded += 16;
}
unsafe {
out.set_len(base + decoded);
}
if decoded + 8 <= n {
let mut padded = [0u8; 64];
let rem = data_bytes.len() - data_pos;
padded[..rem].copy_from_slice(&data_bytes[data_pos..]);
let mut padded_pos = 0usize;
while decoded + 8 <= n {
let cb = ctrl[ctrl_pos];
let result = unsafe {
let mask = _mm_loadu_si128(TABLE[cb as usize].as_ptr() as *const __m128i);
let chunk = _mm_loadu_si128(padded.as_ptr().add(padded_pos) as *const __m128i);
_mm_shuffle_epi8(chunk, mask)
};
unsafe {
let out_ptr = out.as_mut_ptr().add(base + decoded) as *mut __m128i;
_mm_storeu_si128(out_ptr, result);
}
let consumed = 8 + cb.count_ones() as usize;
padded_pos += consumed;
data_pos += consumed;
ctrl_pos += 1;
decoded += 8;
}
unsafe {
out.set_len(base + decoded);
}
}
if decoded < n {
super::scalar::decode_from_raw(
&ctrl[ctrl_pos..],
&data_bytes[data_pos..],
n - decoded,
out,
)?;
}
Ok(())
}