constriction 0.4.0

Entropy coders for research and production (Rust and Python).
Documentation
use std::any::type_name;

use constriction::{
    stream::{
        model::{NonContiguousCategoricalEncoderModel, NonContiguousLookupDecoderModel},
        queue::RangeEncoder,
        stack::AnsCoder,
        Code, Decode, Encode,
    },
    BitArray, Pos, Seek,
};
use criterion::{black_box, criterion_group, Criterion};
use num_traits::AsPrimitive;
use rand::{RngCore, SeedableRng};
use rand_xoshiro::Xoshiro256StarStar;

criterion_group!(
    benches,
    round_trip_u32_u64_u16_12,
    round_trip_u32_u64_u16_16,
    round_trip_u16_u32_u8_8,
    round_trip_u16_u32_u16_8,
    round_trip_u16_u32_u16_12
);

#[cfg(not(miri))]
criterion::criterion_main!(benches);
#[cfg(miri)]
fn main() {} // miri currently doesn't seem to be able to run criterion benchmarks as tests.

fn round_trip_u32_u64_u16_12(c: &mut Criterion) {
    round_trip::<u32, u64, u16, 12>(c);
}

fn round_trip_u32_u64_u16_16(c: &mut Criterion) {
    round_trip::<u32, u64, u16, 16>(c);
}

fn round_trip_u16_u32_u8_8(c: &mut Criterion) {
    round_trip::<u16, u32, u8, 8>(c);
}

fn round_trip_u16_u32_u16_8(c: &mut Criterion) {
    round_trip::<u16, u32, u16, 8>(c);
}

fn round_trip_u16_u32_u16_12(c: &mut Criterion) {
    round_trip::<u16, u32, u16, 12>(c);
}

fn round_trip<Word, State, Probability, const PRECISION: usize>(c: &mut Criterion)
where
    Probability: BitArray,
    u32: AsPrimitive<Probability>,
    Word: BitArray + AsPrimitive<Probability> + From<Probability>,
    State: BitArray + AsPrimitive<Word> + AsPrimitive<usize> + From<Word>,
    u64: From<Probability>,
    usize: From<Probability> + AsPrimitive<Probability>,
{
    ans_round_trip::<Word, State, Probability, PRECISION>(c);
    range_round_trip::<Word, State, Probability, PRECISION>(c);
}

fn ans_round_trip<Word, State, Probability, const PRECISION: usize>(c: &mut Criterion)
where
    Probability: BitArray,
    u32: AsPrimitive<Probability>,
    Word: BitArray + AsPrimitive<Probability> + From<Probability>,
    State: BitArray + AsPrimitive<Word> + AsPrimitive<usize> + From<Word>,
    u64: From<Probability>,
    usize: From<Probability> + AsPrimitive<Probability>,
{
    let (symbols, probabilities) = make_symbols_and_probabilities(PRECISION, 100);
    let encoder_model =
        NonContiguousCategoricalEncoderModel::<u16,Probability,PRECISION>::from_symbols_and_nonzero_fixed_point_probabilities(
            symbols.iter().cloned(),&probabilities[..],false
        )
        .unwrap();

    let data = make_data(&symbols, 10_000);
    let mut encoder = AnsCoder::<Word, State>::new();

    let label_suffix = format!(
        "{}_{}_{}_{}",
        type_name::<Word>(),
        type_name::<State>(),
        type_name::<Probability>(),
        PRECISION
    );
    c.bench_function(&format!("ans_encoding_{label_suffix}"), |b| {
        b.iter(|| {
            encoder.clear();
            encoder
                .encode_iid_symbols_reverse(black_box(&data), &encoder_model)
                .unwrap();

            // Access `encoder.state()` and `encoder.bulk()` at an unpredictable position.
            let index = AsPrimitive::<usize>::as_(encoder.state()) % encoder.bulk().len();
            black_box(encoder.bulk()[index]);
        })
    });

    encoder.clear();
    encoder
        .encode_iid_symbols_reverse(black_box(&data), &encoder_model)
        .unwrap();

    let decoder_model =
    NonContiguousLookupDecoderModel::<u16, Probability, _, _, PRECISION>::from_symbols_and_nonzero_fixed_point_probabilities(
        symbols,probabilities,false
    )
    .unwrap();

    let mut backward_decoder = encoder.into_seekable_decoder();
    let reset_snapshot = backward_decoder.pos();

    c.bench_function(&format!("ans_backward_decoding_{label_suffix}"), |b| {
        b.iter(|| {
            backward_decoder.seek(black_box(reset_snapshot)).unwrap();
            let mut checksum = 1234u16;
            for symbol in backward_decoder.decode_iid_symbols(data.len(), &decoder_model) {
                checksum ^= symbol.unwrap();
            }
            black_box(checksum);
        })
    });

    backward_decoder.seek(reset_snapshot).unwrap();
    let decoded = backward_decoder
        .decode_iid_symbols(data.len(), &decoder_model)
        .map(Result::unwrap)
        .collect::<Vec<_>>();
    assert_eq!(decoded, data);
    assert!(backward_decoder.is_empty());

    backward_decoder.seek(reset_snapshot).unwrap();
    let mut forward_decoder = backward_decoder.into_reversed();
    let reset_snapshot = forward_decoder.pos();

    c.bench_function(&format!("ans_forward_decoding_{label_suffix}"), |b| {
        b.iter(|| {
            forward_decoder.seek(black_box(reset_snapshot)).unwrap();
            let mut checksum = 1234u16;
            for symbol in forward_decoder.decode_iid_symbols(data.len(), &decoder_model) {
                checksum ^= symbol.unwrap();
            }
            black_box(checksum);
        })
    });

    forward_decoder.seek(reset_snapshot).unwrap();
    let decoded = forward_decoder
        .decode_iid_symbols(data.len(), &decoder_model)
        .map(Result::unwrap)
        .collect::<Vec<_>>();
    assert_eq!(decoded, data);
    assert!(forward_decoder.is_empty());
}

