use alloc::vec::Vec;
use core::{
borrow::Borrow, convert::Infallible, fmt::Debug, iter::Fuse, marker::PhantomData, ops::Deref,
};
use num_traits::AsPrimitive;
use super::{
model::{DecoderModel, EncoderModel},
AsDecoder, Code, Decode, Encode, IntoDecoder, TryCodingError,
};
use crate::{
backends::{
self, AsReadWords, AsSeekReadWords, BoundedReadWords, Cursor, FallibleIteratorReadWords,
IntoReadWords, IntoSeekReadWords, ReadWords, Reverse, WriteWords,
},
bit_array_to_chunks_truncated, BitArray, CoderError, DefaultEncoderError,
DefaultEncoderFrontendError, NonZeroBitArray, Pos, PosSeek, Seek, Stack, UnwrapInfallible,
};
#[derive(Clone)]
pub struct AnsCoder<Word, State, Backend = Vec<Word>>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
{
bulk: Backend,
state: State,
phantom: PhantomData<Word>,
}
pub type DefaultAnsCoder<Backend = Vec<u32>> = AnsCoder<u32, u64, Backend>;
pub type SmallAnsCoder<Backend = Vec<u16>> = AnsCoder<u16, u32, Backend>;
impl<Word, State, Backend> Debug for AnsCoder<Word, State, Backend>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
for<'a> &'a Backend: IntoIterator<Item = &'a Word>,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_list().entries(self.iter_compressed()).finish()
}
}
impl<Word, State, Backend, const PRECISION: usize> IntoDecoder<PRECISION>
for AnsCoder<Word, State, Backend>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
Backend: WriteWords<Word> + IntoReadWords<Word, Stack>,
{
type IntoDecoder = AnsCoder<Word, State, Backend::IntoReadWords>;
fn into_decoder(self) -> Self::IntoDecoder {
AnsCoder {
bulk: self.bulk.into_read_words(),
state: self.state,
phantom: PhantomData,
}
}
}
impl<'a, Word, State, Backend> From<&'a AnsCoder<Word, State, Backend>>
for AnsCoder<Word, State, <Backend as AsReadWords<'a, Word, Stack>>::AsReadWords>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
Backend: AsReadWords<'a, Word, Stack>,
{
fn from(ans: &'a AnsCoder<Word, State, Backend>) -> Self {
AnsCoder {
bulk: ans.bulk().as_read_words(),
state: ans.state(),
phantom: PhantomData,
}
}
}
impl<'a, Word, State, Backend, const PRECISION: usize> AsDecoder<'a, PRECISION>
for AnsCoder<Word, State, Backend>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
Backend: WriteWords<Word> + AsReadWords<'a, Word, Stack>,
{
type AsDecoder = AnsCoder<Word, State, Backend::AsReadWords>;
fn as_decoder(&'a self) -> Self::AsDecoder {
self.into()
}
}
impl<Word, State> From<AnsCoder<Word, State, Vec<Word>>> for Vec<Word>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
{
fn from(val: AnsCoder<Word, State, Vec<Word>>) -> Self {
val.into_compressed().unwrap_infallible()
}
}
impl<Word, State> AnsCoder<Word, State, Vec<Word>>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
{
pub fn new() -> Self {
Self::default()
}
}
impl<Word, State, Backend> Default for AnsCoder<Word, State, Backend>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
Backend: Default,
{
fn default() -> Self {
assert!(State::BITS >= 2 * Word::BITS);
Self {
state: State::zero(),
bulk: Default::default(),
phantom: PhantomData,
}
}
}
impl<Word, State, Backend> AnsCoder<Word, State, Backend>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
{
pub fn from_raw_parts(bulk: Backend, state: State) -> Self {
Self {
bulk,
state,
phantom: PhantomData,
}
}
pub fn from_compressed(mut compressed: Backend) -> Result<Self, Backend>
where
Backend: ReadWords<Word, Stack>,
{
assert!(State::BITS >= 2 * Word::BITS);
let state = match Self::read_initial_state(|| compressed.read()) {
Ok(state) => state,
Err(_) => return Err(compressed),
};
Ok(Self {
bulk: compressed,
state,
phantom: PhantomData,
})
}
fn read_initial_state<Error>(
mut read_word: impl FnMut() -> Result<Option<Word>, Error>,
) -> Result<State, ()>
where
Backend: ReadWords<Word, Stack>,
{
if let Some(first_word) = read_word().map_err(|_| ())? {
if first_word == Word::zero() {
return Err(());
}
let mut state = first_word.into();
while let Some(word) = read_word().map_err(|_| ())? {
state = state << Word::BITS | word.into();
if state >= State::one() << (State::BITS - Word::BITS) {
break;
}
}
Ok(state)
} else {
Ok(State::zero())
}
}
pub fn from_binary(mut data: Backend) -> Result<Self, Backend::ReadError>
where
Backend: ReadWords<Word, Stack>,
{
let mut state = State::one();
while state < State::one() << (State::BITS - Word::BITS) {
if let Some(word) = data.read()? {
state = state << Word::BITS | word.into();
} else {
break;
}
}
Ok(Self {
bulk: data,
state,
phantom: PhantomData,
})
}
#[inline(always)]
pub fn bulk(&self) -> &Backend {
&self.bulk
}
pub fn into_raw_parts(self) -> (Backend, State) {
(self.bulk, self.state)
}
pub fn is_empty(&self) -> bool {
self.state == State::zero()
}
pub fn get_compressed(
&mut self,
) -> Result<impl Deref<Target = Backend> + Debug + Drop + '_, Backend::WriteError>
where
Backend: ReadWords<Word, Stack> + WriteWords<Word> + Debug,
{
CoderGuard::<'_, _, _, _, false>::new(self).map_err(|err| match err {
CoderError::Frontend(()) => unreachable!("Can't happen for SEALED==false."),
CoderError::Backend(err) => err,
})
}
pub fn get_binary(
&mut self,
) -> Result<impl Deref<Target = Backend> + Debug + Drop + '_, CoderError<(), Backend::WriteError>>
where
Backend: ReadWords<Word, Stack> + WriteWords<Word> + Debug,
{
CoderGuard::<'_, _, _, _, true>::new(self)
}
pub fn iter_compressed<'a>(&'a self) -> impl Iterator<Item = Word> + '_
where
&'a Backend: IntoIterator<Item = &'a Word>,
{
let bulk_iter = self.bulk.into_iter().cloned();
let state_iter = bit_array_to_chunks_truncated(self.state).rev();
bulk_iter.chain(state_iter)
}
pub fn num_words(&self) -> usize
where
Backend: BoundedReadWords<Word, Stack>,
{
self.bulk.remaining() + bit_array_to_chunks_truncated::<_, Word>(self.state).len()
}
pub fn num_bits(&self) -> usize
where
Backend: BoundedReadWords<Word, Stack>,
{
Word::BITS * self.num_words()
}
pub fn num_valid_bits(&self) -> usize
where
Backend: BoundedReadWords<Word, Stack>,
{
Word::BITS * self.bulk.remaining()
+ core::cmp::max(State::BITS - self.state.leading_zeros() as usize, 1)
- 1
}
pub fn into_decoder(self) -> AnsCoder<Word, State, Backend::IntoReadWords>
where
Backend: IntoReadWords<Word, Stack>,
{
AnsCoder {
bulk: self.bulk.into_read_words(),
state: self.state,
phantom: PhantomData,
}
}
pub fn into_seekable_decoder(self) -> AnsCoder<Word, State, Backend::IntoSeekReadWords>
where
Backend: IntoSeekReadWords<Word, Stack>,
{
AnsCoder {
bulk: self.bulk.into_seek_read_words(),
state: self.state,
phantom: PhantomData,
}
}
pub fn as_decoder<'a>(&'a self) -> AnsCoder<Word, State, Backend::AsReadWords>
where
Backend: AsReadWords<'a, Word, Stack>,
{
AnsCoder {
bulk: self.bulk.as_read_words(),
state: self.state,
phantom: PhantomData,
}
}
pub fn as_seekable_decoder<'a>(&'a self) -> AnsCoder<Word, State, Backend::AsSeekReadWords>
where
Backend: AsSeekReadWords<'a, Word, Stack>,
{
AnsCoder {
bulk: self.bulk.as_seek_read_words(),
state: self.state,
phantom: PhantomData,
}
}
}
impl<Word, State> AnsCoder<Word, State>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
{
pub fn clear(&mut self) {
self.bulk.clear();
self.state = State::zero();
}
}
impl<'bulk, Word, State> AnsCoder<Word, State, Cursor<Word, &'bulk [Word]>>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
{
#[allow(clippy::result_unit_err)]
pub fn from_compressed_slice(compressed: &'bulk [Word]) -> Result<Self, ()> {
Self::from_compressed(backends::Cursor::new_at_write_end(compressed)).map_err(|_| ())
}
pub fn from_binary_slice(data: &'bulk [Word]) -> Self {
Self::from_binary(backends::Cursor::new_at_write_end(data)).unwrap_infallible()
}
}
impl<Word, State, Buf> AnsCoder<Word, State, Reverse<Cursor<Word, Buf>>>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
Buf: AsRef<[Word]>,
{
pub fn from_reversed_compressed(compressed: Buf) -> Result<Self, Buf> {
Self::from_compressed(Reverse(Cursor::new_at_write_beginning(compressed)))
.map_err(|Reverse(cursor)| cursor.into_buf_and_pos().0)
}
pub fn from_reversed_binary(data: Buf) -> Self {
Self::from_binary(Reverse(Cursor::new_at_write_beginning(data))).unwrap_infallible()
}
}
impl<Word, State, Iter, ReadError> AnsCoder<Word, State, FallibleIteratorReadWords<Iter>>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
Iter: Iterator<Item = Result<Word, ReadError>>,
FallibleIteratorReadWords<Iter>: ReadWords<Word, Stack, ReadError = ReadError>,
{
pub fn from_reversed_compressed_iter(compressed: Iter) -> Result<Self, Fuse<Iter>> {
Self::from_compressed(FallibleIteratorReadWords::new(compressed))
.map_err(|iterator_backend| iterator_backend.into_iter())
}
pub fn from_reversed_binary_iter(data: Iter) -> Result<Self, ReadError> {
Self::from_binary(FallibleIteratorReadWords::new(data))
}
}
impl<Word, State, Backend> AnsCoder<Word, State, Backend>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
Backend: WriteWords<Word>,
{
pub fn encode_symbols_reverse<S, M, I, const PRECISION: usize>(
&mut self,
symbols_and_models: I,
) -> Result<(), DefaultEncoderError<Backend::WriteError>>
where
S: Borrow<M::Symbol>,
M: EncoderModel<PRECISION>,
M::Probability: Into<Word>,
Word: AsPrimitive<M::Probability>,
I: IntoIterator<Item = (S, M)>,
I::IntoIter: DoubleEndedIterator,
{
self.encode_symbols(symbols_and_models.into_iter().rev())
}
pub fn try_encode_symbols_reverse<S, M, E, I, const PRECISION: usize>(
&mut self,
symbols_and_models: I,
) -> Result<(), TryCodingError<DefaultEncoderError<Backend::WriteError>, E>>
where
S: Borrow<M::Symbol>,
M: EncoderModel<PRECISION>,
M::Probability: Into<Word>,
Word: AsPrimitive<M::Probability>,
I: IntoIterator<Item = core::result::Result<(S, M), E>>,
I::IntoIter: DoubleEndedIterator,
{
self.try_encode_symbols(symbols_and_models.into_iter().rev())
}
pub fn encode_iid_symbols_reverse<S, M, I, const PRECISION: usize>(
&mut self,
symbols: I,
model: M,
) -> Result<(), DefaultEncoderError<Backend::WriteError>>
where
S: Borrow<M::Symbol>,
M: EncoderModel<PRECISION> + Copy,
M::Probability: Into<Word>,
Word: AsPrimitive<M::Probability>,
I: IntoIterator<Item = S>,
I::IntoIter: DoubleEndedIterator,
{
self.encode_iid_symbols(symbols.into_iter().rev(), model)
}
pub fn into_compressed(mut self) -> Result<Backend, Backend::WriteError> {
self.bulk
.extend_from_iter(bit_array_to_chunks_truncated(self.state).rev())?;
Ok(self.bulk)
}
pub fn into_binary(mut self) -> Result<Backend, Option<Backend::WriteError>> {
let valid_bits = (State::BITS - 1).wrapping_sub(self.state.leading_zeros() as usize);
if valid_bits % Word::BITS != 0 || valid_bits == usize::max_value() {
Err(None)
} else {
let truncated_state = self.state ^ (State::one() << valid_bits);
self.bulk
.extend_from_iter(bit_array_to_chunks_truncated(truncated_state).rev())?;
Ok(self.bulk)
}
}
}
impl<Word, State, Buf> AnsCoder<Word, State, Cursor<Word, Buf>>
where
Word: BitArray,
State: BitArray + AsPrimitive<Word> + From<Word>,
Buf: AsRef<[Word]> + AsMut<[Word]>,
{
pub fn into_reversed(self) -> AnsCoder<Word, State, Reverse<Cursor<Word, Buf>>> {
let (bulk, state) = self.into_raw_parts();
AnsCoder {
bulk: bulk.into_reversed(),
state,
phantom: PhantomData,
}
}
}
impl<Word, State, Buf> AnsCoder<Word, State, Reverse<Cursor<Word, Buf>>>
where
Word: BitArray,
State: BitArray + AsPrimitive<Word> + From<Word>,
Buf: AsRef<[Word]> + AsMut<[Word]>,
{
pub fn into_reversed(self) -> AnsCoder<Word, State, Cursor<Word, Buf>> {
let (bulk, state) = self.into_raw_parts();
AnsCoder {
bulk: bulk.into_reversed(),
state,
phantom: PhantomData,
}
}
}
impl<Word, State, Backend> Code for AnsCoder<Word, State, Backend>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
{
type Word = Word;
type State = State;
#[inline(always)]
fn state(&self) -> Self::State {
self.state
}
}
impl<Word, State, Backend, const PRECISION: usize> Encode<PRECISION>
for AnsCoder<Word, State, Backend>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
Backend: WriteWords<Word>,
{
type FrontendError = DefaultEncoderFrontendError;
type BackendError = Backend::WriteError;
fn encode_symbol<M>(
&mut self,
symbol: impl Borrow<M::Symbol>,
model: M,
) -> Result<(), DefaultEncoderError<Self::BackendError>>
where
M: EncoderModel<PRECISION>,
M::Probability: Into<Self::Word>,
Self::Word: AsPrimitive<M::Probability>,
{
assert!(State::BITS >= Word::BITS + PRECISION);
let (left_sided_cumulative, probability) = model
.left_cumulative_and_probability(symbol)
.ok_or_else(|| DefaultEncoderFrontendError::ImpossibleSymbol.into_coder_error())?;
if (self.state >> (State::BITS - PRECISION)) >= probability.get().into().into() {
self.bulk.write(self.state.as_())?;
self.state = self.state >> Word::BITS;
}
let remainder = (self.state % probability.get().into().into()).as_().as_();
let prefix = self.state / probability.get().into().into();
let quantile = left_sided_cumulative + remainder;
self.state = prefix << PRECISION | quantile.into().into();
Ok(())
}
fn maybe_full(&self) -> bool {
self.bulk.maybe_full()
}
}
impl<Word, State, Backend, const PRECISION: usize> Decode<PRECISION>
for AnsCoder<Word, State, Backend>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
Backend: ReadWords<Word, Stack>,
{
type FrontendError = Infallible;
type BackendError = Backend::ReadError;
#[inline(always)]
fn decode_symbol<M>(
&mut self,
model: M,
) -> Result<M::Symbol, CoderError<Self::FrontendError, Self::BackendError>>
where
M: DecoderModel<PRECISION>,
M::Probability: Into<Self::Word>,
Self::Word: AsPrimitive<M::Probability>,
{
assert!(State::BITS >= Word::BITS + PRECISION);
let quantile = (self.state % (State::one() << PRECISION)).as_().as_();
let (symbol, left_sided_cumulative, probability) = model.quantile_function(quantile);
let remainder = quantile - left_sided_cumulative;
self.state =
(self.state >> PRECISION) * probability.get().into().into() + remainder.into().into();
if self.state < State::one() << (State::BITS - Word::BITS) {
if let Some(word) = self.bulk.read()? {
self.state = (self.state << Word::BITS) | word.into();
}
}
Ok(symbol)
}
fn maybe_exhausted(&self) -> bool {
self.is_empty()
}
}
impl<Word, State, Backend> PosSeek for AnsCoder<Word, State, Backend>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
Backend: PosSeek,
Self: Code,
{
type Position = (Backend::Position, <Self as Code>::State);
}
impl<Word, State, Backend> Seek for AnsCoder<Word, State, Backend>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
Backend: Seek,
{
fn seek(&mut self, (pos, state): Self::Position) -> Result<(), ()> {
self.bulk.seek(pos)?;
self.state = state;
Ok(())
}
}
impl<Word, State, Backend> Pos for AnsCoder<Word, State, Backend>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
Backend: Pos,
{
fn pos(&self) -> Self::Position {
(self.bulk.pos(), self.state())
}
}
struct CoderGuard<'a, Word, State, Backend, const SEALED: bool>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
Backend: WriteWords<Word> + ReadWords<Word, Stack>,
{
inner: &'a mut AnsCoder<Word, State, Backend>,
}
impl<'a, Word, State, Backend, const SEALED: bool> CoderGuard<'a, Word, State, Backend, SEALED>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
Backend: WriteWords<Word> + ReadWords<Word, Stack>,
{
#[inline(always)]
fn new(
ans: &'a mut AnsCoder<Word, State, Backend>,
) -> Result<Self, CoderError<(), Backend::WriteError>> {
let mut chunks_rev = bit_array_to_chunks_truncated(ans.state);
if SEALED && chunks_rev.next() != Some(Word::one()) {
return Err(CoderError::Frontend(()));
}
for chunk in chunks_rev.rev() {
ans.bulk.write(chunk)?
}
Ok(Self { inner: ans })
}
}
impl<'a, Word, State, Backend, const SEALED: bool> Drop
for CoderGuard<'a, Word, State, Backend, SEALED>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
Backend: WriteWords<Word> + ReadWords<Word, Stack>,
{
fn drop(&mut self) {
let mut chunks_rev = bit_array_to_chunks_truncated(self.inner.state);
if SEALED {
chunks_rev.next();
}
for _ in chunks_rev {
core::mem::drop(self.inner.bulk.read());
}
}
}
impl<'a, Word, State, Backend, const SEALED: bool> Deref
for CoderGuard<'a, Word, State, Backend, SEALED>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
Backend: WriteWords<Word> + ReadWords<Word, Stack>,
{
type Target = Backend;
fn deref(&self) -> &Self::Target {
&self.inner.bulk
}
}
impl<Word, State, Backend, const SEALED: bool> Debug
for CoderGuard<'_, Word, State, Backend, SEALED>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
Backend: WriteWords<Word> + ReadWords<Word, Stack> + Debug,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
Debug::fmt(&**self, f)
}
}
#[cfg(test)]
mod tests {
use super::super::model::{
ContiguousCategoricalEntropyModel, DefaultLeakyQuantizer, IterableEntropyModel,
LeakyQuantizer,
};
use super::*;
extern crate std;
use std::dbg;
use probability::distribution::{Gaussian, Inverse};
use rand_xoshiro::{
rand_core::{RngCore, SeedableRng},
Xoshiro256StarStar,
};
#[test]
fn compress_none() {
let coder1 = DefaultAnsCoder::new();
assert!(coder1.is_empty());
let compressed = coder1.into_compressed().unwrap();
assert!(compressed.is_empty());
let coder2 = DefaultAnsCoder::from_compressed(compressed).unwrap();
assert!(coder2.is_empty());
}
#[test]
fn compress_one() {
generic_compress_few(core::iter::once(5), 1)
}
#[test]
fn compress_two() {
generic_compress_few([2, 8].iter().cloned(), 1)
}
#[test]
fn compress_ten() {
generic_compress_few(0..10, 2)
}
#[test]
fn compress_twenty() {
generic_compress_few(-10..10, 4)
}
fn generic_compress_few<I>(symbols: I, expected_size: usize)
where
I: IntoIterator<Item = i32>,
I::IntoIter: Clone + DoubleEndedIterator,
{
let symbols = symbols.into_iter();
let mut encoder = DefaultAnsCoder::new();
let quantizer = DefaultLeakyQuantizer::new(-127..=127);
let model = quantizer.quantize(Gaussian::new(3.2, 5.1));
encoder.encode_iid_symbols(symbols.clone(), &model).unwrap();
let compressed = encoder.into_compressed().unwrap();
assert_eq!(compressed.len(), expected_size);
let mut decoder = DefaultAnsCoder::from_compressed(compressed).unwrap();
for symbol in symbols.rev() {
assert_eq!(decoder.decode_symbol(&model).unwrap(), symbol);
}
assert!(decoder.is_empty());
}
#[test]
fn compress_many_u32_u64_32() {
generic_compress_many::<u32, u64, u32, 32>();
}
#[test]
fn compress_many_u32_u64_24() {
generic_compress_many::<u32, u64, u32, 24>();
}
#[test]
fn compress_many_u32_u64_16() {
generic_compress_many::<u32, u64, u16, 16>();
}
#[test]
fn compress_many_u32_u64_8() {
generic_compress_many::<u32, u64, u8, 8>();
}
#[test]
fn compress_many_u16_u64_16() {
generic_compress_many::<u16, u64, u16, 16>();
}
#[test]
fn compress_many_u16_u64_12() {
generic_compress_many::<u16, u64, u16, 12>();
}
#[test]
fn compress_many_u16_u64_8() {
generic_compress_many::<u16, u64, u8, 8>();
}
#[test]
fn compress_many_u8_u64_8() {
generic_compress_many::<u8, u64, u8, 8>();
}
#[test]
fn compress_many_u16_u32_16() {
generic_compress_many::<u16, u32, u16, 16>();
}
#[test]
fn compress_many_u16_u32_12() {
generic_compress_many::<u16, u32, u16, 12>();
}
#[test]
fn compress_many_u16_u32_8() {
generic_compress_many::<u16, u32, u8, 8>();
}
#[test]
fn compress_many_u8_u32_8() {
generic_compress_many::<u8, u32, u8, 8>();
}
#[test]
fn compress_many_u8_u16_8() {
generic_compress_many::<u8, u16, u8, 8>();
}
fn generic_compress_many<Word, State, Probability, const PRECISION: usize>()
where
State: BitArray + AsPrimitive<Word>,
Word: BitArray + Into<State> + AsPrimitive<Probability>,
Probability: BitArray + Into<Word> + AsPrimitive<usize> + Into<f64>,
u32: AsPrimitive<Probability>,
usize: AsPrimitive<Probability>,
f64: AsPrimitive<Probability>,
i32: AsPrimitive<Probability>,
{
#[cfg(not(miri))]
const AMT: usize = 1000;
#[cfg(miri)]
const AMT: usize = 100;
let mut symbols_gaussian = Vec::with_capacity(AMT);
let mut means = Vec::with_capacity(AMT);
let mut stds = Vec::with_capacity(AMT);
let mut rng = Xoshiro256StarStar::seed_from_u64(
(Word::BITS as u64).rotate_left(3 * 16)
^ (State::BITS as u64).rotate_left(2 * 16)
^ (Probability::BITS as u64).rotate_left(16)
^ PRECISION as u64,
);
for _ in 0..AMT {
let mean = (200.0 / u32::MAX as f64) * rng.next_u32() as f64 - 100.0;
let std_dev = (10.0 / u32::MAX as f64) * rng.next_u32() as f64 + 0.001;
let quantile = (rng.next_u32() as f64 + 0.5) / (1u64 << 32) as f64;
let dist = Gaussian::new(mean, std_dev);
let symbol = (dist.inverse(quantile).round() as i32).clamp(-127, 127);
symbols_gaussian.push(symbol);
means.push(mean);
stds.push(std_dev);
}
let hist = [
1u32, 186545, 237403, 295700, 361445, 433686, 509456, 586943, 663946, 737772, 1657269,
896675, 922197, 930672, 916665, 0, 0, 0, 0, 0, 723031, 650522, 572300, 494702, 418703,
347600, 1, 283500, 226158, 178194, 136301, 103158, 76823, 55540, 39258, 27988, 54269,
];
let categorical_probabilities = hist.iter().map(|&x| x as f64).collect::<Vec<_>>();
let categorical =
ContiguousCategoricalEntropyModel::<Probability, _, PRECISION>::from_floating_point_probabilities(
&categorical_probabilities,
)
.unwrap();
let mut symbols_categorical = Vec::with_capacity(AMT);
let max_probability = Probability::max_value() >> (Probability::BITS - PRECISION);
for _ in 0..AMT {
let quantile = rng.next_u32().as_() & max_probability;
let symbol = categorical.quantile_function(quantile).0;
symbols_categorical.push(symbol);
}
let mut ans = AnsCoder::<Word, State>::new();
ans.encode_iid_symbols_reverse(&symbols_categorical, &categorical)
.unwrap();
dbg!(
ans.num_valid_bits(),
AMT as f64 * categorical.entropy_base2::<f64>()
);
let quantizer = LeakyQuantizer::<_, _, Probability, PRECISION>::new(-127..=127);
ans.encode_symbols_reverse(symbols_gaussian.iter().zip(&means).zip(&stds).map(
|((&symbol, &mean), &core)| (symbol, quantizer.quantize(Gaussian::new(mean, core))),
))
.unwrap();
dbg!(ans.num_valid_bits());
let compressed = ans.into_compressed().unwrap();
let mut ans = AnsCoder::from_compressed(compressed).unwrap();
let reconstructed_gaussian = ans
.decode_symbols(
means
.iter()
.zip(&stds)
.map(|(&mean, &core)| quantizer.quantize(Gaussian::new(mean, core))),
)
.collect::<Result<Vec<_>, CoderError<Infallible, Infallible>>>()
.unwrap();
let reconstructed_categorical = ans
.decode_iid_symbols(AMT, &categorical)
.collect::<Result<Vec<_>, CoderError<Infallible, Infallible>>>()
.unwrap();
assert!(ans.is_empty());
assert_eq!(symbols_gaussian, reconstructed_gaussian);
assert_eq!(symbols_categorical, reconstructed_categorical);
}
#[test]
fn seek() {
#[cfg(not(miri))]
let (num_chunks, symbols_per_chunk) = (100, 100);
#[cfg(miri)]
let (num_chunks, symbols_per_chunk) = (10, 10);
let quantizer = DefaultLeakyQuantizer::new(-100..=100);
let model = quantizer.quantize(Gaussian::new(0.0, 10.0));
let mut encoder = DefaultAnsCoder::new();
let mut rng = Xoshiro256StarStar::seed_from_u64(123);
let mut symbols = Vec::with_capacity(num_chunks);
let mut jump_table = Vec::with_capacity(num_chunks);
let (initial_pos, initial_state) = encoder.pos();
for _ in 0..num_chunks {
let chunk = (0..symbols_per_chunk)
.map(|_| model.quantile_function(rng.next_u32() % (1 << 24)).0)
.collect::<Vec<_>>();
encoder.encode_iid_symbols_reverse(&chunk, &model).unwrap();
symbols.push(chunk);
jump_table.push(encoder.pos());
}
{
let mut seekable_decoder = encoder.as_seekable_decoder();
for (chunk, &(pos, state)) in symbols.iter().zip(&jump_table).rev() {
assert_eq!(seekable_decoder.pos(), (pos, state));
let decoded = seekable_decoder
.decode_iid_symbols(symbols_per_chunk, &model)
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert_eq!(&decoded, chunk)
}
assert_eq!(seekable_decoder.pos(), (initial_pos, initial_state));
assert!(seekable_decoder.is_empty());
for _ in 0..100 {
let chunk_index = rng.next_u32() as usize % num_chunks;
let (pos, state) = jump_table[chunk_index];
seekable_decoder.seek((pos, state)).unwrap();
let decoded = seekable_decoder
.decode_iid_symbols(symbols_per_chunk, &model)
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert_eq!(&decoded, &symbols[chunk_index])
}
}
let mut compressed = encoder.into_compressed().unwrap();
compressed.reverse();
for (pos, _state) in jump_table.iter_mut() {
*pos = compressed.len() - *pos;
}
let initial_pos = compressed.len() - initial_pos;
{
let mut seekable_decoder = AnsCoder::from_reversed_compressed(compressed).unwrap();
for (chunk, &(pos, state)) in symbols.iter().zip(&jump_table).rev() {
assert_eq!(seekable_decoder.pos(), (pos, state));
let decoded = seekable_decoder
.decode_iid_symbols(symbols_per_chunk, &model)
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert_eq!(&decoded, chunk)
}
assert_eq!(seekable_decoder.pos(), (initial_pos, initial_state));
assert!(seekable_decoder.is_empty());
for _ in 0..100 {
let chunk_index = rng.next_u32() as usize % num_chunks;
let (pos, state) = jump_table[chunk_index];
seekable_decoder.seek((pos, state)).unwrap();
let decoded = seekable_decoder
.decode_iid_symbols(symbols_per_chunk, &model)
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert_eq!(&decoded, &symbols[chunk_index])
}
}
}
}