#![allow(
clippy::unnecessary_lazy_evaluations,
reason = "lazy form avoids per-call drop_in_place<CramError> on hot path"
)]
use super::codec_io::{self, Uint7Error};
use super::reader::CramError;
fn uint7_to_cram_error(e: Uint7Error) -> CramError {
match e {
Uint7Error::Truncated => CramError::Truncated { context: "rans_nx16 uint7" },
Uint7Error::Overflow => CramError::Uint7Overflow,
}
}
const ALPHABET_SIZE: usize = 256;
const FLAG_ORDER: u8 = 0x01;
const FLAG_N32: u8 = 0x04;
const FLAG_STRIPE: u8 = 0x08;
const FLAG_NO_SIZE: u8 = 0x10;
const FLAG_CAT: u8 = 0x20;
const FLAG_RLE: u8 = 0x40;
const FLAG_PACK: u8 = 0x80;
type Frequencies1 = Box<[[u32; ALPHABET_SIZE]; ALPHABET_SIZE]>;
type CumulativeFrequencies1 = Box<[[u32; ALPHABET_SIZE]; ALPHABET_SIZE]>;
type SymTables1 = Box<[[u8; 4096]; ALPHABET_SIZE]>;
pub(crate) struct Nx16Order1Buf {
pub frequencies: Frequencies1,
pub cumulative_frequencies: CumulativeFrequencies1,
pub sym_tables: SymTables1,
pub states: Vec<u32>,
pub prev_syms: Vec<u8>,
}
impl Nx16Order1Buf {
pub fn new() -> Self {
Self {
frequencies: Box::new([[0u32; ALPHABET_SIZE]; ALPHABET_SIZE]),
cumulative_frequencies: Box::new([[0u32; ALPHABET_SIZE]; ALPHABET_SIZE]),
sym_tables: Box::new([[0u8; 4096]; ALPHABET_SIZE]),
states: Vec::with_capacity(32),
prev_syms: Vec::with_capacity(32),
}
}
}
impl Default for Nx16Order1Buf {
fn default() -> Self {
Self::new()
}
}
pub fn decode(src: &[u8], uncompressed_size: usize) -> Result<Vec<u8>, CramError> {
let mut buf = Nx16Order1Buf::new();
decode_with_buf(src, uncompressed_size, &mut buf)
}
pub(crate) fn decode_with_buf(
src: &[u8],
mut uncompressed_size: usize,
buf: &mut Nx16Order1Buf,
) -> Result<Vec<u8>, CramError> {
let mut cur: &[u8] = src;
let flags =
read_u8(&mut cur).ok_or_else(|| CramError::Truncated { context: "rans_nx16 flags" })?;
let state_count = if flags & FLAG_N32 != 0 { 32 } else { 4 };
if flags & FLAG_NO_SIZE == 0 {
uncompressed_size = read_uint7(&mut cur)? as usize;
}
super::reader::check_alloc_size(uncompressed_size, "rANS Nx16 output")?;
if flags & FLAG_STRIPE != 0 {
return decode_stripe_with_buf(&mut cur, uncompressed_size, buf);
}
let bit_pack_ctx = if flags & FLAG_PACK != 0 {
let (ctx, len) = read_bit_pack_context(&mut cur, uncompressed_size)?;
uncompressed_size = len;
Some(ctx)
} else {
None
};
let rle_ctx = if flags & FLAG_RLE != 0 {
let (ctx, len) = read_rle_context(&mut cur, state_count, uncompressed_size)?;
uncompressed_size = len;
Some(ctx)
} else {
None
};
let mut dst = vec![0u8; uncompressed_size];
if flags & FLAG_CAT != 0 {
let data = split_off(&mut cur, uncompressed_size)?;
dst.copy_from_slice(data);
} else if flags & FLAG_ORDER == 0 {
decode_order_0(&mut cur, &mut dst, state_count)?;
} else {
decode_order_1_with_buf(&mut cur, &mut dst, state_count, buf)?;
}
if let Some(ctx) = rle_ctx {
dst = apply_rle(&dst, &ctx)?;
}
if let Some(ctx) = bit_pack_ctx {
dst = apply_bit_unpack(&dst, &ctx)?;
}
Ok(dst)
}
pub(crate) use super::codec_io::read_u16_le as read_u16_le_prv;
use super::codec_io::{read_u8, read_u32_le};
#[inline]
fn read_uint7(src: &mut &[u8]) -> Result<u32, CramError> {
codec_io::read_uint7(src).map_err(uint7_to_cram_error)
}
#[inline]
fn split_off<'a>(src: &mut &'a [u8], len: usize) -> Result<&'a [u8], CramError> {
codec_io::split_off(src, len).ok_or(CramError::Truncated { context: "rans_nx16 split_off" })
}
fn read_states(src: &mut &[u8], state_count: usize) -> Result<Vec<u32>, CramError> {
(0..state_count)
.map(|_| {
read_u32_le(src).ok_or_else(|| CramError::Truncated { context: "rans_nx16 state" })
})
.collect()
}
fn state_cumulative_frequency(s: u32, bits: u32) -> u32 {
let mask = 1u32
.checked_shl(bits)
.and_then(|v| v.checked_sub(1))
.expect("bits < 32 for valid rANS state");
s & mask
}
#[cfg(test)]
fn cumulative_frequencies_symbol(
cumulative_frequencies: &[u32; ALPHABET_SIZE],
frequency: u32,
) -> u8 {
let mut sym: u8 = 0;
while sym < 255
&& frequency >= *cumulative_frequencies.get(usize::from(sym) + 1).unwrap_or(&u32::MAX)
{
sym = sym.wrapping_add(1);
}
sym
}
pub(crate) fn state_step(s: u32, f: u32, g: u32, bits: u32) -> u32 {
let result =
f.wrapping_mul(s >> bits).wrapping_add(s & (1u32.wrapping_shl(bits).wrapping_sub(1)));
result.wrapping_sub(g)
}
#[inline]
pub(crate) fn state_renormalize(mut s: u32, src: &mut &[u8]) -> Option<u32> {
if s < (1 << 15) {
s = s.wrapping_shl(16) | u32::from(read_u16_le_prv(src)?);
if s < (1 << 15) {
s = s.wrapping_shl(16) | u32::from(read_u16_le_prv(src)?);
while s < (1 << 15) {
s = s.wrapping_shl(16) | u32::from(read_u16_le_prv(src)?);
}
}
}
Some(s)
}
#[allow(
clippy::indexing_slicing,
reason = "sym is u8 so usize::from(sym) ≤ 255 < ALPHABET_SIZE=256"
)]
fn read_alphabet(src: &mut &[u8]) -> Result<[bool; ALPHABET_SIZE], CramError> {
let truncated = || CramError::Truncated { context: "rans_nx16 alphabet" };
let mut alphabet = [false; ALPHABET_SIZE];
let mut sym = read_u8(src).ok_or_else(truncated)?;
let mut prev_sym = sym;
loop {
alphabet[usize::from(sym)] = true;
sym = read_u8(src).ok_or_else(truncated)?;
if sym == 0 {
break;
}
if sym == prev_sym.wrapping_add(1) {
let len = read_u8(src).ok_or_else(truncated)?;
let end = u32::from(sym).wrapping_add(u32::from(len));
if end > 255 {
return Err(CramError::MalformedAlphabetRun { start: sym, len });
}
for _ in 0..len {
alphabet[usize::from(sym)] = true;
sym = sym.wrapping_add(1);
}
}
prev_sym = sym;
}
Ok(alphabet)
}
pub(crate) const ORDER_0_BITS: u32 = 12;
fn decode_order_0(src: &mut &[u8], dst: &mut [u8], state_count: usize) -> Result<(), CramError> {
if state_count == 32 {
return decode_order_0_32state(src, dst);
}
decode_order_0_generic(src, dst, state_count)
}
#[allow(clippy::indexing_slicing, reason = "sym ≤ 255 (u8), f < 4096 (12-bit mask)")]
fn decode_order_0_32state(src: &mut &[u8], dst: &mut [u8]) -> Result<(), CramError> {
let frequencies = read_frequencies_0(src)?;
let cumulative_frequencies = build_cumulative_frequencies(&frequencies);
let sym_table = build_symbol_table_nx16(&cumulative_frequencies);
let mut states = read_states(src, 32)?;
#[cfg(target_arch = "aarch64")]
{
let states_snapshot = states.clone();
let src_snapshot = *src;
unsafe {
if super::rans_nx16_neon::decode_32state_loop(
src,
dst,
&frequencies,
&cumulative_frequencies,
&sym_table,
&mut states,
)
.is_ok()
{
return Ok(());
}
}
states = states_snapshot;
*src = src_snapshot;
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
let states_snapshot = states.clone();
let src_snapshot = *src;
unsafe {
if super::rans_nx16_avx2::decode_32state_loop(
src,
dst,
&frequencies,
&cumulative_frequencies,
&sym_table,
&mut states,
)
.is_ok()
{
return Ok(());
}
}
states = states_snapshot;
*src = src_snapshot;
}
}
let truncated = || CramError::Truncated { context: "rans_nx16 order-0 truncated" };
for chunk in dst.chunks_mut(32) {
for (d, state) in chunk.iter_mut().zip(states.iter_mut()) {
let f = state_cumulative_frequency(*state, ORDER_0_BITS);
let sym = sym_table[f as usize];
*d = sym;
let i = usize::from(sym);
*state = state_step(
*state,
*frequencies.get(i).unwrap_or(&0),
*cumulative_frequencies.get(i).unwrap_or(&0),
ORDER_0_BITS,
);
*state = state_renormalize(*state, src).ok_or_else(truncated)?;
}
}
Ok(())
}
fn decode_order_0_generic(
src: &mut &[u8],
dst: &mut [u8],
state_count: usize,
) -> Result<(), CramError> {
let truncated = || CramError::Truncated { context: "rans_nx16 order-0 truncated" };
let frequencies = read_frequencies_0(src)?;
let cumulative_frequencies = build_cumulative_frequencies(&frequencies);
let sym_table = build_symbol_table_nx16(&cumulative_frequencies);
let mut states = read_states(src, state_count)?;
for chunk in dst.chunks_mut(states.len()) {
for (d, state) in chunk.iter_mut().zip(states.iter_mut()) {
let f = state_cumulative_frequency(*state, ORDER_0_BITS);
debug_assert!((f as usize) < sym_table.len());
#[allow(clippy::indexing_slicing, reason = "bounds checked by debug_assert above")]
let sym = sym_table[f as usize];
*d = sym;
let i = usize::from(sym);
*state = state_step(
*state,
*frequencies.get(i).unwrap_or(&0),
*cumulative_frequencies.get(i).unwrap_or(&0),
ORDER_0_BITS,
);
*state = state_renormalize(*state, src).ok_or_else(truncated)?;
}
}
Ok(())
}
fn read_frequencies_0(src: &mut &[u8]) -> Result<[u32; ALPHABET_SIZE], CramError> {
let alphabet = read_alphabet(src)?;
let mut frequencies = [0u32; ALPHABET_SIZE];
for (i, frequency) in alphabet.iter().zip(&mut frequencies) {
if *i {
*frequency = read_uint7(src)?;
}
}
normalize_frequencies(&mut frequencies, ORDER_0_BITS)?;
Ok(frequencies)
}
fn normalize_frequencies(
frequencies: &mut [u32; ALPHABET_SIZE],
bits: u32,
) -> Result<(), CramError> {
let sum: u32 = frequencies.iter().copied().try_fold(0u32, |acc, f| acc.checked_add(f)).ok_or(
CramError::FrequencyNormalizationOverflow {
sum: frequencies.iter().copied().map(u64::from).sum(),
},
)?;
if sum == 0 || sum == (1 << bits) {
return Ok(());
}
let mut running = sum;
let mut shift = 0u32;
while running < (1 << bits) {
running = running
.checked_mul(2)
.ok_or(CramError::FrequencyNormalizationOverflow { sum: u64::from(sum) << shift })?;
shift = shift
.checked_add(1)
.ok_or(CramError::FrequencyNormalizationOverflow { sum: u64::from(sum) })?;
}
for f in frequencies {
*f <<= shift;
}
Ok(())
}
fn build_cumulative_frequencies(frequencies: &[u32; ALPHABET_SIZE]) -> [u32; ALPHABET_SIZE] {
let mut cumulative_frequencies = [0u32; ALPHABET_SIZE];
let mut f = 0u32;
for (i, cf) in cumulative_frequencies.iter_mut().enumerate().skip(1) {
let prev = i.checked_sub(1).expect("skip(1) guarantees i ≥ 1");
f = f.saturating_add(*frequencies.get(prev).unwrap_or(&0));
*cf = f;
}
cumulative_frequencies
}
#[allow(
clippy::indexing_slicing,
reason = "sym ≤ 254 (loop guard), sym+1 ≤ 255 < ALPHABET_SIZE=256"
)]
fn build_symbol_table_nx16(cum: &[u32; ALPHABET_SIZE]) -> [u8; 4096] {
let mut table = [0u8; 4096];
build_symbol_table_nx16_into(cum, &mut table);
table
}
#[allow(
clippy::indexing_slicing,
reason = "sym ≤ 254 (loop guard), sym+1 ≤ 255 < ALPHABET_SIZE=256"
)]
fn build_symbol_table_nx16_into(cum: &[u32; ALPHABET_SIZE], table: &mut [u8; 4096]) {
let mut sym = 0u8;
for (f, entry) in (0u32..4096).zip(table.iter_mut()) {
while sym < 255 && f >= *cum.get(usize::from(sym).wrapping_add(1)).unwrap_or(&u32::MAX) {
sym = sym.wrapping_add(1);
}
*entry = sym;
}
}
fn decode_order_1_with_buf(
src: &mut &[u8],
dst: &mut [u8],
state_count: usize,
buf: &mut Nx16Order1Buf,
) -> Result<(), CramError> {
let bits = read_frequencies_1(src, &mut buf.frequencies)?;
build_cumulative_frequencies_1_into(&buf.frequencies, &mut buf.cumulative_frequencies);
for (cum, table) in buf.cumulative_frequencies.iter().zip(buf.sym_tables.iter_mut()) {
build_symbol_table_nx16_into(cum, table);
}
let states = &mut buf.states;
states.clear();
(0..state_count)
.try_for_each(|_| -> Option<()> {
states.push(read_u32_le(src)?);
Some(())
})
.ok_or_else(|| CramError::Truncated { context: "rans_nx16 order-1 states" })?;
let prev_syms = &mut buf.prev_syms;
prev_syms.clear();
prev_syms.resize(state_count, 0);
let chunk_size = dst
.len()
.checked_div(state_count)
.ok_or_else(|| CramError::Truncated { context: "rans_nx16 order-1 zero state count" })?;
#[allow(
clippy::indexing_slicing,
reason = "k/l ≤ 255, state/prev_sym indices from enumerate < state_count, f < 4096"
)]
for i in 0..chunk_size {
for (j, (state, prev_sym)) in states.iter_mut().zip(prev_syms.iter_mut()).enumerate() {
let k = usize::from(*prev_sym);
let f = state_cumulative_frequency(*state, bits);
let sym = buf.sym_tables[k][f as usize];
let out_idx =
j.checked_mul(chunk_size).and_then(|v| v.checked_add(i)).ok_or_else(|| {
CramError::Truncated { context: "rans_nx16 order-1 index overflow" }
})?;
*dst.get_mut(out_idx)
.ok_or_else(|| CramError::Truncated { context: "rans_nx16 order-1 output" })? = sym;
let l = usize::from(sym);
*state =
state_step(*state, buf.frequencies[k][l], buf.cumulative_frequencies[k][l], bits);
*state = state_renormalize(*state, src)
.ok_or_else(|| CramError::Truncated { context: "rans_nx16 order-1 renormalize" })?;
*prev_sym = sym;
}
}
let last_chunk_start = chunk_size.checked_mul(state_count).ok_or_else(|| {
CramError::Truncated { context: "rans_nx16 order-1 last chunk offset overflow" }
})?;
debug_assert!(last_chunk_start <= dst.len());
#[allow(clippy::indexing_slicing, reason = "last_chunk_start bounded by checked_mul above")]
let last_chunk = &mut dst[last_chunk_start..];
if !last_chunk.is_empty() {
let mut state = *states
.last()
.ok_or_else(|| CramError::Truncated { context: "rans_nx16 order-1 last state" })?;
let mut prev_sym = *prev_syms
.last()
.ok_or_else(|| CramError::Truncated { context: "rans_nx16 order-1 last prev_sym" })?;
for d in last_chunk {
let k = usize::from(prev_sym);
debug_assert!(k < ALPHABET_SIZE, "prev_sym {prev_sym} out of range");
let f = state_cumulative_frequency(state, bits);
#[allow(clippy::indexing_slicing, reason = "k < 256, f < 4096")]
let sym = buf.sym_tables[k][f as usize];
*d = sym;
let l = usize::from(sym);
debug_assert!(l < ALPHABET_SIZE, "sym {sym} out of range");
#[allow(clippy::indexing_slicing, reason = "k < 256, l < 256")]
{
state = state_step(
state,
buf.frequencies[k][l],
buf.cumulative_frequencies[k][l],
bits,
);
}
state = state_renormalize(state, src).ok_or_else(|| CramError::Truncated {
context: "rans_nx16 order-1 last renormalize",
})?;
prev_sym = sym;
}
}
Ok(())
}
fn build_cumulative_frequencies_1_into(
frequencies: &Frequencies1,
cf: &mut CumulativeFrequencies1,
) {
for (f, g) in frequencies.iter().zip(cf.iter_mut()) {
*g = build_cumulative_frequencies(f);
}
}
fn read_frequencies_1(src: &mut &[u8], frequencies: &mut Frequencies1) -> Result<u32, CramError> {
let n =
read_u8(src).ok_or_else(|| CramError::Truncated { context: "rans_nx16 freq1 header" })?;
let bits = u32::from(n >> 4);
let is_compressed = (n & 0x01) != 0;
if is_compressed {
let uncompressed_size = read_uint7(src)? as usize;
let compressed_size = read_uint7(src)? as usize;
let mut compressed_data = split_off(src, compressed_size)?;
let mut tmp = vec![0u8; uncompressed_size];
decode_order_0(&mut compressed_data, &mut tmp, 4)?;
read_frequencies_1_inner(&mut &tmp[..], frequencies, bits)?;
} else {
read_frequencies_1_inner(src, frequencies, bits)?;
}
Ok(bits)
}
#[allow(
clippy::indexing_slicing,
reason = "ctx_idx/sym_idx come from enumerate() over [bool;256], so ≤ 255 < ALPHABET_SIZE=256"
)]
fn read_frequencies_1_inner(
src: &mut &[u8],
frequencies: &mut Frequencies1,
bits: u32,
) -> Result<(), CramError> {
let alphabet = read_alphabet(src)?;
for (ctx_idx, ctx_active) in alphabet.iter().enumerate() {
if !*ctx_active {
continue;
}
let fs = &mut frequencies[ctx_idx];
let mut sym_iter = alphabet.iter().enumerate().filter(|(_, b)| **b).peekable();
while let Some((sym_idx, _)) = sym_iter.next() {
let f = read_uint7(src)?;
fs[sym_idx] = f;
if f == 0 {
let n = read_u8(src)
.ok_or_else(|| CramError::Truncated { context: "rans_nx16 freq1 inner run" })?
as usize;
for _ in 0..n {
let _ = sym_iter.next();
}
}
}
normalize_frequencies(fs, bits)?;
}
Ok(())
}
fn decode_stripe_with_buf(
src: &mut &[u8],
uncompressed_size: usize,
buf: &mut Nx16Order1Buf,
) -> Result<Vec<u8>, CramError> {
let chunk_count = read_u8(src)
.ok_or_else(|| CramError::Truncated { context: "rans_nx16 stripe chunk count" })?
as usize;
if chunk_count == 0 {
return Err(CramError::RansStripeZeroChunks);
}
let compressed_sizes: Vec<usize> =
(0..chunk_count).map(|_| read_uint7(src).map(|n| n as usize)).collect::<Result<_, _>>()?;
let q = uncompressed_size.checked_div(chunk_count).ok_or(CramError::RansStripeZeroChunks)?;
let r = uncompressed_size.checked_rem(chunk_count).ok_or(CramError::RansStripeZeroChunks)?;
let uncompressed_sizes: Vec<usize> =
(0..chunk_count).map(|i| if r > i { q.saturating_add(1) } else { q }).collect();
let chunks: Vec<Vec<u8>> = compressed_sizes
.iter()
.zip(&uncompressed_sizes)
.map(|(&cs, &us)| {
let sub = split_off(src, cs)?;
decode_with_buf(sub, us, buf)
})
.collect::<Result<_, _>>()?;
let mut dst = vec![0u8; uncompressed_size];
for (i, chunk) in chunks.iter().enumerate() {
for (j, &s) in chunk.iter().enumerate() {
let idx = j
.checked_mul(chunk_count)
.and_then(|v| v.checked_add(i))
.ok_or(CramError::Truncated { context: "rans_nx16 stripe index overflow" })?;
if let Some(d) = dst.get_mut(idx) {
*d = s;
}
}
}
Ok(dst)
}
struct BitPackContext {
symbol_count: usize,
mapping_table: Vec<u8>,
uncompressed_size: usize,
}
fn read_bit_pack_context(
src: &mut &[u8],
uncompressed_size: usize,
) -> Result<(BitPackContext, usize), CramError> {
let symbol_count = read_u8(src)
.ok_or_else(|| CramError::Truncated { context: "rans_nx16 bit_pack symbol count" })?
as usize;
if symbol_count == 0 {
return Err(CramError::RansBitPackZeroSymbols);
}
let mapping_table = split_off(src, symbol_count)?.to_vec();
let packed_len = read_uint7(src)? as usize;
Ok((BitPackContext { symbol_count, mapping_table, uncompressed_size }, packed_len))
}
fn apply_bit_unpack(src: &[u8], ctx: &BitPackContext) -> Result<Vec<u8>, CramError> {
let mut dst = vec![0u8; ctx.uncompressed_size];
match ctx.symbol_count {
1 => {
let sym = *ctx
.mapping_table
.first()
.ok_or(CramError::Truncated { context: "bit_pack mapping" })?;
dst.fill(sym);
}
2 => unpack(src, &ctx.mapping_table, 8, &mut dst),
3..=4 => unpack(src, &ctx.mapping_table, 4, &mut dst),
5..=16 => unpack(src, &ctx.mapping_table, 2, &mut dst),
n => {
return Err(CramError::RansBitPackTooManySymbols { symbol_count: n });
}
}
Ok(dst)
}
fn unpack(src: &[u8], mapping_table: &[u8], chunk_size: usize, dst: &mut [u8]) {
let bits = u8::BITS as usize;
let shift = bits
.checked_div(chunk_size)
.expect("chunk_size is always 2, 4, or 8 from the match in apply_bit_unpack");
let mask: u8 = (1u8 << shift).wrapping_sub(1);
for (mut s, chunk) in src.iter().copied().zip(dst.chunks_mut(chunk_size)) {
for d in chunk {
let idx = usize::from(s & mask);
*d = mapping_table.get(idx).copied().unwrap_or(0);
s >>= shift;
}
}
}
struct RleContext {
rle_meta: Vec<u8>,
output_len: usize,
}
fn read_rle_context(
src: &mut &[u8],
state_count: usize,
uncompressed_size: usize,
) -> Result<(RleContext, usize), CramError> {
let header = read_uint7(src)?;
let context_size = (header >> 1) as usize;
let is_compressed = (header & 0x01) == 0;
let rle_encoded_len = read_uint7(src)? as usize;
let rle_meta = if is_compressed {
let compressed_size = read_uint7(src)? as usize;
let mut buf = split_off(src, compressed_size)?;
let mut tmp = vec![0u8; context_size];
decode_order_0(&mut buf, &mut tmp, state_count)?;
tmp
} else {
split_off(src, context_size)?.to_vec()
};
Ok((RleContext { rle_meta, output_len: uncompressed_size }, rle_encoded_len))
}
#[allow(
clippy::indexing_slicing,
reason = "sym is u8 so usize::from(sym) ≤ 255 < ALPHABET_SIZE=256"
)]
fn apply_rle(src: &[u8], ctx: &RleContext) -> Result<Vec<u8>, CramError> {
let mut meta_cur: &[u8] = &ctx.rle_meta;
let rle_alphabet = read_rle_alphabet(&mut meta_cur)?;
let mut dst = vec![0u8; ctx.output_len];
let mut dst_iter = dst.iter_mut();
let mut src_iter = src.iter();
while let Some(d) = dst_iter.next() {
let &sym = src_iter.next().ok_or(CramError::Truncated { context: "rans_nx16 rle src" })?;
*d = sym;
if rle_alphabet[usize::from(sym)] {
let len = read_uint7(&mut meta_cur)? as usize;
for e in dst_iter.by_ref().take(len) {
*e = sym;
}
}
}
Ok(dst)
}
#[allow(
clippy::indexing_slicing,
reason = "sym is u8 so usize::from(sym) ≤ 255 < ALPHABET_SIZE=256"
)]
fn read_rle_alphabet(src: &mut &[u8]) -> Result<[bool; ALPHABET_SIZE], CramError> {
let truncated = || CramError::Truncated { context: "rans_nx16 rle alphabet" };
let mut alphabet = [false; ALPHABET_SIZE];
let n = read_u8(src).ok_or_else(truncated)? as usize;
let symbol_count = if n == 0 { ALPHABET_SIZE } else { n };
for _ in 0..symbol_count {
let sym = read_u8(src).ok_or_else(truncated)?;
alphabet[usize::from(sym)] = true;
}
Ok(alphabet)
}
#[cfg(test)]
#[allow(clippy::cast_possible_truncation, reason = "test code")]
#[allow(clippy::arithmetic_side_effects, reason = "test code")]
mod tests {
use super::*;
#[test]
fn rans_stripe_zero_chunks_returns_error() {
let src = vec![
FLAG_STRIPE, 4u8, 0u8, ];
let err = decode(&src, 0).unwrap_err();
assert!(matches!(err, CramError::RansStripeZeroChunks));
}
#[test]
fn rans_bit_pack_zero_symbols_returns_error() {
let src = vec![
FLAG_PACK, 5u8, 0u8, ];
let err = decode(&src, 0).unwrap_err();
assert!(matches!(err, CramError::RansBitPackZeroSymbols));
}
#[test]
fn rans_bit_pack_too_many_symbols_returns_error() {
let symbol_count: u8 = 17;
let packed_len: u8 = 0; let mut src = Vec::new();
src.push(FLAG_PACK); src.push(10u8); src.push(symbol_count); src.extend(std::iter::repeat_n(b'A', symbol_count as usize));
src.push(packed_len);
let mut src2 = Vec::new();
src2.push(FLAG_PACK | FLAG_CAT); src2.push(10u8); src2.push(symbol_count); src2.extend(std::iter::repeat_n(b'A', symbol_count as usize));
src2.push(packed_len);
let err = decode(&src2, 0).unwrap_err();
assert!(
matches!(err, CramError::RansBitPackTooManySymbols { symbol_count: 17 }),
"expected RansBitPackTooManySymbols, got: {err:?}"
);
}
#[test]
fn decode_order_0_noodles_test_vector() {
let src = [
0x00, 0x07, 0x64, 0x65, 0x00, 0x6c, 0x6e, 0x6f, 0x00, 0x73, 0x00, 0x01, 0x01, 0x01, 0x01, 0x03,
0x01, 0x00, 0x26, 0x20, 0x00, 0x00, 0xb8, 0x0a, 0x00, 0x00, 0xd8, 0x0a, 0x00, 0x00,
0x00, 0x04, 0x00,
];
assert_eq!(decode(&src, 0).unwrap(), b"noodles");
}
#[test]
fn decode_order_1_noodles_test_vector() {
let src = [
0x01, 0x4d, 0xa0, 0x00, 0x64, 0x65, 0x00, 0x6c, 0x6e, 0x6f, 0x00, 0x73, 0x00, 0x00, 0x00, 0x01,
0x01, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x01, 0x00,
0x02, 0x00, 0x01, 0x0f, 0x00, 0x02, 0x01, 0x00, 0x01, 0x01, 0x0f, 0x00, 0x02, 0x00,
0x03, 0x0f, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x02, 0x0f, 0x00, 0x00, 0x00,
0x05, 0x10, 0x80, 0x72, 0x60, 0x00, 0x80, 0x8b, 0x5f, 0x00, 0xc0, 0xb0, 0x60, 0x00,
0x40, 0x49, 0x39, 0x00,
];
assert_eq!(
decode(&src, 0).unwrap(),
b"nnnnnnnnnnnnooooooooooooooooddddddddddddddllllllllllllllleeeeeeeeeessssssssss"
);
}
#[test]
fn decode_stripe_noodles_test_vector() {
let src = [
0x08, 0x07, 0x04, 0x17, 0x17, 0x17, 0x15, 0x00, 0x02, 0x6c, 0x6e, 0x00, 0x01, 0x01, 0x00, 0x08,
0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00,
0x00, 0x02, 0x65, 0x6f, 0x00, 0x01, 0x01, 0x00, 0x08, 0x01, 0x00, 0x00, 0x00, 0x01,
0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x02, 0x6f, 0x73, 0x00,
0x01, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x08, 0x01, 0x00, 0x00, 0x80, 0x00, 0x00,
0x00, 0x80, 0x00, 0x00, 0x00, 0x01, 0x64, 0x00, 0x01, 0x00, 0x80, 0x00, 0x00, 0x00,
0x80, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x02, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x22, 0x00, 0x81, 0x11, 0x01, 0x7f, 0x00,
];
assert_eq!(decode(&src, 0).unwrap(), b"noodles");
}
#[test]
fn decode_uncompressed_noodles_test_vector() {
let src = [
0x20, 0x07, 0x6e, 0x6f, 0x6f, 0x64, 0x6c, 0x65, 0x73,
];
assert_eq!(decode(&src, 0).unwrap(), b"noodles");
}
#[test]
fn decode_rle_noodles_test_vector() {
let src = [
0x40, 0x0d, 0x06, 0x06, 0x17, 0x01, 0x07, 0x6f, 0x00, 0x02, 0x01, 0x01, 0x00, 0x00, 0x01, 0x00,
0x00, 0x0c, 0x02, 0x00, 0x00, 0x08, 0x02, 0x00, 0x00, 0x80, 0x00, 0x00, 0x64, 0x65,
0x00, 0x6c, 0x6e, 0x6f, 0x00, 0x73, 0x00, 0x03, 0x01, 0x01, 0x01, 0x01, 0x01, 0x00,
0x3a, 0x20, 0x00, 0x00, 0x7c, 0x20, 0x00, 0x00, 0x52, 0x01, 0x00, 0x00, 0x08, 0x04,
0x00,
];
assert_eq!(decode(&src, 0).unwrap(), b"noooooooodles");
}
#[test]
fn decode_bit_packing_noodles_test_vector() {
let src = [
0x80, 0x07, 0x06, 0x64, 0x65, 0x6c, 0x6e, 0x6f, 0x73, 0x04, 0x04, 0x05, 0x00, 0x12, 0x43, 0x00,
0x01, 0x01, 0x01, 0x01, 0x00, 0x0c, 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x08,
0x02, 0x00, 0x00, 0x04, 0x02, 0x00,
];
assert_eq!(decode(&src, 0).unwrap(), b"noodles");
}
#[test]
fn read_alphabet_run_overflow_returns_error() {
let mut src: &[u8] = &[250, 251, 10, 0];
let err = read_alphabet(&mut src).unwrap_err();
assert!(
matches!(err, CramError::MalformedAlphabetRun { start: 251, len: 10 }),
"expected MalformedAlphabetRun for run from 251 with len=10, got: {err:?}"
);
}
proptest::proptest! {
#[test]
fn read_alphabet_run_bounds_match_spec(start in 1u8..=255, len in 0u8..=255) {
let prev = start.checked_sub(1).expect("start >= 1");
let stream = [prev, start, len, 0];
let mut cur: &[u8] = &stream;
let result = read_alphabet(&mut cur);
let end = u32::from(start) + u32::from(len);
if end > 255 {
proptest::prop_assert!(
matches!(result, Err(CramError::MalformedAlphabetRun { start: s, len: l })
if s == start && l == len),
"expected MalformedAlphabetRun for start={start}, len={len}, got: {result:?}",
);
} else {
let alphabet = result.expect("valid bounds should not error");
proptest::prop_assert!(alphabet[usize::from(prev)], "alphabet[{prev}] must be set");
if len > 0 {
let last = start.checked_add(len.checked_sub(1).expect("len>0"))
.expect("end ≤ 255 so add fits in u8");
proptest::prop_assert!(alphabet[usize::from(last)], "alphabet[{last}] must be set");
}
}
}
}
#[test]
fn read_alphabet_run_exactly_to_255_ok() {
let mut src: &[u8] = &[200, 201, 54, 0];
let alphabet = read_alphabet(&mut src).unwrap();
for s in 200..=255u32 {
assert!(alphabet[s as usize], "alphabet[{s}] should be set");
}
assert!(!alphabet[199], "alphabet[199] should not be set");
assert!(!alphabet[0], "alphabet[0] should not be set (no wraparound)");
}
#[test]
fn read_uint7_overflow_returns_error() {
let mut src: &[u8] = &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x00];
let err = read_uint7(&mut src).unwrap_err();
assert!(matches!(err, CramError::Uint7Overflow));
}
#[test]
fn read_uint7_five_bytes_ok() {
let mut src: &[u8] = &[0x80, 0x80, 0x80, 0x80, 0x01];
let result = read_uint7(&mut src);
assert!(result.is_ok());
}
#[test]
fn read_uint7_spec_vectors() {
let cases: &[(u32, &[u8])] = &[
(0, &[0x00]),
(1, &[0x01]),
(127, &[0x7F]),
(128, &[0x81, 0x00]),
(200, &[0x81, 0x48]),
(16_383, &[0xFF, 0x7F]),
(16_384, &[0x81, 0x80, 0x00]),
(0x12345, &[0x84, 0xC6, 0x45]),
((1u32 << 28) - 1, &[0xFF, 0xFF, 0xFF, 0x7F]),
(1u32 << 28, &[0x81, 0x80, 0x80, 0x80, 0x00]),
(u32::MAX, &[0x8F, 0xFF, 0xFF, 0xFF, 0x7F]),
];
for (val, encoded) in cases {
let mut cur: &[u8] = encoded;
let decoded = read_uint7(&mut cur)
.unwrap_or_else(|e| panic!("decode of {encoded:02x?} failed: {e:?}"));
assert_eq!(decoded, *val, "decoded value mismatch for {encoded:02x?}");
assert!(
cur.is_empty(),
"decoder consumed wrong byte count for {val} encoded as {encoded:02x?}",
);
}
}
proptest::proptest! {
#[test]
fn state_renormalize_matches_spec(
initial_state in 0u32..=u32::MAX,
bytes in proptest::collection::vec(0u8..=255, 0..=16),
) {
let mut src: &[u8] = &bytes;
let initial_src_len = src.len();
let result = state_renormalize(initial_state, &mut src);
if initial_state >= (1 << 15) {
proptest::prop_assert_eq!(result, Some(initial_state));
proptest::prop_assert_eq!(src.len(), initial_src_len);
return Ok(());
}
let mut expected_state = initial_state;
let mut expected_src: &[u8] = &bytes;
while expected_state < (1 << 15) {
let Some((head, rest)) = expected_src.split_first_chunk::<2>() else {
proptest::prop_assert_eq!(result, None,
"expected None when src exhausted mid-renorm");
return Ok(());
};
let lo = u32::from(u16::from_le_bytes(*head));
expected_state = expected_state.wrapping_shl(16) | lo;
expected_src = rest;
}
proptest::prop_assert_eq!(result, Some(expected_state));
proptest::prop_assert_eq!(src.len(), expected_src.len());
}
#[test]
fn read_uint7_roundtrip(val in 0u32..=u32::MAX) {
let mut stream = Vec::with_capacity(256);
encode_uint7_prv(&mut stream, val);
let mut cur: &[u8] = &stream;
let decoded = read_uint7(&mut cur).unwrap();
proptest::prop_assert_eq!(decoded, val);
proptest::prop_assert!(cur.is_empty(), "undecoded trailing bytes");
}
#[test]
fn read_uint7_roundtrip_max_continuation(val in (1u32 << 28)..=u32::MAX) {
let mut stream = Vec::with_capacity(256);
encode_uint7_prv(&mut stream, val);
proptest::prop_assert_eq!(stream.len(), 5, "max-continuation encodes to 5 bytes");
let mut cur: &[u8] = &stream;
let decoded = read_uint7(&mut cur).unwrap();
proptest::prop_assert_eq!(decoded, val);
}
}
#[test]
fn normalize_frequencies_overflow_returns_error() {
let mut frequencies = [0u32; ALPHABET_SIZE];
frequencies.fill(0x0100_0000);
let result = normalize_frequencies(&mut frequencies, ORDER_0_BITS);
assert!(
matches!(result, Err(CramError::FrequencyNormalizationOverflow { .. })),
"expected overflow error, got: {result:?}"
);
}
#[test]
fn build_symbol_table_matches_linear_scan() {
let freq: [u32; 256] = {
let mut f = [0u32; 256];
f[0] = 100;
f[1] = 200;
f[2] = 50;
f[3] = 3646; f[4] = 100;
f
};
let cum = build_cumulative_frequencies(&freq);
let table = build_symbol_table_nx16(&cum);
for f_val in 0u32..4096 {
let expected = cumulative_frequencies_symbol(&cum, f_val);
let actual = table[f_val as usize];
assert_eq!(
actual, expected,
"mismatch at f={f_val}: table gives {actual}, linear scan gives {expected}"
);
}
}
#[test]
fn build_symbol_table_matches_linear_scan_uniform() {
let freq = [16u32; 256]; let cum = build_cumulative_frequencies(&freq);
let table = build_symbol_table_nx16(&cum);
for f_val in 0u32..4096 {
let expected = cumulative_frequencies_symbol(&cum, f_val);
let actual = table[f_val as usize];
assert_eq!(actual, expected);
}
}
proptest::proptest! {
#[test]
fn build_symbol_table_proptest(seeds in proptest::collection::vec(0u32..=4096, 8)) {
let total: u64 = seeds.iter().map(|&s| u64::from(s)).sum();
proptest::prop_assume!(total > 0);
let mut freq = [0u32; 256];
for (i, &s) in seeds.iter().enumerate() {
let scaled = (u64::from(s) * 4096 / total) as u32;
freq[i % 256] = freq[i % 256].saturating_add(scaled);
}
let sum: u32 = freq.iter().sum();
if sum < 4096 {
freq[0] = freq[0].saturating_add(4096 - sum);
}
proptest::prop_assume!(freq.iter().sum::<u32>() == 4096);
let cum = build_cumulative_frequencies(&freq);
let table = build_symbol_table_nx16(&cum);
let mut table_into = [0u8; 4096];
build_symbol_table_nx16_into(&cum, &mut table_into);
proptest::prop_assert_eq!(table, table_into);
for f_val in 0u32..4096 {
let expected = cumulative_frequencies_symbol(&cum, f_val);
let actual = table[f_val as usize];
proptest::prop_assert_eq!(actual, expected, "mismatch at f={}", f_val);
}
}
}
#[allow(clippy::cast_possible_truncation, reason = "len bounded by proptest range")]
#[test]
fn order0_32state_and_generic_produce_same_output() {
let mut stream = Vec::with_capacity(256);
stream.push(0x04); stream.push(4); stream.push(0); stream.push(0); stream.push(0xA0); stream.push(0x20);
for _ in 0..32 {
stream.extend_from_slice(&0x01000000u32.to_le_bytes());
}
let result = decode(&stream, 0).unwrap();
assert_eq!(result, &[0, 0, 0, 0]);
let mut src1: &[u8] = &stream;
let mut src2: &[u8] = &stream;
let flags = read_u8(&mut src1).unwrap();
let _ = read_u8(&mut src2).unwrap();
assert_eq!(flags & FLAG_N32, FLAG_N32);
let _ = read_uint7(&mut src1).unwrap();
let _ = read_uint7(&mut src2).unwrap();
let mut dst1 = vec![0u8; 4];
let mut dst2 = vec![0u8; 4];
decode_order_0_32state(&mut src1, &mut dst1).unwrap();
decode_order_0_generic(&mut src2, &mut dst2, 32).unwrap();
assert_eq!(dst1, dst2);
}
#[allow(
clippy::cast_possible_truncation,
clippy::arithmetic_side_effects,
reason = "val masked to 7 bits; n bounded to ≤5 by loop over 32-bit value shifted by 7"
)]
fn encode_uint7_prv(stream: &mut Vec<u8>, val: u32) {
if val == 0 {
stream.push(0);
return;
}
let mut tmp = [0u8; 5];
let mut n = 0;
let mut v = val;
while v > 0 {
tmp[n] = v as u8 & 0x7F;
v >>= 7;
n += 1;
}
while n > 1 {
n -= 1;
stream.push(tmp[n] | 0x80);
}
stream.push(tmp[0]);
}
#[test]
fn neon_simd_handles_len_32() {
let len = 32;
let mut stream = Vec::with_capacity(256);
stream.push(FLAG_N32);
encode_uint7_prv(&mut stream, len as u32);
stream.extend_from_slice(&[0, 0]);
stream.push(0x80);
stream.push(0x20);
for _ in 0..32 {
stream.extend_from_slice(&0x01000000u32.to_le_bytes());
}
let mut cur: &[u8] = &stream;
read_u8(&mut cur).unwrap();
read_uint7(&mut cur).unwrap();
let mut dst = vec![0u8; len];
decode_order_0_generic(&mut cur, &mut dst, 32).unwrap();
assert_eq!(dst, vec![0u8; 32], "scalar path failed");
let result = decode(&stream, 0).unwrap();
assert_eq!(result, vec![0u8; 32], "SIMD path failed");
}
#[test]
fn neon_fallback_handles_len_128_direct() {
let len = 128;
let mut stream = Vec::with_capacity(256);
stream.push(FLAG_N32);
encode_uint7_prv(&mut stream, len as u32);
stream.extend_from_slice(&[0, 0]);
stream.push(0x80);
stream.push(0x20);
for _ in 0..32 {
stream.extend_from_slice(&0x01000000u32.to_le_bytes());
}
let mut cur: &[u8] = &stream;
read_u8(&mut cur).unwrap();
read_uint7(&mut cur).unwrap();
let mut dst = vec![0u8; len];
decode_order_0_32state(&mut cur, &mut dst).unwrap();
assert!(dst.iter().all(|&b| b == 0), "decode_order_0_32state produced non-zero output");
}
proptest::proptest! {
#[test]
fn simd_matches_scalar_order0_32state(
len in 0usize..1024,
) {
let mut stream = Vec::with_capacity(256);
stream.push(FLAG_N32);
encode_uint7_prv(&mut stream, len as u32);
stream.extend_from_slice(&[0, 0]);
stream.push(0x80);
stream.push(0x20); for _ in 0..32 {
stream.extend_from_slice(&0x01000000u32.to_le_bytes());
}
let simd_result = decode(&stream, 0).unwrap();
assert_eq!(simd_result.len(), len);
let mut cur: &[u8] = &stream;
read_u8(&mut cur).unwrap();
read_uint7(&mut cur).unwrap();
let mut scalar_dst = vec![0u8; len];
decode_order_0_generic(&mut cur, &mut scalar_dst, 32).unwrap();
assert_eq!(simd_result, scalar_dst);
}
#[test]
fn simd_matches_scalar_with_renorm(
renorm_bytes in proptest::collection::vec(0u8..=255, 256..512),
len in 32usize..=128,
) {
let len = (len / 32).checked_mul(32).expect("len ≤ 128 → fits in usize");
let mut stream = Vec::with_capacity(256);
stream.push(FLAG_N32);
encode_uint7_prv(&mut stream, len as u32);
stream.extend_from_slice(&[0, 0]);
stream.clear();
stream.push(FLAG_N32);
encode_uint7_prv(&mut stream, len as u32);
stream.push(0); stream.push(1); stream.push(0); stream.push(0); encode_uint7_prv(&mut stream, 2048); encode_uint7_prv(&mut stream, 2048); for _ in 0..32 {
stream.extend_from_slice(&0x0000_8001u32.to_le_bytes());
}
stream.extend_from_slice(&renorm_bytes);
let mut cur: &[u8] = &stream;
read_u8(&mut cur).unwrap();
read_uint7(&mut cur).unwrap();
let mut scalar_dst = vec![0u8; len];
let scalar_result = decode_order_0_generic(&mut cur, &mut scalar_dst, 32);
let simd_result = decode(&stream, 0);
match (simd_result, scalar_result) {
(Ok(simd_dst), Ok(())) => {
proptest::prop_assert_eq!(simd_dst, scalar_dst);
}
(Err(_), Err(_)) => {
}
(Ok(_), Err(scalar_err)) => {
proptest::prop_assert!(
false,
"SIMD succeeded where scalar failed: {:?}",
scalar_err,
);
}
(Err(simd_err), Ok(())) => {
proptest::prop_assert!(
false,
"SIMD failed ({:?}) where scalar succeeded — SIMD bug",
simd_err,
);
}
}
}
#[test]
fn simd_remainder_uses_correct_lanes(len in 0usize..256) {
let mut stream = Vec::with_capacity(512);
stream.push(FLAG_N32);
encode_uint7_prv(&mut stream, len as u32);
stream.push(5);
stream.push(100);
stream.push(0); encode_uint7_prv(&mut stream, 2048);
encode_uint7_prv(&mut stream, 2048);
for j in 0u32..32 {
let s = (4096 << 12) | (j * 80);
stream.extend_from_slice(&s.to_le_bytes());
}
let mut cur_simd: &[u8] = &stream;
read_u8(&mut cur_simd).unwrap();
let _ = read_uint7(&mut cur_simd).unwrap();
let mut dst_simd = vec![0u8; len];
decode_order_0_32state(&mut cur_simd, &mut dst_simd).unwrap();
let mut cur_gen: &[u8] = &stream;
read_u8(&mut cur_gen).unwrap();
let _ = read_uint7(&mut cur_gen).unwrap();
let mut dst_gen = vec![0u8; len];
decode_order_0_generic(&mut cur_gen, &mut dst_gen, 32).unwrap();
assert_eq!(dst_simd, dst_gen, "SIMD and scalar diverge for len={len}");
}
}
}