#![allow(clippy::indexing_slicing, reason = "SIMD lane indices bounded by j<32 invariant")]
#[cfg(target_arch = "aarch64")]
use super::reader::CramError;
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub(crate) unsafe fn decode_32state_loop(
src: &mut &[u8],
dst: &mut [u8],
frequencies: &[u32; 256],
cumulative_frequencies: &[u32; 256],
sym_table: &[u8; 4096],
states: &mut [u32],
) -> Result<(), CramError> {
debug_assert_eq!(states.len(), 32, "NEON 32-state loop requires exactly 32 states");
let full_chunks = dst.len() / 32;
let mask_12bit = vdupq_n_u32(0xFFF);
for chunk_idx in 0..full_chunks {
let start = chunk_idx.wrapping_mul(32);
let chunk = dst
.get_mut(start..start.wrapping_add(32))
.ok_or(CramError::Truncated { context: "neon chunk range" })?;
for j in 0..32 {
let f = (states[j] & 0xFFF) as usize;
chunk[j] = sym_table[f];
}
for j in (0..32).step_by(4) {
let s = unsafe { vld1q_u32(states.as_ptr().add(j)) };
let hi = vshrq_n_u32(s, 12);
let f = vandq_u32(s, mask_12bit);
let syms = [
usize::from(chunk[j]),
usize::from(chunk[j.wrapping_add(1)]),
usize::from(chunk[j.wrapping_add(2)]),
usize::from(chunk[j.wrapping_add(3)]),
];
let freq_arr = [
frequencies[syms[0]],
frequencies[syms[1]],
frequencies[syms[2]],
frequencies[syms[3]],
];
let cum_arr = [
cumulative_frequencies[syms[0]],
cumulative_frequencies[syms[1]],
cumulative_frequencies[syms[2]],
cumulative_frequencies[syms[3]],
];
let (freqs, cums) =
unsafe { (vld1q_u32(freq_arr.as_ptr()), vld1q_u32(cum_arr.as_ptr())) };
let new_s = vsubq_u32(vaddq_u32(vmulq_u32(freqs, hi), f), cums);
unsafe { vst1q_u32(states.as_mut_ptr().add(j), new_s) };
}
for state in states.iter_mut() {
*state = super::rans_nx16::state_renormalize(*state, src)
.ok_or(CramError::Truncated { context: "neon renorm" })?;
}
}
let remainder_start = full_chunks.wrapping_mul(32);
if let Some(remainder) = dst.get_mut(remainder_start..) {
for (j, d) in remainder.iter_mut().enumerate() {
let state = states
.get_mut(j)
.ok_or(CramError::Truncated { context: "neon remainder state index" })?;
let f = *state & 0xFFF;
let sym = sym_table
.get(f as usize)
.ok_or(CramError::Truncated { context: "neon remainder sym_table" })?;
*d = *sym;
let i = usize::from(*sym);
*state = super::rans_nx16::state_step(
*state,
frequencies[i],
cumulative_frequencies[i],
super::rans_nx16::ORDER_0_BITS,
);
*state = super::rans_nx16::state_renormalize(*state, src)
.ok_or(CramError::Truncated { context: "neon remainder renorm" })?;
}
}
Ok(())
}