use crate::error::CodecError;
const NUM_STREAMS: usize = 4;
const PROB_BITS: u32 = 14;
const PROB_SCALE: u32 = 1 << PROB_BITS;
const RANS_L: u32 = 1 << 23;
const FREQ_TABLE_SIZE: usize = 256 * 4;
const HEADER_SIZE: usize = 4 + FREQ_TABLE_SIZE + 4;
pub fn encode(data: &[u8]) -> Vec<u8> {
if data.is_empty() {
let out = vec![0u8; HEADER_SIZE];
return out;
}
let mut freqs = [0u32; 256];
for &b in data {
freqs[b as usize] += 1;
}
let norm_freqs = normalize_frequencies(&freqs, data.len());
let (cum_freqs, sym_freqs) = build_cum_table(&norm_freqs);
let mut streams: [Vec<u8>; NUM_STREAMS] = std::array::from_fn(|_| Vec::new());
let mut states = [RANS_L; NUM_STREAMS];
for i in (0..data.len()).rev() {
let stream_idx = i % NUM_STREAMS;
let sym = data[i] as usize;
let freq = sym_freqs[sym];
let start = cum_freqs[sym];
if freq == 0 {
continue; }
rans_encode_symbol(
&mut states[stream_idx],
&mut streams[stream_idx],
start,
freq,
);
}
for i in 0..NUM_STREAMS {
let s = states[i];
streams[i].push((s & 0xFF) as u8);
streams[i].push(((s >> 8) & 0xFF) as u8);
streams[i].push(((s >> 16) & 0xFF) as u8);
streams[i].push(((s >> 24) & 0xFF) as u8);
}
let total_compressed: usize = streams.iter().map(|s| s.len()).sum();
let mut out = Vec::with_capacity(HEADER_SIZE + total_compressed + NUM_STREAMS * 4);
out.extend_from_slice(&(data.len() as u32).to_le_bytes());
for &f in &norm_freqs {
out.extend_from_slice(&f.to_le_bytes());
}
let comp_payload_size = total_compressed + NUM_STREAMS * 4; out.extend_from_slice(&(comp_payload_size as u32).to_le_bytes());
for s in &streams {
out.extend_from_slice(&(s.len() as u32).to_le_bytes());
}
for s in &streams {
out.extend_from_slice(s);
}
out
}
pub fn decode(data: &[u8]) -> Result<Vec<u8>, CodecError> {
if data.len() < HEADER_SIZE {
return Err(CodecError::Truncated {
expected: HEADER_SIZE,
actual: data.len(),
});
}
let uncompressed_size = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
if uncompressed_size == 0 {
return Ok(Vec::new());
}
let mut norm_freqs = [0u32; 256];
for (i, freq) in norm_freqs.iter_mut().enumerate() {
let pos = 4 + i * 4;
*freq = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
}
let (cum_freqs, sym_freqs) = build_cum_table(&norm_freqs);
let lookup = build_decode_table(&cum_freqs, &sym_freqs);
let _comp_size = u32::from_le_bytes([
data[HEADER_SIZE - 4],
data[HEADER_SIZE - 3],
data[HEADER_SIZE - 2],
data[HEADER_SIZE - 1],
]) as usize;
let mut pos = HEADER_SIZE;
if pos + NUM_STREAMS * 4 > data.len() {
return Err(CodecError::Truncated {
expected: pos + NUM_STREAMS * 4,
actual: data.len(),
});
}
let mut stream_sizes = [0usize; NUM_STREAMS];
for size in stream_sizes.iter_mut() {
*size =
u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]) as usize;
pos += 4;
}
let mut stream_data: [Vec<u8>; NUM_STREAMS] = std::array::from_fn(|_| Vec::new());
for i in 0..NUM_STREAMS {
let end = pos + stream_sizes[i];
if end > data.len() {
return Err(CodecError::Truncated {
expected: end,
actual: data.len(),
});
}
stream_data[i] = data[pos..end].to_vec();
pos += stream_sizes[i];
}
let mut states = [0u32; NUM_STREAMS];
let mut stream_pos = [0usize; NUM_STREAMS];
for i in 0..NUM_STREAMS {
let s = &stream_data[i];
if s.len() < 4 {
return Err(CodecError::Corrupt {
detail: format!("rANS stream {i} too short for state"),
});
}
let end = s.len();
states[i] = u32::from_le_bytes([s[end - 4], s[end - 3], s[end - 2], s[end - 1]]);
stream_pos[i] = end - 4;
}
let mut output = vec![0u8; uncompressed_size];
for (i, out_byte) in output.iter_mut().enumerate() {
let stream_idx = i % NUM_STREAMS;
let (sym, new_state) =
rans_decode_symbol(states[stream_idx], &lookup, &cum_freqs, &sym_freqs);
*out_byte = sym;
states[stream_idx] = rans_decode_renorm(
new_state,
&stream_data[stream_idx],
&mut stream_pos[stream_idx],
);
}
Ok(output)
}
fn rans_encode_symbol(state: &mut u32, bitstream: &mut Vec<u8>, start: u32, freq: u32) {
let max_state = ((RANS_L >> PROB_BITS) << 8) * freq;
while *state >= max_state {
bitstream.push((*state & 0xFF) as u8);
*state >>= 8;
}
*state = ((*state / freq) << PROB_BITS) + (*state % freq) + start;
}
fn rans_decode_symbol(
state: u32,
lookup: &[u8; PROB_SCALE as usize],
cum_freqs: &[u32; 257],
sym_freqs: &[u32; 256],
) -> (u8, u32) {
let slot = state & (PROB_SCALE - 1);
let sym = lookup[slot as usize];
let start = cum_freqs[sym as usize];
let freq = sym_freqs[sym as usize];
let new_state = freq * (state >> PROB_BITS) + slot - start;
(sym, new_state)
}
fn rans_decode_renorm(mut state: u32, stream: &[u8], pos: &mut usize) -> u32 {
while state < RANS_L && *pos > 0 {
*pos -= 1;
state = (state << 8) | stream[*pos] as u32;
}
state
}
fn normalize_frequencies(freqs: &[u32; 256], total: usize) -> [u32; 256] {
let mut norm = [0u32; 256];
let mut sum = 0u32;
let total_f64 = total as f64;
for i in 0..256 {
if freqs[i] > 0 {
norm[i] = ((freqs[i] as f64 / total_f64 * PROB_SCALE as f64).round() as u32).max(1);
sum += norm[i];
}
}
if sum > 0 {
while sum > PROB_SCALE {
let max_idx = norm
.iter()
.enumerate()
.filter(|(_, f)| **f > 1)
.max_by_key(|(_, f)| **f)
.map(|(i, _)| i)
.unwrap_or(0);
norm[max_idx] -= 1;
sum -= 1;
}
while sum < PROB_SCALE {
let max_idx = norm
.iter()
.enumerate()
.max_by_key(|(_, f)| **f)
.map(|(i, _)| i)
.unwrap_or(0);
norm[max_idx] += 1;
sum += 1;
}
}
norm
}
fn build_cum_table(freqs: &[u32; 256]) -> ([u32; 257], [u32; 256]) {
let mut cum = [0u32; 257];
let sym_freqs = *freqs;
for i in 0..256 {
cum[i + 1] = cum[i] + freqs[i];
}
(cum, sym_freqs)
}
fn build_decode_table(
cum_freqs: &[u32; 257],
_sym_freqs: &[u32; 256],
) -> [u8; PROB_SCALE as usize] {
let mut table = [0u8; PROB_SCALE as usize];
for sym in 0..256u16 {
let start = cum_freqs[sym as usize] as usize;
let end = cum_freqs[sym as usize + 1] as usize;
for entry in table[start..end].iter_mut() {
*entry = sym as u8;
}
}
table
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_roundtrip() {
let encoded = encode(&[]);
let decoded = decode(&encoded).unwrap();
assert!(decoded.is_empty());
}
#[test]
fn single_byte() {
let encoded = encode(&[42]);
let decoded = decode(&encoded).unwrap();
assert_eq!(decoded, vec![42]);
}
#[test]
fn repeated_bytes() {
let data = vec![0u8; 10_000];
let encoded = encode(&data);
let decoded = decode(&encoded).unwrap();
assert_eq!(decoded, data);
let ratio = data.len() as f64 / encoded.len() as f64;
assert!(
ratio > 2.0,
"repeated bytes should compress >2x, got {ratio:.1}x"
);
}
#[test]
fn text_data() {
let text = b"the quick brown fox jumps over the lazy dog. ";
let data: Vec<u8> = text.iter().copied().cycle().take(10_000).collect();
let encoded = encode(&data);
let decoded = decode(&encoded).unwrap();
assert_eq!(decoded, data);
let ratio = data.len() as f64 / encoded.len() as f64;
assert!(ratio > 1.5, "text should compress >1.5x, got {ratio:.1}x");
}
#[test]
fn uniform_random_data() {
let mut data = vec![0u8; 5000];
let mut rng: u64 = 12345;
for byte in &mut data {
rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
*byte = (rng >> 33) as u8;
}
let encoded = encode(&data);
let decoded = decode(&encoded).unwrap();
assert_eq!(decoded, data);
}
#[test]
fn all_byte_values() {
let data: Vec<u8> = (0..=255u8).cycle().take(4096).collect();
let encoded = encode(&data);
let decoded = decode(&encoded).unwrap();
assert_eq!(decoded, data);
}
#[test]
fn skewed_distribution() {
let mut data = vec![0u8; 10_000];
for i in 0..1000 {
data[i * 10] = 1;
}
let encoded = encode(&data);
let decoded = decode(&encoded).unwrap();
assert_eq!(decoded, data);
let ratio = data.len() as f64 / encoded.len() as f64;
assert!(
ratio > 1.5,
"skewed data should compress >1.5x, got {ratio:.1}x"
);
}
#[test]
fn better_than_raw_on_structured() {
let mut data = Vec::with_capacity(10_000);
for i in 0..10_000 {
data.push((i % 16) as u8); }
let encoded = encode(&data);
let decoded = decode(&encoded).unwrap();
assert_eq!(decoded, data);
let ratio = data.len() as f64 / encoded.len() as f64;
assert!(
ratio > 1.5,
"low-entropy data should compress >1.5x, got {ratio:.1}x"
);
}
#[test]
fn truncated_input_errors() {
assert!(decode(&[]).is_err());
assert!(decode(&[1, 0, 0, 0]).is_err()); }
}