use alloc::vec::Vec;
use core::{
borrow::Borrow,
fmt::{Debug, Display},
marker::PhantomData,
num::NonZeroUsize,
ops::Deref,
};
use num_traits::AsPrimitive;
use super::{
model::{DecoderModel, EncoderModel},
Code, Decode, Encode, IntoDecoder,
};
use crate::{
backends::{AsReadWords, BoundedReadWords, Cursor, IntoReadWords, ReadWords, WriteWords},
BitArray, CoderError, DefaultEncoderError, DefaultEncoderFrontendError, NonZeroBitArray, Pos,
PosSeek, Queue, Seek, UnwrapInfallible,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct RangeCoderState<Word, State: BitArray> {
lower: State,
range: State::NonZero,
phantom: PhantomData<Word>,
}
impl<Word: BitArray, State: BitArray> RangeCoderState<Word, State> {
#[allow(clippy::result_unit_err)]
pub fn new(lower: State, range: State) -> Result<Self, ()> {
if range >> (State::BITS - Word::BITS) == State::zero() {
Err(())
} else {
Ok(Self {
lower,
range: range.into_nonzero().expect("We checked above."),
phantom: PhantomData,
})
}
}
pub fn lower(&self) -> State {
self.lower
}
pub fn range(&self) -> State::NonZero {
self.range
}
}
impl<Word: BitArray, State: BitArray> Default for RangeCoderState<Word, State> {
fn default() -> Self {
Self {
lower: State::zero(),
range: State::max_value().into_nonzero().expect("max_value() != 0"),
phantom: PhantomData,
}
}
}
#[derive(Debug, Clone)]
pub struct RangeEncoder<Word, State, Backend = Vec<Word>>
where
Word: BitArray,
State: BitArray,
Backend: WriteWords<Word>,
{
bulk: Backend,
state: RangeCoderState<Word, State>,
situation: EncoderSituation<Word>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EncoderSituation<Word> {
Normal,
Inverted(NonZeroUsize, Word),
}
impl<Word> Default for EncoderSituation<Word> {
fn default() -> Self {
Self::Normal
}
}
pub type DefaultRangeEncoder<Backend = Vec<u32>> = RangeEncoder<u32, u64, Backend>;
pub type SmallRangeEncoder<Backend = Vec<u16>> = RangeEncoder<u16, u32, Backend>;
impl<Word, State, Backend> Code for RangeEncoder<Word, State, Backend>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
Backend: WriteWords<Word>,
{
type State = RangeCoderState<Word, State>;
type Word = Word;
fn state(&self) -> Self::State {
self.state
}
}
impl<Word, State, Backend> PosSeek for RangeEncoder<Word, State, Backend>
where
Word: BitArray,
State: BitArray,
Backend: WriteWords<Word> + PosSeek,
Self: Code,
{
type Position = (Backend::Position, <Self as Code>::State);
}
impl<Word, State, Backend> Pos for RangeEncoder<Word, State, Backend>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
Backend: WriteWords<Word> + Pos<Position = usize>,
{
fn pos(&self) -> Self::Position {
let num_inverted = if let EncoderSituation::Inverted(num_inverted, _) = self.situation {
num_inverted.get()
} else {
0
};
(self.bulk.pos() + num_inverted, self.state())
}
}
impl<Word, State, Backend> Default for RangeEncoder<Word, State, Backend>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
Backend: WriteWords<Word> + Default,
{
fn default() -> Self {
Self::with_backend(Backend::default())
}
}
impl<Word, State> RangeEncoder<Word, State>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
{
pub fn new() -> Self {
assert!(State::BITS >= 2 * Word::BITS);
assert_eq!(State::BITS % Word::BITS, 0);
Self {
bulk: Vec::new(),
state: RangeCoderState::default(),
situation: EncoderSituation::Normal,
}
}
}
impl<Word, State> From<RangeEncoder<Word, State>> for Vec<Word>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
{
fn from(val: RangeEncoder<Word, State>) -> Self {
val.into_compressed().unwrap_infallible()
}
}
impl<Word, State, Backend> RangeEncoder<Word, State, Backend>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
Backend: WriteWords<Word>,
{
pub fn with_backend(backend: Backend) -> Self {
assert!(State::BITS >= 2 * Word::BITS);
assert_eq!(State::BITS % Word::BITS, 0);
Self {
bulk: backend,
state: RangeCoderState::default(),
situation: EncoderSituation::Normal,
}
}
pub fn is_empty<'a>(&'a self) -> bool
where
Backend: AsReadWords<'a, Word, Queue>,
Backend::AsReadWords: BoundedReadWords<Word, Queue>,
{
self.state.range.get() == State::max_value() && self.bulk.as_read_words().is_exhausted()
}
pub fn maybe_full(&self) -> bool {
self.bulk.maybe_full()
}
#[allow(clippy::result_unit_err)]
pub fn into_decoder(self) -> Result<RangeDecoder<Word, State, Backend::IntoReadWords>, ()>
where
Backend: IntoReadWords<Word, Queue>,
{
RangeDecoder::from_compressed(self.into_compressed().map_err(|_| ())?).map_err(|_| ())
}
pub fn into_compressed(mut self) -> Result<Backend, Backend::WriteError> {
self.seal()?;
Ok(self.bulk)
}
fn seal(&mut self) -> Result<(), Backend::WriteError> {
if self.state.range.get() == State::max_value() {
return Ok(());
}
let point = self
.state
.lower
.wrapping_add(&((State::one() << (State::BITS - Word::BITS)) - State::one()));
if let EncoderSituation::Inverted(num_inverted, first_inverted_lower_word) = self.situation
{
let (first_word, consecutive_words) = if point < self.state.lower {
(first_inverted_lower_word + Word::one(), Word::zero())
} else {
(first_inverted_lower_word, Word::max_value())
};
self.bulk.write(first_word)?;
for _ in 1..num_inverted.get() {
self.bulk.write(consecutive_words)?;
}
}
let point_word = (point >> (State::BITS - Word::BITS)).as_();
self.bulk.write(point_word)?;
let upper_word = (self.state.lower.wrapping_add(&self.state.range.get())
>> (State::BITS - Word::BITS))
.as_();
if upper_word == point_word {
self.bulk.write(Word::zero())?;
}
Ok(())
}
fn num_seal_words(&self) -> usize {
if self.state.range.get() == State::max_value() {
return 0;
}
let point = self
.state
.lower
.wrapping_add(&((State::one() << (State::BITS - Word::BITS)) - State::one()));
let point_word = (point >> (State::BITS - Word::BITS)).as_();
let upper_word = (self.state.lower.wrapping_add(&self.state.range.get())
>> (State::BITS - Word::BITS))
.as_();
let mut count = if upper_word == point_word { 2 } else { 1 };
if let EncoderSituation::Inverted(num_inverted, _) = self.situation {
count += num_inverted.get();
}
count
}
pub fn num_words<'a>(&'a self) -> usize
where
Backend: AsReadWords<'a, Word, Queue>,
Backend::AsReadWords: BoundedReadWords<Word, Queue>,
{
self.bulk.as_read_words().remaining() + self.num_seal_words()
}
pub fn num_bits<'a>(&'a self) -> usize
where
Backend: AsReadWords<'a, Word, Queue>,
Backend::AsReadWords: BoundedReadWords<Word, Queue>,
{
Word::BITS * self.num_words()
}
pub fn bulk(&self) -> &Backend {
&self.bulk
}
pub fn from_raw_parts(
bulk: Backend,
state: RangeCoderState<Word, State>,
situation: EncoderSituation<Word>,
) -> Self {
assert!(State::BITS >= 2 * Word::BITS);
assert_eq!(State::BITS % Word::BITS, 0);
Self {
bulk,
state,
situation,
}
}
pub fn into_raw_parts(
self,
) -> (
Backend,
RangeCoderState<Word, State>,
EncoderSituation<Word>,
) {
(self.bulk, self.state, self.situation)
}
}
impl<Word, State> RangeEncoder<Word, State>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
{
pub fn clear(&mut self) {
self.bulk.clear();
self.state = RangeCoderState::default();
}
pub fn get_compressed(&mut self) -> EncoderGuard<'_, Word, State> {
EncoderGuard::new(self)
}
pub fn decoder(
&mut self,
) -> RangeDecoder<Word, State, Cursor<Word, EncoderGuard<'_, Word, State>>> {
RangeDecoder::from_compressed(self.get_compressed()).unwrap_infallible()
}
fn unseal(&mut self) {
for _ in 0..self.num_seal_words() {
let word = self.bulk.pop();
debug_assert!(word.is_some());
}
}
}
impl<Word, State, Backend, const PRECISION: usize> IntoDecoder<PRECISION>
for RangeEncoder<Word, State, Backend>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
Backend: WriteWords<Word> + IntoReadWords<Word, Queue>,
{
type IntoDecoder = RangeDecoder<Word, State, Backend::IntoReadWords>;
fn into_decoder(self) -> Self::IntoDecoder {
self.into()
}
}
impl<Word, State, Backend, const PRECISION: usize> Encode<PRECISION>
for RangeEncoder<Word, State, Backend>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
Backend: WriteWords<Word>,
{
type FrontendError = DefaultEncoderFrontendError;
type BackendError = Backend::WriteError;
fn encode_symbol<D>(
&mut self,
symbol: impl Borrow<D::Symbol>,
model: D,
) -> Result<(), DefaultEncoderError<Self::BackendError>>
where
D: EncoderModel<PRECISION>,
D::Probability: Into<Self::Word>,
Self::Word: AsPrimitive<D::Probability>,
{
let (left_sided_cumulative, probability) = model
.left_cumulative_and_probability(symbol)
.ok_or_else(|| DefaultEncoderFrontendError::ImpossibleSymbol.into_coder_error())?;
let scale = self.state.range.get() >> PRECISION;
self.state.range = (scale * probability.get().into().into())
.into_nonzero()
.ok_or_else(|| DefaultEncoderFrontendError::ImpossibleSymbol.into_coder_error())?;
let new_lower = self
.state
.lower
.wrapping_add(&(scale * left_sided_cumulative.into().into()));
if let EncoderSituation::Inverted(num_inverted, first_inverted_lower_word) = self.situation
{
if new_lower.wrapping_add(&self.state.range.get()) > new_lower {
let (first_word, consecutive_words) = if new_lower < self.state.lower {
(first_inverted_lower_word + Word::one(), Word::zero())
} else {
(first_inverted_lower_word, Word::max_value())
};
self.bulk.write(first_word)?;
for _ in 1..num_inverted.get() {
self.bulk.write(consecutive_words)?;
}
self.situation = EncoderSituation::Normal;
}
}
self.state.lower = new_lower;
if self.state.range.get() < State::one() << (State::BITS - Word::BITS) {
self.state.range = unsafe {
(self.state.range.get() << Word::BITS).into_nonzero_unchecked()
};
let lower_word = (self.state.lower >> (State::BITS - Word::BITS)).as_();
self.state.lower = self.state.lower << Word::BITS;
if let EncoderSituation::Inverted(num_inverted, _) = &mut self.situation {
*num_inverted = NonZeroUsize::new(num_inverted.get().wrapping_add(1))
.expect("Cannot encode more symbols than what's addressable with usize.");
} else if self.state.lower.wrapping_add(&self.state.range.get()) > self.state.lower {
self.bulk.write(lower_word)?;
} else {
self.situation =
EncoderSituation::Inverted(NonZeroUsize::new(1).expect("1 != 0"), lower_word);
}
}
Ok(())
}
fn maybe_full(&self) -> bool {
RangeEncoder::maybe_full(self)
}
}
#[derive(Debug, Clone)]
pub struct RangeDecoder<Word, State, Backend>
where
Word: BitArray,
State: BitArray,
Backend: ReadWords<Word, Queue>,
{
bulk: Backend,
state: RangeCoderState<Word, State>,
point: State,
}
pub type DefaultRangeDecoder<Backend = Cursor<u32, Vec<u32>>> = RangeDecoder<u32, u64, Backend>;
pub type SmallRangeDecoder<Backend> = RangeDecoder<u16, u32, Backend>;
impl<Word, State, Backend> RangeDecoder<Word, State, Backend>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
Backend: ReadWords<Word, Queue>,
{
pub fn from_compressed<Buf>(compressed: Buf) -> Result<Self, Backend::ReadError>
where
Buf: IntoReadWords<Word, Queue, IntoReadWords = Backend>,
{
assert!(State::BITS >= 2 * Word::BITS);
assert_eq!(State::BITS % Word::BITS, 0);
let mut bulk = compressed.into_read_words();
let point = Self::read_point(&mut bulk)?;
Ok(RangeDecoder {
bulk,
state: RangeCoderState::default(),
point,
})
}
pub fn with_backend(backend: Backend) -> Result<Self, Backend::ReadError> {
assert!(State::BITS >= 2 * Word::BITS);
assert_eq!(State::BITS % Word::BITS, 0);
let mut bulk = backend;
let point = Self::read_point(&mut bulk)?;
Ok(RangeDecoder {
bulk,
state: RangeCoderState::default(),
point,
})
}
pub fn for_compressed<'a, Buf>(compressed: &'a Buf) -> Result<Self, Backend::ReadError>
where
Buf: AsReadWords<'a, Word, Queue, AsReadWords = Backend>,
{
assert!(State::BITS >= 2 * Word::BITS);
assert_eq!(State::BITS % Word::BITS, 0);
let mut bulk = compressed.as_read_words();
let point = Self::read_point(&mut bulk)?;
Ok(RangeDecoder {
bulk,
state: RangeCoderState::default(),
point,
})
}
pub fn from_raw_parts(
bulk: Backend,
state: RangeCoderState<Word, State>,
point: State,
) -> Result<Self, Backend> {
assert!(State::BITS >= 2 * Word::BITS);
assert_eq!(State::BITS % Word::BITS, 0);
if point.wrapping_sub(&state.lower) >= state.range.get() {
Err(bulk)
} else {
Ok(Self { bulk, state, point })
}
}
pub fn into_raw_parts(self) -> (Backend, RangeCoderState<Word, State>, State) {
(self.bulk, self.state, self.point)
}
fn read_point<B: ReadWords<Word, Queue>>(bulk: &mut B) -> Result<State, B::ReadError> {
let mut num_read = 0;
let mut point = State::zero();
while let Some(word) = bulk.read()? {
point = point << Word::BITS | word.into();
num_read += 1;
if num_read == State::BITS / Word::BITS {
break;
}
}
#[allow(clippy::collapsible_if)]
if num_read < State::BITS / Word::BITS {
if num_read != 0 {
point = point << (State::BITS - num_read * Word::BITS);
}
}
Ok(point)
}
pub fn maybe_exhausted(&self) -> bool {
let max_difference =
((State::one() << (State::BITS - Word::BITS)) << 1).wrapping_sub(&State::one());
self.bulk.maybe_exhausted()
&& (self.state.range.get() == State::max_value()
|| self.point.wrapping_sub(&self.state.lower) < max_difference)
}
}
impl<Word, State, Backend> Code for RangeDecoder<Word, State, Backend>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
Backend: ReadWords<Word, Queue>,
{
type State = RangeCoderState<Word, State>;
type Word = Word;
fn state(&self) -> Self::State {
self.state
}
}
impl<Word, State, Backend> PosSeek for RangeDecoder<Word, State, Backend>
where
Word: BitArray,
State: BitArray,
Backend: ReadWords<Word, Queue>,
Backend: PosSeek,
Self: Code,
{
type Position = (Backend::Position, <Self as Code>::State);
}
impl<Word, State, Backend> Seek for RangeDecoder<Word, State, Backend>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
Backend: ReadWords<Word, Queue> + Seek,
{
fn seek(&mut self, pos_and_state: Self::Position) -> Result<(), ()> {
let (pos, state) = pos_and_state;
self.bulk.seek(pos)?;
self.point = Self::read_point(&mut self.bulk).map_err(|_| ())?;
self.state = state;
Ok(())
}
}
impl<Word, State, Backend> From<RangeEncoder<Word, State, Backend>>
for RangeDecoder<Word, State, Backend::IntoReadWords>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
Backend: WriteWords<Word> + IntoReadWords<Word, Queue>,
{
fn from(encoder: RangeEncoder<Word, State, Backend>) -> Self {
encoder.into_decoder().unwrap()
}
}
impl<Word, State, Backend, const PRECISION: usize> Decode<PRECISION>
for RangeDecoder<Word, State, Backend>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
Backend: ReadWords<Word, Queue>,
{
type FrontendError = DecoderFrontendError;
type BackendError = Backend::ReadError;
fn decode_symbol<D>(
&mut self,
model: D,
) -> Result<D::Symbol, CoderError<Self::FrontendError, Self::BackendError>>
where
D: DecoderModel<PRECISION>,
D::Probability: Into<Self::Word>,
Self::Word: AsPrimitive<D::Probability>,
{
let scale = self.state.range.get() >> PRECISION;
let quantile = self.point.wrapping_sub(&self.state.lower) / scale;
if quantile >= State::one() << PRECISION {
return Err(CoderError::Frontend(DecoderFrontendError::InvalidData));
}
let (symbol, left_sided_cumulative, probability) =
model.quantile_function(quantile.as_().as_());
self.state.lower = self
.state
.lower
.wrapping_add(&(scale * left_sided_cumulative.into().into()));
self.state.range = (scale * probability.get().into().into())
.into_nonzero()
.expect("TODO");
if self.state.range.get() < State::one() << (State::BITS - Word::BITS) {
self.state.lower = self.state.lower << Word::BITS;
self.state.range = unsafe {
(self.state.range.get() << Word::BITS).into_nonzero_unchecked()
};
self.point = self.point << Word::BITS;
if let Some(word) = self.bulk.read()? {
self.point = self.point | word.into();
}
}
Ok(symbol)
}
fn maybe_exhausted(&self) -> bool {
RangeDecoder::maybe_exhausted(self)
}
}
pub struct EncoderGuard<'a, Word, State>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
{
inner: &'a mut RangeEncoder<Word, State>,
}
impl<Word, State> Debug for EncoderGuard<'_, Word, State>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
Debug::fmt(&**self, f)
}
}
impl<'a, Word, State> EncoderGuard<'a, Word, State>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
{
fn new(encoder: &'a mut RangeEncoder<Word, State>) -> Self {
if !encoder.is_empty() {
encoder.seal().unwrap_infallible();
}
Self { inner: encoder }
}
}
impl<'a, Word, State> Drop for EncoderGuard<'a, Word, State>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
{
fn drop(&mut self) {
self.inner.unseal();
}
}
impl<'a, Word, State> Deref for EncoderGuard<'a, Word, State>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
{
type Target = [Word];
fn deref(&self) -> &Self::Target {
&self.inner.bulk
}
}
impl<'a, Word, State> AsRef<[Word]> for EncoderGuard<'a, Word, State>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
{
fn as_ref(&self) -> &[Word] {
self
}
}
#[cfg(test)]
mod tests {
extern crate std;
use std::dbg;
use super::super::model::{
ContiguousCategoricalEntropyModel, IterableEntropyModel, LeakyQuantizer,
};
use super::*;
use probability::distribution::{Gaussian, Inverse};
use rand_xoshiro::{
rand_core::{RngCore, SeedableRng},
Xoshiro256StarStar,
};
#[test]
fn compress_none() {
let encoder = DefaultRangeEncoder::new();
assert!(encoder.is_empty());
let compressed = encoder.into_compressed().unwrap();
assert!(compressed.is_empty());
let decoder = DefaultRangeDecoder::from_compressed(compressed).unwrap();
assert!(decoder.maybe_exhausted());
}
#[test]
fn compress_one() {
generic_compress_few(core::iter::once(5), 1)
}
#[test]
fn compress_two() {
generic_compress_few([2, 8].iter().cloned(), 1)
}
#[test]
fn compress_ten() {
generic_compress_few(0..10, 2)
}
#[test]
fn compress_twenty() {
generic_compress_few(-10..10, 4)
}
fn generic_compress_few<I>(symbols: I, expected_size: usize)
where
I: IntoIterator<Item = i32>,
I::IntoIter: Clone,
{
let symbols = symbols.into_iter();
let mut encoder = DefaultRangeEncoder::new();
let quantizer = LeakyQuantizer::<_, _, u32, 24>::new(-127..=127);
let model = quantizer.quantize(Gaussian::new(3.2, 5.1));
encoder.encode_iid_symbols(symbols.clone(), &model).unwrap();
let compressed = encoder.into_compressed().unwrap();
assert_eq!(compressed.len(), expected_size);
let mut decoder = DefaultRangeDecoder::from_compressed(&compressed).unwrap();
for symbol in symbols {
assert_eq!(decoder.decode_symbol(&model).unwrap(), symbol);
}
assert!(decoder.maybe_exhausted());
}
#[test]
fn compress_many_u32_u64_32() {
generic_compress_many::<u32, u64, u32, 32>();
}
#[test]
fn compress_many_u32_u64_24() {
generic_compress_many::<u32, u64, u32, 24>();
}
#[test]
fn compress_many_u32_u64_16() {
generic_compress_many::<u32, u64, u16, 16>();
}
#[test]
fn compress_many_u32_u64_8() {
generic_compress_many::<u32, u64, u8, 8>();
}
#[test]
fn compress_many_u16_u64_16() {
generic_compress_many::<u16, u64, u16, 16>();
}
#[test]
fn compress_many_u16_u64_12() {
generic_compress_many::<u16, u64, u16, 12>();
}
#[test]
fn compress_many_u16_u64_8() {
generic_compress_many::<u16, u64, u8, 8>();
}
#[test]
fn compress_many_u8_u64_8() {
generic_compress_many::<u8, u64, u8, 8>();
}
#[test]
fn compress_many_u16_u32_16() {
generic_compress_many::<u16, u32, u16, 16>();
}
#[test]
fn compress_many_u16_u32_12() {
generic_compress_many::<u16, u32, u16, 12>();
}
#[test]
fn compress_many_u16_u32_8() {
generic_compress_many::<u16, u32, u8, 8>();
}
#[test]
fn compress_many_u8_u32_8() {
generic_compress_many::<u8, u32, u8, 8>();
}
#[test]
fn compress_many_u8_u16_8() {
generic_compress_many::<u8, u16, u8, 8>();
}
fn generic_compress_many<Word, State, Probability, const PRECISION: usize>()
where
State: BitArray + AsPrimitive<Word>,
Word: BitArray + Into<State> + AsPrimitive<Probability>,
Probability: BitArray + Into<Word> + AsPrimitive<usize> + Into<f64>,
u32: AsPrimitive<Probability>,
usize: AsPrimitive<Probability>,
f64: AsPrimitive<Probability>,
i32: AsPrimitive<Probability>,
{
#[cfg(not(miri))]
const AMT: usize = 1000;
#[cfg(miri)]
const AMT: usize = 100;
let mut symbols_gaussian = Vec::with_capacity(AMT);
let mut means = Vec::with_capacity(AMT);
let mut stds = Vec::with_capacity(AMT);
let mut rng = Xoshiro256StarStar::seed_from_u64(1234);
for _ in 0..AMT {
let mean = (200.0 / u32::MAX as f64) * rng.next_u32() as f64 - 100.0;
let std_dev = (10.0 / u32::MAX as f64) * rng.next_u32() as f64 + 0.001;
let quantile = (rng.next_u32() as f64 + 0.5) / (1u64 << 32) as f64;
let dist = Gaussian::new(mean, std_dev);
let symbol = (dist.inverse(quantile).round() as i32).clamp(-127, 127);
symbols_gaussian.push(symbol);
means.push(mean);
stds.push(std_dev);
}
let hist = [
1u32, 186545, 237403, 295700, 361445, 433686, 509456, 586943, 663946, 737772, 1657269,
896675, 922197, 930672, 916665, 0, 0, 0, 0, 0, 723031, 650522, 572300, 494702, 418703,
347600, 1, 283500, 226158, 178194, 136301, 103158, 76823, 55540, 39258, 27988, 54269,
];
let categorical_probabilities = hist.iter().map(|&x| x as f64).collect::<Vec<_>>();
let categorical =
ContiguousCategoricalEntropyModel::<Probability, _, PRECISION>::from_floating_point_probabilities(
&categorical_probabilities,
)
.unwrap();
let mut symbols_categorical = Vec::with_capacity(AMT);
let max_probability = Probability::max_value() >> (Probability::BITS - PRECISION);
for _ in 0..AMT {
let quantile = rng.next_u32().as_() & max_probability;
let symbol = categorical.quantile_function(quantile).0;
symbols_categorical.push(symbol);
}
let mut encoder = RangeEncoder::<Word, State>::new();
encoder
.encode_iid_symbols(&symbols_categorical, &categorical)
.unwrap();
dbg!(
encoder.num_bits(),
AMT as f64 * categorical.entropy_base2::<f64>()
);
let quantizer = LeakyQuantizer::<_, _, Probability, PRECISION>::new(-127..=127);
encoder
.encode_symbols(symbols_gaussian.iter().zip(&means).zip(&stds).map(
|((&symbol, &mean), &core)| (symbol, quantizer.quantize(Gaussian::new(mean, core))),
))
.unwrap();
dbg!(encoder.num_bits());
let mut decoder = encoder.into_decoder().unwrap();
let reconstructed_categorical = decoder
.decode_iid_symbols(AMT, &categorical)
.collect::<Result<Vec<_>, _>>()
.unwrap();
let reconstructed_gaussian = decoder
.decode_symbols(
means
.iter()
.zip(&stds)
.map(|(&mean, &core)| quantizer.quantize(Gaussian::new(mean, core))),
)
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert!(decoder.maybe_exhausted());
assert_eq!(symbols_categorical, reconstructed_categorical);
assert_eq!(symbols_gaussian, reconstructed_gaussian);
}
#[test]
fn seek() {
#[cfg(not(miri))]
let (num_chunks, symbols_per_chunk) = (100, 100);
#[cfg(miri)]
let (num_chunks, symbols_per_chunk) = (10, 10);
let quantizer = LeakyQuantizer::<_, _, u32, 24>::new(-100..=100);
let model = quantizer.quantize(Gaussian::new(0.0, 10.0));
let mut encoder = DefaultRangeEncoder::new();
let mut rng = Xoshiro256StarStar::seed_from_u64(123);
let mut symbols = Vec::with_capacity(num_chunks);
let mut jump_table = Vec::with_capacity(num_chunks);
for _ in 0..num_chunks {
jump_table.push(encoder.pos());
let chunk = (0..symbols_per_chunk)
.map(|_| model.quantile_function(rng.next_u32() % (1 << 24)).0)
.collect::<Vec<_>>();
encoder.encode_iid_symbols(&chunk, &model).unwrap();
symbols.push(chunk);
}
let final_pos_and_state = encoder.pos();
let mut decoder = encoder.decoder();
for (chunk, _) in symbols.iter().zip(&jump_table) {
let decoded = decoder
.decode_iid_symbols(symbols_per_chunk, &model)
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert_eq!(&decoded, chunk);
}
assert!(decoder.maybe_exhausted());
for i in 0..100 {
let chunk_index = if i == 3 {
0
} else {
rng.next_u32() as usize % num_chunks
};
let pos_and_state = jump_table[chunk_index];
decoder.seek(pos_and_state).unwrap();
let decoded = decoder
.decode_iid_symbols(symbols_per_chunk, &model)
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert_eq!(&decoded, &symbols[chunk_index])
}
decoder.seek(jump_table[0]).unwrap();
assert!(!decoder.maybe_exhausted());
decoder.seek(final_pos_and_state).unwrap();
assert!(decoder.maybe_exhausted());
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum DecoderFrontendError {
InvalidData,
}
impl Display for DecoderFrontendError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::InvalidData => write!(
f,
"Tried to decode from compressed data that is invalid for the employed entropy model."
),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for DecoderFrontendError {}