use alloc::vec::Vec;
use core::{borrow::Borrow, convert::Infallible, fmt::Display};
use num_traits::AsPrimitive;
use super::{
model::{DecoderModel, EncoderModel},
Code, Decode, Encode, TryCodingError,
};
use crate::{
backends::{ReadWords, WriteWords},
BitArray, CoderError, DefaultEncoderFrontendError, NonZeroBitArray, Pos, PosSeek, Seek, Stack,
};
#[derive(Debug, Clone)]
pub struct ChainCoder<Word, State, CompressedBackend, RemaindersBackend, const PRECISION: usize>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
{
compressed: CompressedBackend,
remainders: RemaindersBackend,
heads: ChainCoderHeads<Word, State, PRECISION>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ChainCoderHeads<Word: BitArray, State: BitArray, const PRECISION: usize> {
compressed: Word::NonZero,
remainders: State,
}
impl<Word: BitArray, State: BitArray, const PRECISION: usize>
ChainCoderHeads<Word, State, PRECISION>
{
#[inline(always)]
pub fn is_whole(self) -> bool {
self.compressed.get() == Word::one()
}
fn new<B: ReadWords<Word, Stack>>(
source: &mut B,
push_one: bool,
) -> Result<ChainCoderHeads<Word, State, PRECISION>, CoderError<(), B::ReadError>>
where
Word: Into<State>,
{
assert!(State::BITS >= Word::BITS + PRECISION);
assert!(PRECISION > 0);
assert!(PRECISION <= Word::BITS);
let threshold = State::one() << (State::BITS - Word::BITS - PRECISION);
let mut remainders_head = if push_one {
State::one()
} else {
match source.read()? {
Some(word) if word != Word::zero() => word.into(),
_ => return Err(CoderError::Frontend(())),
}
};
while remainders_head < threshold {
remainders_head = remainders_head << Word::BITS
| source.read()?.ok_or(CoderError::Frontend(()))?.into();
}
Ok(ChainCoderHeads {
compressed: Word::one().into_nonzero().expect("1 != 0"),
remainders: remainders_head,
})
}
}
pub type DefaultChainCoder = ChainCoder<u32, u64, Vec<u32>, Vec<u32>, 24>;
pub type SmallChainCoder = ChainCoder<u16, u32, Vec<u16>, Vec<u16>, 12>;
impl<Word, State, CompressedBackend, RemaindersBackend, const PRECISION: usize>
ChainCoder<Word, State, CompressedBackend, RemaindersBackend, PRECISION>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
{
pub fn from_binary(
mut data: CompressedBackend,
) -> Result<Self, CoderError<CompressedBackend, CompressedBackend::ReadError>>
where
CompressedBackend: ReadWords<Word, Stack>,
RemaindersBackend: Default,
{
let heads = match ChainCoderHeads::new(&mut data, true) {
Ok(heads) => heads,
Err(CoderError::Frontend(())) => return Err(CoderError::Frontend(data)),
Err(CoderError::Backend(err)) => return Err(CoderError::Backend(err)),
};
let remainders = RemaindersBackend::default();
Ok(Self {
compressed: data,
remainders,
heads,
})
}
pub fn from_compressed(
mut compressed: CompressedBackend,
) -> Result<Self, CoderError<CompressedBackend, CompressedBackend::ReadError>>
where
CompressedBackend: ReadWords<Word, Stack>,
RemaindersBackend: Default,
{
let heads = match ChainCoderHeads::new(&mut compressed, false) {
Ok(heads) => heads,
Err(CoderError::Frontend(())) => return Err(CoderError::Frontend(compressed)),
Err(CoderError::Backend(err)) => return Err(CoderError::Backend(err)),
};
let remainders = RemaindersBackend::default();
Ok(Self {
compressed,
remainders,
heads,
})
}
pub fn into_remainders(
mut self,
) -> Result<(CompressedBackend, RemaindersBackend), RemaindersBackend::WriteError>
where
RemaindersBackend: WriteWords<Word>,
{
while self.heads.remainders != State::zero() {
self.remainders.write(self.heads.remainders.as_())?;
self.heads.remainders = self.heads.remainders >> Word::BITS;
}
self.remainders.write(self.heads.compressed.get())?;
Ok((self.compressed, self.remainders))
}
pub fn from_remainders(
mut remainders: RemaindersBackend,
) -> Result<Self, CoderError<RemaindersBackend, RemaindersBackend::ReadError>>
where
RemaindersBackend: ReadWords<Word, Stack>,
CompressedBackend: Default,
{
let compressed_head = match remainders.read()?.and_then(Word::into_nonzero) {
Some(word) => word,
_ => return Err(CoderError::Frontend(remainders)),
};
let mut heads = match ChainCoderHeads::new(&mut remainders, false) {
Ok(heads) => heads,
Err(CoderError::Frontend(())) => return Err(CoderError::Frontend(remainders)),
Err(CoderError::Backend(err)) => return Err(CoderError::Backend(err)),
};
heads.compressed = compressed_head;
let compressed = CompressedBackend::default();
Ok(Self {
compressed,
remainders,
heads,
})
}
pub fn into_compressed(
mut self,
) -> Result<
(RemaindersBackend, CompressedBackend),
CoderError<Self, CompressedBackend::WriteError>,
>
where
CompressedBackend: WriteWords<Word>,
{
if !self.is_whole() {
return Err(CoderError::Frontend(self));
}
while self.heads.remainders != State::zero() {
self.compressed.write(self.heads.remainders.as_())?;
self.heads.remainders = self.heads.remainders >> Word::BITS;
}
Ok((self.remainders, self.compressed))
}
pub fn into_binary(
mut self,
) -> Result<
(RemaindersBackend, CompressedBackend),
CoderError<Self, CompressedBackend::WriteError>,
>
where
CompressedBackend: WriteWords<Word>,
{
if !self.is_whole()
|| (State::BITS - self.heads.remainders.leading_zeros() as usize - 1) % Word::BITS != 0
{
return Err(CoderError::Frontend(self));
}
while self.heads.remainders > State::one() {
self.compressed.write(self.heads.remainders.as_())?;
self.heads.remainders = self.heads.remainders >> Word::BITS;
}
debug_assert!(self.heads.remainders == State::one());
Ok((self.remainders, self.compressed))
}
#[inline(always)]
pub fn is_whole(&self) -> bool {
self.heads.compressed.get() == Word::one()
}
pub fn encode_symbols_reverse<S, M, I>(
&mut self,
symbols_and_models: I,
) -> Result<(), EncoderError<Word, CompressedBackend, RemaindersBackend>>
where
S: Borrow<M::Symbol>,
M: EncoderModel<PRECISION>,
M::Probability: Into<Word>,
Word: AsPrimitive<M::Probability>,
I: IntoIterator<Item = (S, M)>,
I::IntoIter: DoubleEndedIterator,
CompressedBackend: WriteWords<Word>,
RemaindersBackend: ReadWords<Word, Stack>,
{
self.encode_symbols(symbols_and_models.into_iter().rev())
}
pub fn try_encode_symbols_reverse<S, M, E, I>(
&mut self,
symbols_and_models: I,
) -> Result<(), TryCodingError<EncoderError<Word, CompressedBackend, RemaindersBackend>, E>>
where
S: Borrow<M::Symbol>,
M: EncoderModel<PRECISION>,
M::Probability: Into<Word>,
Word: AsPrimitive<M::Probability>,
I: IntoIterator<Item = core::result::Result<(S, M), E>>,
I::IntoIter: DoubleEndedIterator,
CompressedBackend: WriteWords<Word>,
RemaindersBackend: ReadWords<Word, Stack>,
{
self.try_encode_symbols(symbols_and_models.into_iter().rev())
}
#[inline(always)]
pub fn encode_iid_symbols_reverse<S, M, I>(
&mut self,
symbols: I,
model: M,
) -> Result<(), EncoderError<Word, CompressedBackend, RemaindersBackend>>
where
S: Borrow<M::Symbol>,
M: EncoderModel<PRECISION> + Copy,
M::Probability: Into<Word>,
Word: AsPrimitive<M::Probability>,
I: IntoIterator<Item = S>,
I::IntoIter: DoubleEndedIterator,
CompressedBackend: WriteWords<Word>,
RemaindersBackend: ReadWords<Word, Stack>,
{
self.encode_iid_symbols(symbols.into_iter().rev(), model)
}
#[allow(clippy::type_complexity)]
pub fn increase_precision<const NEW_PRECISION: usize>(
mut self,
) -> Result<
ChainCoder<Word, State, CompressedBackend, RemaindersBackend, NEW_PRECISION>,
CoderError<Infallible, BackendError<Infallible, RemaindersBackend::WriteError>>,
>
where
RemaindersBackend: WriteWords<Word>,
{
assert!(NEW_PRECISION >= PRECISION);
assert!(NEW_PRECISION <= Word::BITS);
assert!(State::BITS >= Word::BITS + NEW_PRECISION);
if self.heads.remainders >= State::one() << (State::BITS - NEW_PRECISION) {
self.flush_remainders_head()?;
}
Ok(ChainCoder {
compressed: self.compressed,
remainders: self.remainders,
heads: ChainCoderHeads {
compressed: self.heads.compressed,
remainders: self.heads.remainders,
},
})
}
#[allow(clippy::type_complexity)]
pub fn decrease_precision<const NEW_PRECISION: usize>(
mut self,
) -> Result<
ChainCoder<Word, State, CompressedBackend, RemaindersBackend, NEW_PRECISION>,
CoderError<EncoderFrontendError, BackendError<Infallible, RemaindersBackend::ReadError>>,
>
where
RemaindersBackend: ReadWords<Word, Stack>,
{
assert!(NEW_PRECISION <= PRECISION);
assert!(NEW_PRECISION > 0);
if self.heads.remainders < State::one() << (State::BITS - NEW_PRECISION - Word::BITS) {
self.refill_remainders_head()?
}
Ok(ChainCoder {
compressed: self.compressed,
remainders: self.remainders,
heads: ChainCoderHeads {
compressed: self.heads.compressed,
remainders: self.heads.remainders,
},
})
}
#[inline(always)]
pub fn change_precision<const NEW_PRECISION: usize>(
self,
) -> Result<
ChainCoder<Word, State, CompressedBackend, RemaindersBackend, NEW_PRECISION>,
ChangePrecisionError<Word, RemaindersBackend>,
>
where
RemaindersBackend: WriteWords<Word> + ReadWords<Word, Stack>,
{
if NEW_PRECISION > PRECISION {
self.increase_precision()
.map_err(ChangePrecisionError::Increase)
} else {
self.decrease_precision()
.map_err(ChangePrecisionError::Decrease)
}
}
#[inline(always)]
fn flush_remainders_head<FrontendError, ReadError>(
&mut self,
) -> Result<(), CoderError<FrontendError, BackendError<ReadError, RemaindersBackend::WriteError>>>
where
RemaindersBackend: WriteWords<Word>,
{
self.remainders
.write(self.heads.remainders.as_())
.map_err(|err| CoderError::Backend(BackendError::Remainders(err)))?;
self.heads.remainders = self.heads.remainders >> Word::BITS;
Ok(())
}
#[inline(always)]
fn refill_remainders_head<WriteError>(
&mut self,
) -> Result<
(),
CoderError<EncoderFrontendError, BackendError<WriteError, RemaindersBackend::ReadError>>,
>
where
RemaindersBackend: ReadWords<Word, Stack>,
{
let word = self
.remainders
.read()
.map_err(|err| CoderError::Backend(BackendError::Remainders(err)))?
.ok_or(CoderError::Frontend(EncoderFrontendError::OutOfRemainders))?;
self.heads.remainders = (self.heads.remainders << Word::BITS) | word.into();
Ok(())
}
}
impl<Word, State, CompressedBackend, RemaindersBackend, const PRECISION: usize> Code
for ChainCoder<Word, State, CompressedBackend, RemaindersBackend, PRECISION>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
{
type Word = Word;
type State = ChainCoderHeads<Word, State, PRECISION>;
fn state(&self) -> Self::State {
self.heads
}
}
#[allow(type_alias_bounds)]
pub type DecoderError<
Word,
CompressedBackend: ReadWords<Word, Stack>,
RemaindersBackend: WriteWords<Word>,
> = CoderError<
DecoderFrontendError,
BackendError<CompressedBackend::ReadError, RemaindersBackend::WriteError>,
>;
#[allow(type_alias_bounds)]
pub type EncoderError<
Word,
CompressedBackend: WriteWords<Word>,
RemaindersBackend: ReadWords<Word, Stack>,
> = CoderError<
EncoderFrontendError,
BackendError<CompressedBackend::WriteError, RemaindersBackend::ReadError>,
>;
#[derive(Debug, PartialEq, Eq)]
pub enum DecoderFrontendError {
OutOfCompressedData,
}
impl core::fmt::Display for DecoderFrontendError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::OutOfCompressedData => {
write!(f, "Out of compressed data.")
}
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for DecoderFrontendError {}
#[derive(Debug, PartialEq, Eq)]
pub enum EncoderFrontendError {
OutOfRemainders,
ImpossibleSymbol,
}
impl core::fmt::Display for EncoderFrontendError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::OutOfRemainders => {
write!(f, "Out of remainders information from previous decoding.")
}
Self::ImpossibleSymbol => DefaultEncoderFrontendError::ImpossibleSymbol.fmt(f),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for EncoderFrontendError {}
#[derive(Debug, PartialEq, Eq)]
pub enum BackendError<CompressedBackendError, RemaindersBackendError> {
Compressed(CompressedBackendError),
Remainders(RemaindersBackendError),
}
impl<CompressedBackendError: Display, RemaindersBackendError: Display> core::fmt::Display
for BackendError<CompressedBackendError, RemaindersBackendError>
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::Compressed(err) => {
write!(f, "Read/write error when accessing compressed: {err}")
}
Self::Remainders(err) => {
write!(f, "Read/write error when accessing remainders: {err}")
}
}
}
}
#[cfg(feature = "std")]
impl<
CompressedBackendError: std::error::Error + 'static,
RemaindersBackendError: std::error::Error + 'static,
> std::error::Error for BackendError<CompressedBackendError, RemaindersBackendError>
{
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Compressed(err) => Some(err),
Self::Remainders(err) => Some(err),
}
}
}
#[derive(Debug, PartialEq, Eq)]
pub enum ChangePrecisionError<Word, RemaindersBackend>
where
RemaindersBackend: WriteWords<Word> + ReadWords<Word, Stack>,
{
Increase(CoderError<Infallible, BackendError<Infallible, RemaindersBackend::WriteError>>),
Decrease(
CoderError<EncoderFrontendError, BackendError<Infallible, RemaindersBackend::ReadError>>,
),
}
impl<Word, RemaindersBackend> Display for ChangePrecisionError<Word, RemaindersBackend>
where
RemaindersBackend: WriteWords<Word> + ReadWords<Word, Stack>,
RemaindersBackend::WriteError: Display,
RemaindersBackend::ReadError: Display,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
ChangePrecisionError::Increase(err) => {
write!(f, "Error while increasing precision of chain coder: {err}")
}
ChangePrecisionError::Decrease(err) => {
write!(f, "Error while decreasing precision of chain coder: {err}")
}
}
}
}
#[cfg(feature = "std")]
impl<Word, RemaindersBackend> std::error::Error for ChangePrecisionError<Word, RemaindersBackend>
where
Self: core::fmt::Debug,
RemaindersBackend: WriteWords<Word> + ReadWords<Word, Stack>,
RemaindersBackend::WriteError: std::error::Error + 'static,
RemaindersBackend::ReadError: std::error::Error + 'static,
{
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Increase(err) => Some(err),
Self::Decrease(err) => Some(err),
}
}
}
impl<Word, State, CompressedBackend, RemaindersBackend, const PRECISION: usize> PosSeek
for ChainCoder<Word, State, CompressedBackend, RemaindersBackend, PRECISION>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
CompressedBackend: PosSeek,
RemaindersBackend: PosSeek,
{
type Position = (
BackendPosition<CompressedBackend::Position, RemaindersBackend::Position>,
ChainCoderHeads<Word, State, PRECISION>,
);
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct BackendPosition<CompressedPosition, RemaindersPosition> {
pub compressed: CompressedPosition,
pub remainders: RemaindersPosition,
}
impl<Word, State, CompressedBackend, RemaindersBackend, const PRECISION: usize> Pos
for ChainCoder<Word, State, CompressedBackend, RemaindersBackend, PRECISION>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
CompressedBackend: Pos,
RemaindersBackend: Pos,
{
fn pos(&self) -> Self::Position {
(
BackendPosition {
compressed: self.compressed.pos(),
remainders: self.remainders.pos(),
},
self.state(),
)
}
}
impl<Word, State, CompressedBackend, RemaindersBackend, const PRECISION: usize> Seek
for ChainCoder<Word, State, CompressedBackend, RemaindersBackend, PRECISION>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
CompressedBackend: Seek,
RemaindersBackend: Seek,
{
fn seek(&mut self, (pos, state): Self::Position) -> Result<(), ()> {
self.compressed.seek(pos.compressed)?;
self.remainders.seek(pos.remainders)?;
self.heads = state;
Ok(())
}
}
impl<Word, State, CompressedBackend, RemaindersBackend, const PRECISION: usize> Decode<PRECISION>
for ChainCoder<Word, State, CompressedBackend, RemaindersBackend, PRECISION>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
CompressedBackend: ReadWords<Word, Stack>,
RemaindersBackend: WriteWords<Word>,
{
type FrontendError = DecoderFrontendError;
type BackendError = BackendError<CompressedBackend::ReadError, RemaindersBackend::WriteError>;
fn decode_symbol<M>(
&mut self,
model: M,
) -> Result<M::Symbol, DecoderError<Word, CompressedBackend, RemaindersBackend>>
where
M: DecoderModel<PRECISION>,
M::Probability: Into<Self::Word>,
Self::Word: AsPrimitive<M::Probability>,
{
assert!(PRECISION <= Word::BITS);
assert!(PRECISION != 0);
assert!(State::BITS >= Word::BITS + PRECISION);
let word = if PRECISION == Word::BITS
|| self.heads.compressed.get() < Word::one() << PRECISION
{
let word = self
.compressed
.read()
.map_err(BackendError::Compressed)?
.ok_or(CoderError::Frontend(
DecoderFrontendError::OutOfCompressedData,
))?;
if PRECISION != Word::BITS {
self.heads.compressed = unsafe {
Word::NonZero::new_unchecked(
self.heads.compressed.get() << (Word::BITS - PRECISION) | word >> PRECISION,
)
};
}
word
} else {
let word = self.heads.compressed.get();
self.heads.compressed = unsafe {
Word::NonZero::new_unchecked(self.heads.compressed.get() >> PRECISION)
};
word
};
let quantile = if PRECISION == Word::BITS {
word
} else {
word % (Word::one() << PRECISION)
};
let quantile = quantile.as_();
let (symbol, left_sided_cumulative, probability) = model.quantile_function(quantile);
let remainder = quantile - left_sided_cumulative;
self.heads.remainders =
self.heads.remainders * probability.get().into().into() + remainder.into().into();
if self.heads.remainders >= State::one() << (State::BITS - PRECISION) {
self.flush_remainders_head()?;
}
Ok(symbol)
}
fn maybe_exhausted(&self) -> bool {
self.compressed.maybe_exhausted() || self.remainders.maybe_full()
}
}
impl<Word, State, CompressedBackend, RemaindersBackend, const PRECISION: usize> Encode<PRECISION>
for ChainCoder<Word, State, CompressedBackend, RemaindersBackend, PRECISION>
where
Word: BitArray + Into<State>,
State: BitArray + AsPrimitive<Word>,
CompressedBackend: WriteWords<Word>,
RemaindersBackend: ReadWords<Word, Stack>,
{
type FrontendError = EncoderFrontendError;
type BackendError = BackendError<CompressedBackend::WriteError, RemaindersBackend::ReadError>;
fn encode_symbol<M>(
&mut self,
symbol: impl Borrow<M::Symbol>,
model: M,
) -> Result<(), EncoderError<Word, CompressedBackend, RemaindersBackend>>
where
M: EncoderModel<PRECISION>,
M::Probability: Into<Self::Word>,
Self::Word: AsPrimitive<M::Probability>,
{
assert!(PRECISION <= Word::BITS);
assert!(PRECISION > 0);
let (left_sided_cumulative, probability) = model
.left_cumulative_and_probability(symbol)
.ok_or(CoderError::Frontend(EncoderFrontendError::ImpossibleSymbol))?;
if self.heads.remainders
< probability.get().into().into() << (State::BITS - Word::BITS - PRECISION)
{
self.refill_remainders_head()?;
}
let remainder = (self.heads.remainders % probability.get().into().into())
.as_()
.as_();
let quantile = (left_sided_cumulative + remainder).into();
self.heads.remainders = self.heads.remainders / probability.get().into().into();
if PRECISION != Word::BITS
&& self.heads.compressed.get() < Word::one() << (Word::BITS - PRECISION)
{
unsafe {
self.heads.compressed =
(self.heads.compressed.get() << PRECISION | quantile).into_nonzero_unchecked();
}
} else {
let word = if PRECISION == Word::BITS {
quantile
} else {
let word = self.heads.compressed.get() << PRECISION | quantile;
unsafe {
self.heads.compressed = (self.heads.compressed.get()
>> (Word::BITS - PRECISION))
.into_nonzero_unchecked();
}
word
};
self.compressed
.write(word)
.map_err(BackendError::Compressed)?;
}
Ok(())
}
fn maybe_full(&self) -> bool {
self.remainders.maybe_exhausted() || self.compressed.maybe_full()
}
}
#[cfg(test)]
mod tests {
use super::super::model::LeakyQuantizer;
use super::*;
use probability::distribution::Gaussian;
use rand_xoshiro::{
rand_core::{RngCore, SeedableRng},
Xoshiro256StarStar,
};
use alloc::vec;
#[test]
fn restore_none() {
generic_restore_many::<u32, u64, u32, 24>(4, 0);
}
#[test]
fn restore_one() {
generic_restore_many::<u32, u64, u32, 24>(5, 1);
}
#[test]
fn restore_two() {
generic_restore_many::<u32, u64, u32, 24>(5, 2);
}
#[test]
fn restore_ten() {
generic_restore_many::<u32, u64, u32, 24>(20, 10);
}
#[test]
fn restore_twenty() {
generic_restore_many::<u32, u64, u32, 24>(19, 20);
}
#[test]
fn restore_many_u32_u64_32() {
generic_restore_many::<u32, u64, u32, 32>(1024, 1000);
}
#[test]
fn restore_many_u32_u64_24() {
generic_restore_many::<u32, u64, u32, 24>(1024, 1000);
}
#[test]
fn restore_many_u32_u64_16() {
generic_restore_many::<u32, u64, u16, 16>(1024, 1000);
}
#[test]
fn restore_many_u16_u64_16() {
generic_restore_many::<u16, u64, u16, 16>(1024, 1000);
}
#[test]
fn restore_many_u32_u64_8() {
generic_restore_many::<u32, u64, u8, 8>(1024, 1000);
}
#[test]
fn restore_many_u16_u64_8() {
generic_restore_many::<u16, u64, u8, 8>(1024, 1000);
}
#[test]
fn restore_many_u8_u64_8() {
generic_restore_many::<u8, u64, u8, 8>(1024, 1000);
}
#[test]
fn restore_many_u16_u32_16() {
generic_restore_many::<u16, u32, u16, 16>(1024, 1000);
}
#[test]
fn restore_many_u16_u32_8() {
generic_restore_many::<u16, u32, u8, 8>(1024, 1000);
}
#[test]
fn restore_many_u8_u32_8() {
generic_restore_many::<u8, u32, u8, 8>(1024, 1000);
}
fn generic_restore_many<Word, State, Probability, const PRECISION: usize>(
amt_compressed_words: usize,
amt_symbols: usize,
) where
State: BitArray + AsPrimitive<Word>,
Word: BitArray + Into<State> + AsPrimitive<Probability>,
Probability: BitArray + Into<Word> + AsPrimitive<usize> + Into<f64>,
u64: AsPrimitive<Word>,
u32: AsPrimitive<Probability>,
usize: AsPrimitive<Probability>,
f64: AsPrimitive<Probability>,
i32: AsPrimitive<Probability>,
{
#[cfg(miri)]
let (amt_compressed_words, amt_symbols) =
(amt_compressed_words.min(128), amt_symbols.min(100));
let mut rng = Xoshiro256StarStar::seed_from_u64(
(amt_compressed_words as u64).rotate_left(32) ^ amt_symbols as u64,
);
let mut compressed = (0..amt_compressed_words)
.map(|_| rng.next_u64().as_())
.collect::<Vec<_>>();
let leading_zeros = (rng.next_u32() % (Word::BITS as u32 - 1)) as usize;
let last_word = compressed.last_mut().unwrap();
*last_word = *last_word | Word::one() << (Word::BITS - leading_zeros - 1);
*last_word = *last_word & Word::max_value() >> leading_zeros;
let distributions = (0..amt_symbols)
.map(|_| {
let mean = (200.0 / u32::MAX as f64) * rng.next_u32() as f64 - 100.0;
let std_dev = (10.0 / u32::MAX as f64) * rng.next_u32() as f64 + 0.001;
Gaussian::new(mean, std_dev)
})
.collect::<Vec<_>>();
let quantizer = LeakyQuantizer::<_, _, Probability, PRECISION>::new(-100..=100);
let mut coder =
ChainCoder::<Word, State, Vec<Word>, Vec<Word>, PRECISION>::from_compressed(
compressed.clone(),
)
.unwrap();
let symbols = coder
.decode_symbols(
distributions
.iter()
.map(|&distribution| quantizer.quantize(distribution)),
)
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert!(!coder.maybe_exhausted());
let (remainders_prefix, remainders_suffix) = coder.clone().into_remainders().unwrap();
let mut remainders = remainders_prefix.clone();
remainders.extend_from_slice(&remainders_suffix);
let coder2 = ChainCoder::from_remainders(remainders).unwrap();
let coder3 = ChainCoder::from_remainders(remainders_suffix).unwrap();
for (mut coder, prefix) in vec![
(coder, vec![]),
(coder2, vec![]),
(coder3, remainders_prefix),
] {
coder
.encode_symbols_reverse(
symbols
.iter()
.zip(&distributions)
.map(|(&symbol, &distribution)| (symbol, quantizer.quantize(distribution))),
)
.unwrap();
let (compressed_prefix, compressed_suffix) = coder.into_compressed().unwrap();
let mut reconstructed = prefix;
reconstructed.extend(compressed_prefix);
reconstructed.extend(compressed_suffix);
assert_eq!(reconstructed, compressed);
}
}
}