#![allow(clippy::indexing_slicing, reason = "SIMD lane indices bounded by j<32 invariant")]
#[cfg(target_arch = "x86_64")]
use super::reader::CramError;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn decode_32state_loop(
src: &mut &[u8],
dst: &mut [u8],
frequencies: &[u32; 256],
cumulative_frequencies: &[u32; 256],
sym_table: &[u8; 4096],
states: &mut [u32],
) -> Result<(), CramError> {
debug_assert_eq!(states.len(), 32, "AVX2 32-state loop requires exactly 32 states");
let full_chunks = dst.len() / 32;
let mask_12bit = _mm256_set1_epi32(0xFFF);
for chunk_idx in 0..full_chunks {
let start = chunk_idx.wrapping_mul(32);
let chunk = dst
.get_mut(start..start.wrapping_add(32))
.ok_or(CramError::Truncated { context: "avx2 chunk range" })?;
for j in 0..32 {
let f = (states[j] & 0xFFF) as usize;
chunk[j] = sym_table[f];
}
for j in (0..32).step_by(8) {
let s = unsafe { _mm256_loadu_si256(states.as_ptr().add(j) as *const __m256i) };
let hi = _mm256_srli_epi32(s, 12);
let f = _mm256_and_si256(s, mask_12bit);
let syms = [
usize::from(chunk[j]),
usize::from(chunk[j.wrapping_add(1)]),
usize::from(chunk[j.wrapping_add(2)]),
usize::from(chunk[j.wrapping_add(3)]),
usize::from(chunk[j.wrapping_add(4)]),
usize::from(chunk[j.wrapping_add(5)]),
usize::from(chunk[j.wrapping_add(6)]),
usize::from(chunk[j.wrapping_add(7)]),
];
let freq_arr = [
frequencies[syms[0]],
frequencies[syms[1]],
frequencies[syms[2]],
frequencies[syms[3]],
frequencies[syms[4]],
frequencies[syms[5]],
frequencies[syms[6]],
frequencies[syms[7]],
];
let cum_arr = [
cumulative_frequencies[syms[0]],
cumulative_frequencies[syms[1]],
cumulative_frequencies[syms[2]],
cumulative_frequencies[syms[3]],
cumulative_frequencies[syms[4]],
cumulative_frequencies[syms[5]],
cumulative_frequencies[syms[6]],
cumulative_frequencies[syms[7]],
];
let (freqs, cums) = unsafe {
(
_mm256_loadu_si256(freq_arr.as_ptr() as *const __m256i),
_mm256_loadu_si256(cum_arr.as_ptr() as *const __m256i),
)
};
let prod = _mm256_mullo_epi32(freqs, hi);
let sum = _mm256_add_epi32(prod, f);
let new_s = _mm256_sub_epi32(sum, cums);
unsafe { _mm256_storeu_si256(states.as_mut_ptr().add(j) as *mut __m256i, new_s) };
}
for state in states.iter_mut() {
*state = super::rans_nx16::state_renormalize(*state, src)
.ok_or(CramError::Truncated { context: "avx2 renorm" })?;
}
}
let remainder_start = full_chunks.wrapping_mul(32);
if let Some(remainder) = dst.get_mut(remainder_start..) {
for (j, d) in remainder.iter_mut().enumerate() {
let state = states
.get_mut(j)
.ok_or(CramError::Truncated { context: "avx2 remainder state index" })?;
let f = *state & 0xFFF;
let sym = sym_table
.get(f as usize)
.ok_or(CramError::Truncated { context: "avx2 remainder sym_table" })?;
*d = *sym;
let i = usize::from(*sym);
*state = super::rans_nx16::state_step(
*state,
frequencies[i],
cumulative_frequencies[i],
super::rans_nx16::ORDER_0_BITS,
);
*state = super::rans_nx16::state_renormalize(*state, src)
.ok_or(CramError::Truncated { context: "avx2 remainder renorm" })?;
}
}
Ok(())
}