use crate::alloc::vec::Vec;
#[derive(Clone, Debug)]
pub struct ProbModel {
pub freq: [u32; 256],
pub cum_freq: [u32; 257],
pub total_bits: u32,
}
impl ProbModel {
pub fn from_freqs(freq: [u16; 256]) -> Self {
let sum: u32 = freq.iter().map(|&x| x as u32).sum();
if sum == 0 {
let uniform = [1u16; 256];
return Self::from_freqs(uniform);
}
let target: u32 = 1 << 16;
let mut scaled_freq = [0u32; 256];
let mut scaled_sum = 0u32;
for i in 0..256 {
if freq[i] > 0 {
let f = (freq[i] as u32 * target) / sum;
let f = if f == 0 { 1 } else { f };
scaled_freq[i] = f;
scaled_sum += f;
}
}
if scaled_sum != target {
for i in (0..256).rev() {
if scaled_freq[i] > 0 {
if scaled_sum > target {
let diff = scaled_sum - target;
if scaled_freq[i] > diff {
scaled_freq[i] -= diff;
break;
}
} else {
let diff = target - scaled_sum;
scaled_freq[i] += diff;
break;
}
}
}
}
let mut cum_freq = [0u32; 257];
let mut acc = 0u32;
for i in 0..256 {
cum_freq[i] = acc;
acc += scaled_freq[i];
}
cum_freq[256] = acc;
Self {
freq: scaled_freq,
cum_freq,
total_bits: 16,
}
}
pub fn from_scaled_freqs(freq: [u32; 256]) -> Self {
let mut cum_freq = [0u32; 257];
let mut acc = 0u32;
for i in 0..256 {
cum_freq[i] = acc;
acc += freq[i];
}
cum_freq[256] = acc;
Self {
freq,
cum_freq,
total_bits: 16,
}
}
pub fn find_symbol(&self, value: u16) -> u8 {
let val = value as u32;
let mut low = 0;
let mut high = 255;
while low < high {
let mid = (low + high + 1) / 2;
if self.cum_freq[mid] <= val {
low = mid;
} else {
high = mid - 1;
}
}
low as u8
}
}
const L_MIN: u32 = 1 << 16;
pub struct RansEncoder {
pub state: u32,
}
impl Default for RansEncoder {
fn default() -> Self {
Self::new()
}
}
impl RansEncoder {
pub fn new() -> Self {
Self { state: L_MIN }
}
pub fn encode(&mut self, model: &ProbModel, symbol: u8, output: &mut Vec<u16>) {
let f = model.freq[symbol as usize] as u32;
let c = model.cum_freq[symbol as usize];
let limit = f << 16;
while self.state >= limit {
output.push((self.state & 0xFFFF) as u16);
self.state >>= 16;
}
self.state = ((self.state / f) << 16) + c + (self.state % f);
}
pub fn finish(self, output: &mut Vec<u16>) {
output.push((self.state & 0xFFFF) as u16);
output.push((self.state >> 16) as u16);
}
}
pub struct RansDecoder {
pub state: u32,
}
impl RansDecoder {
pub fn new(input: &mut Vec<u16>) -> Self {
let high = input.pop().expect("Empty input") as u32;
let low = input.pop().expect("Empty input") as u32;
Self {
state: (high << 16) | low,
}
}
pub fn decode(&mut self, model: &ProbModel, input: &mut Vec<u16>) -> u8 {
let val = (self.state & 0xFFFF) as u16;
let symbol = model.find_symbol(val);
let f = model.freq[symbol as usize] as u32;
let c = model.cum_freq[symbol as usize];
self.state = f * (self.state >> 16) + (val as u32 - c);
while self.state < L_MIN && !input.is_empty() {
self.state = (self.state << 16) | (input.pop().unwrap() as u32);
}
symbol
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rans_roundtrip_basic() {
let mut freqs = [0u16; 256];
freqs[b'a' as usize] = 10;
freqs[b'b' as usize] = 5;
freqs[b'c' as usize] = 1;
let model = ProbModel::from_freqs(freqs);
let data = b"abcbaaaaaaaaaabbbbbc";
let mut encoder = RansEncoder::new();
let mut compressed = Vec::new();
for &b in data.iter().rev() {
encoder.encode(&model, b, &mut compressed);
}
encoder.finish(&mut compressed);
let mut decoder = RansDecoder::new(&mut compressed);
let mut reconstructed = Vec::new();
for _ in 0..data.len() {
reconstructed.push(decoder.decode(&model, &mut compressed));
}
assert_eq!(data.as_slice(), reconstructed.as_slice());
}
#[test]
fn test_rans_long_roundtrip() {
let freqs = [1u16; 256];
let model = ProbModel::from_freqs(freqs);
let data: Vec<u8> = (0..2000).map(|i| (i % 256) as u8).collect();
let mut encoder = RansEncoder::new();
let mut compressed = Vec::new();
for &b in data.iter().rev() {
encoder.encode(&model, b, &mut compressed);
}
encoder.finish(&mut compressed);
let mut decoder = RansDecoder::new(&mut compressed);
let mut reconstructed = Vec::new();
for _ in 0..data.len() {
reconstructed.push(decoder.decode(&model, &mut compressed));
}
assert_eq!(data, reconstructed);
}
}