atlas-archive-core 1.1.0

High-performance compression library with adaptive context modeling (Loom) and .nyx archives
Documentation
//! Asymmetric Numeral Systems (ANS) implementation.
//! Includes rANS and tANS variants for high-performance entropy coding.

use crate::alloc::vec::Vec;

/// A simple probability model for rANS.
/// Maps symbols to frequencies and cumulative frequencies.
/// Total frequency MUST be a power of 2 (target 2^16).
#[derive(Clone, Debug)]
pub struct ProbModel {
    pub freq: [u32; 256],
    /// Cumulative frequencies scaled to 2^16.
    /// Note: cum_freq[256] == 65536, so we use u32.
    pub cum_freq: [u32; 257],
    pub total_bits: u32,
}

impl ProbModel {
    /// Create a model from byte frequencies, scaling them to sum to 2^16.
    pub fn from_freqs(freq: [u16; 256]) -> Self {
        let sum: u32 = freq.iter().map(|&x| x as u32).sum();
        if sum == 0 {
            let uniform = [1u16; 256];
            return Self::from_freqs(uniform);
        }

        // Scale frequencies to sum to 2^16 (65536)
        let target: u32 = 1 << 16;
        let mut scaled_freq = [0u32; 256];
        let mut scaled_sum = 0u32;
        for i in 0..256 {
            if freq[i] > 0 {
                let f = (freq[i] as u32 * target) / sum;
                let f = if f == 0 { 1 } else { f };
                scaled_freq[i] = f;
                scaled_sum += f;
            }
        }

        // Adjust last active symbol to match target exactly
        if scaled_sum != target {
            for i in (0..256).rev() {
                if scaled_freq[i] > 0 {
                    if scaled_sum > target {
                        let diff = scaled_sum - target;
                        if scaled_freq[i] > diff {
                            scaled_freq[i] -= diff;
                            break;
                        }
                    } else {
                        let diff = target - scaled_sum;
                        scaled_freq[i] += diff;
                        break;
                    }
                }
            }
        }

        let mut cum_freq = [0u32; 257];
        let mut acc = 0u32;
        for i in 0..256 {
            cum_freq[i] = acc;
            acc += scaled_freq[i];
        }
        cum_freq[256] = acc;

        Self {
            freq: scaled_freq,
            cum_freq,
            total_bits: 16,
        }
    }

    /// Create a model from frequencies ALREADY summing to 2^16.
    pub fn from_scaled_freqs(freq: [u32; 256]) -> Self {
        let mut cum_freq = [0u32; 257];
        let mut acc = 0u32;
        for i in 0..256 {
            cum_freq[i] = acc;
            acc += freq[i];
        }
        cum_freq[256] = acc;
        Self {
            freq,
            cum_freq,
            total_bits: 16,
        }
    }

    /// Finds a symbol given a cumulative value.
    pub fn find_symbol(&self, value: u16) -> u8 {
        let val = value as u32;
        let mut low = 0;
        let mut high = 255;
        while low < high {
            let mid = (low + high + 1) / 2;
            if self.cum_freq[mid] <= val {
                low = mid;
            } else {
                high = mid - 1;
            }
        }
        low as u8
    }
}

const L_MIN: u32 = 1 << 16;

/// rANS Encoder.
pub struct RansEncoder {
    pub state: u32,
}

impl Default for RansEncoder {
    fn default() -> Self {
        Self::new()
    }
}

impl RansEncoder {
    pub fn new() -> Self {
        Self { state: L_MIN }
    }

    /// Encode a symbol into the state.
    pub fn encode(&mut self, model: &ProbModel, symbol: u8, output: &mut Vec<u16>) {
        let f = model.freq[symbol as usize] as u32;
        let c = model.cum_freq[symbol as usize];

        // Renormalization: state < f * 2^16 to avoid 32-bit overflow.
        let limit = f << 16;
        while self.state >= limit {
            output.push((self.state & 0xFFFF) as u16);
            self.state >>= 16;
        }

        // rANS step: x = (floor(x/f) << 16) + c + (x%f)
        self.state = ((self.state / f) << 16) + c + (self.state % f);
    }

    pub fn finish(self, output: &mut Vec<u16>) {
        output.push((self.state & 0xFFFF) as u16);
        output.push((self.state >> 16) as u16);
    }
}

/// rANS Decoder.
pub struct RansDecoder {
    pub state: u32,
}

impl RansDecoder {
    pub fn new(input: &mut Vec<u16>) -> Self {
        let high = input.pop().expect("Empty input") as u32;
        let low = input.pop().expect("Empty input") as u32;
        Self {
            state: (high << 16) | low,
        }
    }

    pub fn decode(&mut self, model: &ProbModel, input: &mut Vec<u16>) -> u8 {
        let val = (self.state & 0xFFFF) as u16;
        let symbol = model.find_symbol(val);

        let f = model.freq[symbol as usize] as u32;
        let c = model.cum_freq[symbol as usize];

        // rANS step
        self.state = f * (self.state >> 16) + (val as u32 - c);

        // Renormalization
        while self.state < L_MIN && !input.is_empty() {
            self.state = (self.state << 16) | (input.pop().unwrap() as u32);
        }

        symbol
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_rans_roundtrip_basic() {
        let mut freqs = [0u16; 256];
        freqs[b'a' as usize] = 10;
        freqs[b'b' as usize] = 5;
        freqs[b'c' as usize] = 1;
        let model = ProbModel::from_freqs(freqs);

        let data = b"abcbaaaaaaaaaabbbbbc";
        let mut encoder = RansEncoder::new();
        let mut compressed = Vec::new();

        for &b in data.iter().rev() {
            encoder.encode(&model, b, &mut compressed);
        }
        encoder.finish(&mut compressed);

        let mut decoder = RansDecoder::new(&mut compressed);
        let mut reconstructed = Vec::new();
        for _ in 0..data.len() {
            reconstructed.push(decoder.decode(&model, &mut compressed));
        }

        assert_eq!(data.as_slice(), reconstructed.as_slice());
    }

    #[test]
    fn test_rans_long_roundtrip() {
        let freqs = [1u16; 256];
        let model = ProbModel::from_freqs(freqs);

        let data: Vec<u8> = (0..2000).map(|i| (i % 256) as u8).collect();

        let mut encoder = RansEncoder::new();
        let mut compressed = Vec::new();
        for &b in data.iter().rev() {
            encoder.encode(&model, b, &mut compressed);
        }
        encoder.finish(&mut compressed);

        let mut decoder = RansDecoder::new(&mut compressed);
        let mut reconstructed = Vec::new();
        for _ in 0..data.len() {
            reconstructed.push(decoder.decode(&model, &mut compressed));
        }
        assert_eq!(data, reconstructed);
    }
}