use core::{borrow::Borrow, marker::PhantomData};
use alloc::{boxed::Box, vec::Vec};
use num_traits::{float::FloatCore, AsPrimitive};
use crate::{generic_static_asserts, wrapping_pow2, BitArray, NonZeroBitArray};
use super::{
super::{DecoderModel, EntropyModel, IterableEntropyModel},
accumulate_nonzero_probabilities, fast_quantized_cdf, iter_extended_cdf,
non_contiguous::NonContiguousCategoricalDecoderModel,
perfectly_quantized_probabilities,
};
pub type SmallNonContiguousLookupDecoderModel<
Symbol,
Cdf = Vec<(u16, Symbol)>,
LookupTable = Box<[u16]>,
> = NonContiguousLookupDecoderModel<Symbol, u16, Cdf, LookupTable, 12>;
#[derive(Debug, Clone, Copy)]
pub struct NonContiguousLookupDecoderModel<
Symbol,
Probability = u16,
Cdf = Vec<(Probability, Symbol)>,
LookupTable = Box<[Probability]>,
const PRECISION: usize = 12,
> where
Probability: BitArray,
{
lookup_table: LookupTable,
cdf: Cdf,
phantom: PhantomData<(Probability, Symbol)>,
}
impl<Symbol, Probability, const PRECISION: usize>
NonContiguousLookupDecoderModel<
Symbol,
Probability,
Vec<(Probability, Symbol)>,
Box<[Probability]>,
PRECISION,
>
where
Probability: BitArray + Into<usize>,
Symbol: Clone,
usize: AsPrimitive<Probability>,
{
#[allow(clippy::result_unit_err)]
pub fn from_symbols_and_floating_point_probabilities_perfect<F>(
symbols: impl IntoIterator<Item = Symbol>,
probabilities: &[F],
) -> Result<Self, ()>
where
F: FloatCore + core::iter::Sum<F> + Into<f64>,
Probability: Into<f64> + AsPrimitive<usize>,
f64: AsPrimitive<Probability>,
usize: AsPrimitive<Probability>,
{
let slots = perfectly_quantized_probabilities::<_, _, PRECISION>(probabilities)?;
Self::from_symbols_and_nonzero_fixed_point_probabilities(
symbols,
slots.into_iter().map(|slot| slot.weight),
false,
)
}
#[allow(clippy::result_unit_err)]
pub fn from_symbols_and_floating_point_probabilities_fast<F>(
symbols: impl IntoIterator<Item = Symbol>,
probabilities: &[F],
normalization: Option<F>,
) -> Result<Self, ()>
where
F: FloatCore + core::iter::Sum<F> + AsPrimitive<Probability>,
Probability: AsPrimitive<usize>,
f64: AsPrimitive<Probability>,
usize: AsPrimitive<Probability> + AsPrimitive<F>,
{
let mut cdf =
fast_quantized_cdf::<Probability, F, PRECISION>(probabilities, normalization)?;
let mut left_cumulative = cdf.next().expect("cdf is not empty");
let cdf = cdf.chain(core::iter::once(wrapping_pow2(PRECISION)));
let symbol_table = symbols
.into_iter()
.zip(cdf)
.map(|(symbol, right_cumulative)| {
let probability = right_cumulative
.wrapping_sub(&left_cumulative)
.into_nonzero()
.expect("quantization is leaky");
let old_left_cumulative = left_cumulative;
left_cumulative = right_cumulative;
(symbol, old_left_cumulative, probability)
});
Ok(Self::from_symbol_table(symbol_table))
}
#[deprecated(
since = "0.4.0",
note = "Please use `from_symbols_and_floating_point_probabilities_fast` or \
`from_symbols_and_floating_point_probabilities_perfect` instead. See documentation for \
detailed upgrade instructions."
)]
#[allow(clippy::result_unit_err)]
pub fn from_symbols_and_floating_point_probabilities<F>(
symbols: &[Symbol],
probabilities: &[F],
) -> Result<Self, ()>
where
F: FloatCore + core::iter::Sum<F> + Into<f64>,
Probability: Into<f64> + AsPrimitive<usize>,
f64: AsPrimitive<Probability>,
usize: AsPrimitive<Probability>,
{
Self::from_symbols_and_floating_point_probabilities_perfect(
symbols.iter().cloned(),
probabilities,
)
}
#[allow(clippy::result_unit_err)]
pub fn from_symbols_and_nonzero_fixed_point_probabilities<S, P>(
symbols: S,
probabilities: P,
infer_last_probability: bool,
) -> Result<Self, ()>
where
S: IntoIterator<Item = Symbol>,
P: IntoIterator,
P::Item: Borrow<Probability>,
{
generic_static_asserts!(
(Probability: BitArray; const PRECISION: usize);
PROBABILITY_MUST_SUPPORT_PRECISION: PRECISION <= Probability::BITS;
PRECISION_MUST_BE_NONZERO: PRECISION > 0;
USIZE_MUST_STRICTLY_SUPPORT_PRECISION: PRECISION < <usize as BitArray>::BITS;
);
let mut lookup_table = Vec::with_capacity(1 << PRECISION);
let symbols = symbols.into_iter();
let mut cdf =
Vec::with_capacity(symbols.size_hint().0 + 1 + infer_last_probability as usize);
let mut symbols = accumulate_nonzero_probabilities::<_, _, _, _, _, PRECISION>(
symbols,
probabilities.into_iter(),
|symbol, _, probability| {
let index = cdf.len().as_();
cdf.push((lookup_table.len().as_(), symbol));
lookup_table.resize(lookup_table.len() + probability.into(), index);
Ok(())
},
infer_last_probability,
)?;
let last_symbol = cdf.last().expect("cdf is not empty").1.clone();
cdf.push((wrapping_pow2(PRECISION), last_symbol));
if symbols.next().is_some() {
Err(())
} else {
Ok(Self {
lookup_table: lookup_table.into_boxed_slice(),
cdf,
phantom: PhantomData,
})
}
}
pub fn from_iterable_entropy_model<'m, M>(model: &'m M) -> Self
where
M: IterableEntropyModel<'m, PRECISION, Symbol = Symbol, Probability = Probability> + ?Sized,
{
Self::from_symbol_table(model.symbol_table())
}
fn from_symbol_table(
symbol_table: impl Iterator<Item = (Symbol, Probability, Probability::NonZero)>,
) -> Self {
generic_static_asserts!(
(Probability: BitArray; const PRECISION: usize);
PROBABILITY_MUST_SUPPORT_PRECISION: PRECISION <= Probability::BITS;
PRECISION_MUST_BE_NONZERO: PRECISION > 0;
USIZE_MUST_STRICTLY_SUPPORT_PRECISION: PRECISION < <usize as BitArray>::BITS;
);
let mut lookup_table = Vec::with_capacity(1 << PRECISION);
let mut cdf = Vec::with_capacity(symbol_table.size_hint().0 + 1);
for (symbol, left_sided_cumulative, probability) in symbol_table {
let index = cdf.len().as_();
debug_assert_eq!(left_sided_cumulative, lookup_table.len().as_());
cdf.push((lookup_table.len().as_(), symbol));
lookup_table.resize(lookup_table.len() + probability.get().into(), index);
}
let last_symbol = cdf.last().expect("cdf is not empty").1.clone();
cdf.push((wrapping_pow2(PRECISION), last_symbol));
Self {
lookup_table: lookup_table.into_boxed_slice(),
cdf,
phantom: PhantomData,
}
}
}
impl<Symbol, Probability, Cdf, LookupTable, const PRECISION: usize>
NonContiguousLookupDecoderModel<Symbol, Probability, Cdf, LookupTable, PRECISION>
where
Probability: BitArray + Into<usize>,
usize: AsPrimitive<Probability>,
Cdf: AsRef<[(Probability, Symbol)]>,
LookupTable: AsRef<[Probability]>,
{
pub fn as_view(
&self,
) -> NonContiguousLookupDecoderModel<
Symbol,
Probability,
&[(Probability, Symbol)],
&[Probability],
PRECISION,
> {
NonContiguousLookupDecoderModel {
lookup_table: self.lookup_table.as_ref(),
cdf: self.cdf.as_ref(),
phantom: PhantomData,
}
}
pub fn as_non_contiguous_categorical(
&self,
) -> NonContiguousCategoricalDecoderModel<
Symbol,
Probability,
&[(Probability, Symbol)],
PRECISION,
> {
NonContiguousCategoricalDecoderModel {
cdf: self.cdf.as_ref(),
phantom: PhantomData,
}
}
pub fn into_non_contiguous_categorical(
self,
) -> NonContiguousCategoricalDecoderModel<Symbol, Probability, Cdf, PRECISION> {
NonContiguousCategoricalDecoderModel {
cdf: self.cdf,
phantom: PhantomData,
}
}
}
impl<Symbol, Probability, Cdf, LookupTable, const PRECISION: usize> EntropyModel<PRECISION>
for NonContiguousLookupDecoderModel<Symbol, Probability, Cdf, LookupTable, PRECISION>
where
Probability: BitArray + Into<usize>,
{
type Symbol = Symbol;
type Probability = Probability;
}
impl<Symbol, Probability, Cdf, LookupTable, const PRECISION: usize> DecoderModel<PRECISION>
for NonContiguousLookupDecoderModel<Symbol, Probability, Cdf, LookupTable, PRECISION>
where
Probability: BitArray + Into<usize>,
Cdf: AsRef<[(Probability, Symbol)]>,
LookupTable: AsRef<[Probability]>,
Symbol: Clone,
{
#[inline(always)]
fn quantile_function(
&self,
quantile: Probability,
) -> (Symbol, Probability, Probability::NonZero) {
generic_static_asserts!(
(Probability: BitArray; const PRECISION: usize);
PROBABILITY_MUST_SUPPORT_PRECISION: PRECISION <= Probability::BITS;
PRECISION_MUST_BE_NONZERO: PRECISION > 0;
);
if Probability::BITS != PRECISION {
assert!(quantile < Probability::one() << PRECISION);
}
let ((left_sided_cumulative, symbol), next_cumulative) = unsafe {
let index = *self.lookup_table.as_ref().get_unchecked(quantile.into());
let index = index.into();
let cdf = self.cdf.as_ref();
(
cdf.get_unchecked(index).clone(),
cdf.get_unchecked(index + 1).0,
)
};
let probability = unsafe {
next_cumulative
.wrapping_sub(&left_sided_cumulative)
.into_nonzero_unchecked()
};
(symbol, left_sided_cumulative, probability)
}
}
impl<'m, Symbol, Probability, M, const PRECISION: usize> From<&'m M>
for NonContiguousLookupDecoderModel<
Symbol,
Probability,
Vec<(Probability, Symbol)>,
Box<[Probability]>,
PRECISION,
>
where
Probability: BitArray + Into<usize>,
Symbol: Clone,
usize: AsPrimitive<Probability>,
M: IterableEntropyModel<'m, PRECISION, Symbol = Symbol, Probability = Probability> + ?Sized,
{
#[inline(always)]
fn from(model: &'m M) -> Self {
Self::from_iterable_entropy_model(model)
}
}
impl<'m, Symbol, Probability, Cdf, LookupTable, const PRECISION: usize>
IterableEntropyModel<'m, PRECISION>
for NonContiguousLookupDecoderModel<Symbol, Probability, Cdf, LookupTable, PRECISION>
where
Symbol: Clone + 'm,
Probability: BitArray + Into<usize>,
usize: AsPrimitive<Probability>,
Cdf: AsRef<[(Probability, Symbol)]>,
LookupTable: AsRef<[Probability]>,
{
#[inline(always)]
fn symbol_table(
&'m self,
) -> impl Iterator<
Item = (
Self::Symbol,
Self::Probability,
<Self::Probability as BitArray>::NonZero,
),
> {
iter_extended_cdf(self.cdf.as_ref().iter().cloned())
}
}
#[cfg(test)]
mod tests {
use alloc::string::String;
use crate::stream::{
model::{EncoderModel, NonContiguousCategoricalEncoderModel},
stack::DefaultAnsCoder,
Decode,
};
use super::*;
#[test]
fn lookup_noncontiguous() {
let symbols = "axcy";
let probabilities = [3u8, 18, 1, 42];
let encoder_model = NonContiguousCategoricalEncoderModel::<_, u8, 6>::from_symbols_and_nonzero_fixed_point_probabilities(
symbols.chars(),probabilities.iter(),false
)
.unwrap();
let decoder_model = NonContiguousCategoricalDecoderModel::<_, _,_, 6>::from_symbols_and_nonzero_fixed_point_probabilities(
symbols.chars(),probabilities.iter(),false
)
.unwrap();
let lookup_decoder_model =
NonContiguousLookupDecoderModel::from_iterable_entropy_model(&decoder_model);
for symbol in symbols.chars() {
let (left_cumulative, probability) = encoder_model
.left_cumulative_and_probability(symbol)
.unwrap();
for quantile in left_cumulative..left_cumulative + probability.get() {
assert_eq!(
decoder_model.quantile_function(quantile),
(symbol, left_cumulative, probability)
);
assert_eq!(
lookup_decoder_model.quantile_function(quantile),
(symbol, left_cumulative, probability)
);
}
}
for quantile in 0..1 << 6 {
let (symbol, left_cumulative, probability) = decoder_model.quantile_function(quantile);
assert_eq!(
lookup_decoder_model.quantile_function(quantile),
(symbol, left_cumulative, probability)
);
assert_eq!(
encoder_model
.left_cumulative_and_probability(symbol)
.unwrap(),
(left_cumulative, probability)
);
}
let symbols = "axcxcyaac";
let mut ans = DefaultAnsCoder::new();
ans.encode_iid_symbols_reverse(symbols.chars(), &encoder_model)
.unwrap();
assert!(!ans.is_empty());
let decoded = ans
.decode_iid_symbols(9, &decoder_model)
.collect::<Result<String, _>>()
.unwrap();
assert_eq!(decoded, symbols);
assert!(ans.is_empty());
}
}