arithmetify 0.1.5

A simple implementation of arithmetic coding
Documentation
use std::io::Write;

use rand::{thread_rng, Rng};

use super::*;
use crate::{
    ArithmeticDecoder, ArithmeticEncoder, Distribution, SequenceModel,
};

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Alphabet {
    A,
    B,
    C,
}

struct PD;
impl PD {
    const WEIGHTS: [u32; 4] = [10, 1000, 10, 1];

    fn sample(&self, rng: &mut impl rand::Rng) -> Option<Alphabet> {
        let p = rng.gen_range(0..self.denominator());
        self.symbol_lookup(p)
    }

    fn denominator() -> u32 {
        Self.denominator()
    }

    fn symbol_lookup(p: u32) -> Option<Alphabet> {
        Self.symbol_lookup(p)
    }

    fn numerator_range(symbol: Option<&Alphabet>) -> std::ops::Range<u32> {
        Self.numerator_range(symbol)
    }
}

impl Distribution<Alphabet, u32> for PD {
    fn denominator(&self) -> u32 {
        Self::WEIGHTS.iter().sum()
    }

    fn numerator_range(
        &self,
        symbol: Option<&Alphabet>,
    ) -> std::ops::Range<u32> {
        use Alphabet::*;
        let index = symbol
            .map(|s| match s {
                A => 1,
                B => 2,
                C => 3,
            })
            .unwrap_or(0);

        let cum = Self::WEIGHTS.iter().take(index).sum();
        cum..cum + Self::WEIGHTS[index]
    }

    fn symbol_lookup(&self, p: u32) -> Option<Alphabet> {
        use Alphabet::*;
        let mut cums = vec![0];
        cums.extend(Self::WEIGHTS);
        (1..cums.len()).for_each(|i| cums[i] += cums[i - 1]);
        let idx = cums.binary_search(&p).unwrap_or_else(|idx| idx - 1);

        match idx {
            0 => None,
            1 => Some(A),
            2 => Some(B),
            3 => Some(C),
            _ => panic!(
                "Cummulative probability density (p={p}) is out of bounds (0..{})",
                Self::denominator()
            ),
        }
    }
}

struct SM(Vec<Alphabet>);
impl SM {
    pub fn sample(&mut self, rng: &mut impl rand::Rng) {
        while let Some(s) = self.pd().sample(rng) {
            let p_range = PD::numerator_range(Some(&s));
            assert!(
                p_range.start < p_range.end,
                "pd has empty range for {s:?}"
            );
            self.push(s)
        }
    }

    pub fn into_sequence(self) -> Vec<Alphabet> {
        self.0
    }
}
impl SequenceModel<Alphabet, u32> for SM {
    type ProbabilityDensity = PD;

    fn push(&mut self, symbol: Alphabet) {
        self.0.push(symbol)
    }

    fn pd(&self) -> Self::ProbabilityDensity {
        PD
    }
}

fn test_symbols(symbols: &[Alphabet]) {
    let mut encoder = ArithmeticEncoder32::new();
    let mut sm = SM(Vec::new());
    encoder.encode(&mut sm, symbols.iter().copied());

    let bytes = encoder.finalize();

    let mut decoder = ArithmeticDecoder32::new(bytes);
    let mut sm = SM(Vec::new());
    decoder.decode(&mut sm);

    assert_eq!(&sm.0, &symbols);
}

#[test]
fn test_as() {
    test_symbols(&[Alphabet::A; 320]);
}

#[test]
fn test_cs() {
    test_symbols(&[Alphabet::C; 640]);
}

#[test]
pub fn test_random() {
    let rng = &mut thread_rng();
    let num_tests = 100000;
    let mut last_print = std::time::Instant::now();

    println!();
    for i in 0..num_tests {
        let mut sm = SM(Vec::new());
        sm.sample(rng);
        let symbols = sm.into_sequence();

        test_symbols(&symbols);
        if last_print.elapsed() > std::time::Duration::from_millis(200) {
            print!("\r{}/{num_tests}", i + 1);
            last_print = std::time::Instant::now();
            let _ = std::io::stdout().flush();
        }
    }

    println!("{num_tests}/{num_tests}");
}

#[test]
fn test_encode_by_weight() {
    test_distribution(&[100, 345, 102, 534, 435]);
    test_distribution(&[10]);
    test_distribution(&[1, 1, 1]);
}

fn test_distribution(weights: &[u32]) {
    let rng = &mut thread_rng();
    let mut cum_weights: Box<[u32]> = weights.into();
    for i in 1..cum_weights.len() {
        cum_weights[i] += cum_weights[i - 1];
    }

    let total_weights = *cum_weights
        .last()
        .expect("There should be at least one symbol");

    let mut input = Vec::new();
    loop {
        let p = rng.gen_range(0..total_weights);
        let symbol = match cum_weights.binary_search(&p) {
            Ok(sym) => sym + 1,
            Err(sym) => sym,
        };

        input.push(symbol);
        let 1.. = symbol else { break };
    }

    let mut encoder = ArithmeticEncoder::new();
    for symbol in &input {
        encoder.encode_by_weights(weights.iter().copied(), *symbol);
    }

    let compressed = encoder.finalize();
    let mut decoder = ArithmeticDecoder::new(compressed);

    let mut decoded = vec![];
    loop {
        let symbol = decoder.decode_by_weights(weights.iter().copied());
        decoded.push(symbol);

        let 1.. = symbol else { break };
    }

    assert_eq!(input, decoded.as_slice());
}