fn range_round_trip<Word, State, Probability, const PRECISION: usize>(c: &mut Criterion)
where
    Probability: BitArray,
    u32: AsPrimitive<Probability>,
    Word: BitArray + AsPrimitive<Probability> + From<Probability>,
    State: BitArray + AsPrimitive<Word> + AsPrimitive<usize> + From<Word>,
    u64: From<Probability>,
    usize: From<Probability> + AsPrimitive<Probability>,
{
    let (symbols, probabilities) = make_symbols_and_probabilities(PRECISION, 100);
    let encoder_model =
    NonContiguousCategoricalEncoderModel::<u16,Probability,PRECISION>::from_symbols_and_nonzero_fixed_point_probabilities(
            symbols.iter().cloned(),probabilities.iter(),false
        )
        .unwrap();

    let data = make_data(&symbols, 10_000);
    let mut encoder = RangeEncoder::<Word, State>::new();
    let reset_snapshot = encoder.pos();

    let label_suffix = format!(
        "{}_{}_{}_{}",
        type_name::<Word>(),
        type_name::<State>(),
        type_name::<Probability>(),
        PRECISION
    );
    c.bench_function(&format!("range_encoding_{label_suffix}"), |b| {
        b.iter(|| {
            encoder.clear();
            encoder
                .encode_iid_symbols(black_box(&data), &encoder_model)
                .unwrap();

            // Access `encoder.state()` and `encoder.bulk()` at an unpredictable position.
            let index = AsPrimitive::<usize>::as_(encoder.state().lower()) % encoder.bulk().len();
            black_box(encoder.bulk()[index]);
        })
    });

    encoder.clear();
    encoder
        .encode_iid_symbols(black_box(&data), &encoder_model)
        .unwrap();

    let decoder_model =
    NonContiguousLookupDecoderModel::<u16, Probability, _, _, PRECISION>::from_symbols_and_nonzero_fixed_point_probabilities(
        symbols,probabilities,false
    )
    .unwrap();

    let mut decoder = encoder.into_decoder().unwrap();

    c.bench_function(&format!("range_decoding_{label_suffix}"), |b| {
        b.iter(|| {
            decoder.seek(black_box(reset_snapshot)).unwrap();
            let mut checksum = 1234u16;
            for symbol in decoder.decode_iid_symbols(data.len(), &decoder_model) {
                checksum ^= symbol.unwrap();
            }
            black_box(checksum);
        })
    });

    decoder.seek(reset_snapshot).unwrap();
    let decoded = decoder
        .decode_iid_symbols(data.len(), &decoder_model)
        .map(Result::unwrap)
        .collect::<Vec<_>>();
    assert_eq!(decoded, data);
    assert!(decoder.decoder_maybe_exhausted::<PRECISION>());
}

fn make_symbols_and_probabilities<Probability>(
    precision: usize,
    num_symbols: u32,
) -> (Vec<u16>, Vec<Probability>)
where
    Probability: BitArray + Into<u64>,
    u32: AsPrimitive<Probability>,
{
    assert!(precision <= Probability::BITS);
    assert!(num_symbols as u64 <= 1u64 << precision);

    let mut rng = Xoshiro256StarStar::seed_from_u64(1234);
    let mut total_probability = 0;
    let bound_minus_one: u32 = ((2 * (1u64 << precision)) as f64 / num_symbols as f64 - 1.5) as u32;
    assert!(bound_minus_one > 0);

    // Try to get roughly the right total total amount, then fix up later.
    let (symbols, mut probabilities): (Vec<_>, Vec<_>) = (0..num_symbols)
        .map(|_| {
            let probability = ((rng.next_u32() % bound_minus_one) + 1).as_();
            total_probability += probability.into();
            let symbol = rng.next_u32() as u16;
            (symbol, probability)
        })
        .unzip();

    if total_probability < 1u64 << precision {
        let remaining = (1u64 << precision) - total_probability;
        for _ in 0..remaining {
            let index = rng.next_u32() % probabilities.len() as u32;
            let x = &mut probabilities[index as usize];
            *x = *x + Probability::one();
        }
    } else {
        let mut exceeding = total_probability - (1u64 << precision);
        while exceeding != 0 {
            let index = rng.next_u32() % probabilities.len() as u32;
            let x = &mut probabilities[index as usize];
            if *x > Probability::one() {
                *x = *x - Probability::one();
                exceeding -= 1;
            }
        }
    }

    (symbols, probabilities)
}

fn make_data<Symbol: Copy>(symbols: &[Symbol], amt: usize) -> Vec<Symbol> {
    let mut rng = Xoshiro256StarStar::seed_from_u64(5678 ^ amt as u64);
    (0..amt)
        .map(|_| symbols[(rng.next_u32() % symbols.len() as u32) as usize])
        .collect::<Vec<_>>()
}