use core::{borrow::Borrow, hash::Hash, marker::PhantomData};
#[cfg(feature = "std")]
use std::collections::{
hash_map::Entry::{Occupied, Vacant},
HashMap,
};
#[cfg(not(feature = "std"))]
use hashbrown::hash_map::{
Entry::{Occupied, Vacant},
HashMap,
};
use alloc::{boxed::Box, vec::Vec};
use num_traits::{float::FloatCore, AsPrimitive};
use crate::{wrapping_pow2, BitArray, NonZeroBitArray};
use super::{
super::{DecoderModel, EncoderModel, EntropyModel, IterableEntropyModel},
accumulate_nonzero_probabilities, fast_quantized_cdf, iter_extended_cdf,
lookup_noncontiguous::NonContiguousLookupDecoderModel,
perfectly_quantized_probabilities,
};
pub type DefaultNonContiguousCategoricalEncoderModel<Symbol> =
NonContiguousCategoricalEncoderModel<Symbol, u32, 24>;
pub type SmallNonContiguousCategoricalEncoderModel<Symbol> =
NonContiguousCategoricalEncoderModel<Symbol, u16, 12>;
pub type DefaultNonContiguousCategoricalDecoderModel<Symbol, Cdf = Vec<(u32, Symbol)>> =
NonContiguousCategoricalDecoderModel<Symbol, u32, Cdf, 24>;
pub type SmallNonContiguousCategoricalDecoderModel<Symbol, Cdf = Vec<(u16, Symbol)>> =
NonContiguousCategoricalDecoderModel<Symbol, u16, Cdf, 12>;
#[derive(Debug, Clone, Copy)]
pub struct NonContiguousCategoricalDecoderModel<Symbol, Probability, Cdf, const PRECISION: usize> {
pub(super) cdf: Cdf,
pub(super) phantom: PhantomData<(Symbol, Probability)>,
}
impl<Symbol, Probability: BitArray, const PRECISION: usize>
NonContiguousCategoricalDecoderModel<Symbol, Probability, Vec<(Probability, Symbol)>, PRECISION>
where
Symbol: Clone,
{
#[allow(clippy::result_unit_err)]
pub fn from_symbols_and_floating_point_probabilities_fast<F>(
symbols: impl IntoIterator<Item = Symbol>,
probabilities: &[F],
normalization: Option<F>,
) -> Result<Self, ()>
where
F: FloatCore + core::iter::Sum<F> + AsPrimitive<Probability>,
Probability: AsPrimitive<usize>,
usize: AsPrimitive<Probability> + AsPrimitive<F>,
{
let cdf = fast_quantized_cdf::<Probability, F, PRECISION>(probabilities, normalization)?;
let mut extended_cdf = Vec::with_capacity(probabilities.len() + 1);
extended_cdf.extend(cdf.zip(symbols));
let last_symbol = extended_cdf.last().expect("`len` >= 2").1.clone();
extended_cdf.push((wrapping_pow2(PRECISION), last_symbol));
Ok(Self::from_extended_cdf(extended_cdf))
}
#[allow(clippy::result_unit_err)]
pub fn from_symbols_and_floating_point_probabilities_perfect<F>(
symbols: impl IntoIterator<Item = Symbol>,
probabilities: &[F],
) -> Result<Self, ()>
where
F: FloatCore + core::iter::Sum<F> + Into<f64>,
Probability: Into<f64> + AsPrimitive<usize>,
f64: AsPrimitive<Probability>,
usize: AsPrimitive<Probability>,
{
let slots = perfectly_quantized_probabilities::<_, _, PRECISION>(probabilities)?;
Self::from_symbols_and_nonzero_fixed_point_probabilities(
symbols,
slots.into_iter().map(|slot| slot.weight),
false,
)
}
#[deprecated(
since = "0.4.0",
note = "Please use `from_symbols_and_floating_point_probabilities_fast` or \
`from_symbols_and_floating_point_probabilities_perfect` instead. See documentation for \
detailed upgrade instructions."
)]
#[allow(clippy::result_unit_err)]
pub fn from_symbols_and_floating_point_probabilities<F>(
symbols: &[Symbol],
probabilities: &[F],
) -> Result<Self, ()>
where
F: FloatCore + core::iter::Sum<F> + Into<f64>,
Probability: Into<f64> + AsPrimitive<usize>,
f64: AsPrimitive<Probability>,
usize: AsPrimitive<Probability>,
{
Self::from_symbols_and_floating_point_probabilities_perfect(
symbols.iter().cloned(),
probabilities,
)
}
#[allow(clippy::result_unit_err)]
pub fn from_symbols_and_nonzero_fixed_point_probabilities<S, P>(
symbols: S,
probabilities: P,
infer_last_probability: bool,
) -> Result<Self, ()>
where
S: IntoIterator<Item = Symbol>,
P: IntoIterator,
P::Item: Borrow<Probability>,
{
let symbols = symbols.into_iter();
let mut cdf = Vec::with_capacity(symbols.size_hint().0 + 1);
let mut symbols = accumulate_nonzero_probabilities::<_, _, _, _, _, PRECISION>(
symbols,
probabilities.into_iter(),
|symbol, left_sided_cumulative, _| {
cdf.push((left_sided_cumulative, symbol));
Ok(())
},
infer_last_probability,
)?;
cdf.push((
wrapping_pow2(PRECISION),
cdf.last().expect("`symbols` is not empty").1.clone(),
));
if symbols.next().is_some() {
Err(())
} else {
Ok(Self::from_extended_cdf(cdf))
}
}
#[inline(always)]
fn from_extended_cdf(cdf: Vec<(Probability, Symbol)>) -> Self {
Self {
cdf,
phantom: PhantomData,
}
}
pub fn from_iterable_entropy_model<'m, M>(model: &'m M) -> Self
where
M: IterableEntropyModel<'m, PRECISION, Symbol = Symbol, Probability = Probability> + ?Sized,
{
let symbol_table = model.symbol_table();
let mut cdf = Vec::with_capacity(symbol_table.size_hint().0 + 1);
cdf.extend(
symbol_table.map(|(symbol, left_sided_cumulative, _)| (left_sided_cumulative, symbol)),
);
cdf.push((
wrapping_pow2(PRECISION),
cdf.last().expect("`symbol_table` is not empty").1.clone(),
));
Self {
cdf,
phantom: PhantomData,
}
}
}
impl<Symbol, Probability, Cdf, const PRECISION: usize>
NonContiguousCategoricalDecoderModel<Symbol, Probability, Cdf, PRECISION>
where
Symbol: Clone,
Probability: BitArray,
Cdf: AsRef<[(Probability, Symbol)]>,
{
#[inline(always)]
pub fn support_size(&self) -> usize {
self.cdf.as_ref().len() - 1
}
#[inline]
pub fn as_view(
&self,
) -> NonContiguousCategoricalDecoderModel<
Symbol,
Probability,
&[(Probability, Symbol)],
PRECISION,
> {
NonContiguousCategoricalDecoderModel {
cdf: self.cdf.as_ref(),
phantom: PhantomData,
}
}
#[inline(always)]
pub fn to_lookup_decoder_model(
&self,
) -> NonContiguousLookupDecoderModel<
Symbol,
Probability,
Vec<(Probability, Symbol)>,
Box<[Probability]>,
PRECISION,
>
where
Probability: Into<usize>,
usize: AsPrimitive<Probability>,
{
self.into()
}
}
impl<Symbol, Probability, Cdf, const PRECISION: usize> EntropyModel<PRECISION>
for NonContiguousCategoricalDecoderModel<Symbol, Probability, Cdf, PRECISION>
where
Probability: BitArray,
{
type Symbol = Symbol;
type Probability = Probability;
}
impl<'m, Symbol, Probability, Cdf, const PRECISION: usize> IterableEntropyModel<'m, PRECISION>
for NonContiguousCategoricalDecoderModel<Symbol, Probability, Cdf, PRECISION>
where
Symbol: Clone + 'm,
Probability: BitArray,
Cdf: AsRef<[(Probability, Symbol)]>,
{
#[inline(always)]
fn symbol_table(
&'m self,
) -> impl Iterator<
Item = (
Self::Symbol,
Self::Probability,
<Self::Probability as BitArray>::NonZero,
),
> {
iter_extended_cdf(self.cdf.as_ref().iter().cloned())
}
fn floating_point_symbol_table<F>(&'m self) -> impl Iterator<Item = (Self::Symbol, F, F)>
where
F: FloatCore + From<Self::Probability> + 'm,
Self::Probability: Into<F>,
{
let whole = (F::one() + F::one()) * (Self::Probability::one() << (PRECISION - 1)).into();
self.symbol_table()
.map(move |(symbol, cumulative, probability)| {
(
symbol,
cumulative.into() / whole,
probability.get().into() / whole,
)
})
}
fn entropy_base2<F>(&'m self) -> F
where
F: num_traits::Float + core::iter::Sum,
Self::Probability: Into<F>,
{
let entropy_scaled = self
.symbol_table()
.map(|(_, _, probability)| {
let probability = probability.get().into();
probability * probability.log2() })
.sum::<F>();
let whole = (F::one() + F::one()) * (Self::Probability::one() << (PRECISION - 1)).into();
F::from(PRECISION).unwrap() - entropy_scaled / whole
}
fn to_generic_encoder_model(
&'m self,
) -> NonContiguousCategoricalEncoderModel<Self::Symbol, Self::Probability, PRECISION>
where
Self::Symbol: core::hash::Hash + Eq,
{
self.into()
}
fn to_generic_decoder_model(
&'m self,
) -> NonContiguousCategoricalDecoderModel<
Self::Symbol,
Self::Probability,
Vec<(Self::Probability, Self::Symbol)>,
PRECISION,
>
where
Self::Symbol: Clone,
{
self.into()
}
fn to_generic_lookup_decoder_model(
&'m self,
) -> NonContiguousLookupDecoderModel<
Self::Symbol,
Self::Probability,
Vec<(Self::Probability, Self::Symbol)>,
Box<[Self::Probability]>,
PRECISION,
>
where
Self::Probability: Into<usize>,
usize: AsPrimitive<Self::Probability>,
Self::Symbol: Clone + Default,
{
self.into()
}
}
impl<Symbol, Probability, Cdf, const PRECISION: usize> DecoderModel<PRECISION>
for NonContiguousCategoricalDecoderModel<Symbol, Probability, Cdf, PRECISION>
where
Symbol: Clone,
Probability: BitArray,
Cdf: AsRef<[(Probability, Symbol)]>,
{
#[inline(always)]
fn quantile_function(
&self,
quantile: Self::Probability,
) -> (Symbol, Probability, Probability::NonZero) {
let cdf = self.cdf.as_ref();
let monotonic_part_of_cdf = unsafe { cdf.get_unchecked(..cdf.len() - 1) };
let Err(next_index) = monotonic_part_of_cdf.binary_search_by(|(cumulative, _symbol)| {
if *cumulative <= quantile {
core::cmp::Ordering::Less
} else {
core::cmp::Ordering::Greater
}
}) else {
unsafe { core::hint::unreachable_unchecked() }
};
let (right_cumulative, (left_cumulative, symbol)) = unsafe {
(
cdf.get_unchecked(next_index).0,
cdf.get_unchecked(next_index - 1).clone(),
)
};
let probability = unsafe {
right_cumulative
.wrapping_sub(&left_cumulative)
.into_nonzero_unchecked()
};
(symbol, left_cumulative, probability)
}
}
impl<'m, Symbol, Probability, M, const PRECISION: usize> From<&'m M>
for NonContiguousCategoricalDecoderModel<
Symbol,
Probability,
Vec<(Probability, Symbol)>,
PRECISION,
>
where
Symbol: Clone,
Probability: BitArray,
M: IterableEntropyModel<'m, PRECISION, Symbol = Symbol, Probability = Probability> + ?Sized,
{
#[inline(always)]
fn from(model: &'m M) -> Self {
Self::from_iterable_entropy_model(model)
}
}
#[derive(Debug, Clone)]
pub struct NonContiguousCategoricalEncoderModel<Symbol, Probability, const PRECISION: usize>
where
Symbol: Hash,
Probability: BitArray,
{
table: HashMap<Symbol, (Probability, Probability::NonZero)>,
}
impl<Symbol, Probability, const PRECISION: usize>
NonContiguousCategoricalEncoderModel<Symbol, Probability, PRECISION>
where
Symbol: Hash + Eq,
Probability: BitArray,
{
#[allow(clippy::result_unit_err)]
pub fn from_symbols_and_floating_point_probabilities_fast<F>(
symbols: impl IntoIterator<Item = Symbol>,
probabilities: &[F],
normalization: Option<F>,
) -> Result<Self, ()>
where
F: FloatCore + core::iter::Sum<F> + AsPrimitive<Probability>,
Probability: AsPrimitive<usize>,
usize: AsPrimitive<Probability> + AsPrimitive<F>,
{
let cdf = fast_quantized_cdf::<Probability, F, PRECISION>(probabilities, normalization)?;
Self::from_symbols_and_cdf(symbols, cdf)
}
#[allow(clippy::result_unit_err)]
pub fn from_symbols_and_floating_point_probabilities_perfect<F>(
symbols: impl IntoIterator<Item = Symbol>,
probabilities: &[F],
) -> Result<Self, ()>
where
F: FloatCore + core::iter::Sum<F> + Into<f64>,
Probability: Into<f64> + AsPrimitive<usize>,
f64: AsPrimitive<Probability>,
usize: AsPrimitive<Probability>,
{
let slots = perfectly_quantized_probabilities::<_, _, PRECISION>(probabilities)?;
Self::from_symbols_and_nonzero_fixed_point_probabilities(
symbols,
slots.into_iter().map(|slot| slot.weight),
false,
)
}
#[deprecated(
since = "0.4.0",
note = "Please use `from_symbols_and_floating_point_probabilities_fast` or \
`from_symbols_and_floating_point_probabilities_perfect` instead. See documentation for \
detailed upgrade instructions."
)]
#[allow(clippy::result_unit_err)]
pub fn from_symbols_and_floating_point_probabilities<F>(
symbols: impl IntoIterator<Item = Symbol>,
probabilities: &[F],
) -> Result<Self, ()>
where
F: FloatCore + core::iter::Sum<F> + Into<f64>,
Probability: Into<f64> + AsPrimitive<usize>,
f64: AsPrimitive<Probability>,
usize: AsPrimitive<Probability>,
{
Self::from_symbols_and_floating_point_probabilities_perfect(symbols, probabilities)
}
#[allow(clippy::result_unit_err)]
pub fn from_symbols_and_nonzero_fixed_point_probabilities<S, P>(
symbols: S,
probabilities: P,
infer_last_probability: bool,
) -> Result<Self, ()>
where
S: IntoIterator<Item = Symbol>,
P: IntoIterator,
P::Item: Borrow<Probability>,
{
let symbols = symbols.into_iter();
let mut table =
HashMap::with_capacity(symbols.size_hint().0 + infer_last_probability as usize);
let mut symbols = accumulate_nonzero_probabilities::<_, _, _, _, _, PRECISION>(
symbols,
probabilities.into_iter(),
|symbol, left_sided_cumulative, probability| match table.entry(symbol) {
Occupied(_) => Err(()),
Vacant(slot) => {
let probability = probability.into_nonzero().ok_or(())?;
slot.insert((left_sided_cumulative, probability));
Ok(())
}
},
infer_last_probability,
)?;
if symbols.next().is_some() {
Err(())
} else {
Ok(Self { table })
}
}
#[allow(clippy::result_unit_err)]
fn from_symbols_and_cdf<S, P>(symbols: S, cdf: P) -> Result<Self, ()>
where
S: IntoIterator<Item = Symbol>,
P: IntoIterator<Item = Probability>,
{
let mut symbols = symbols.into_iter();
let mut cdf = cdf.into_iter();
let mut table = HashMap::with_capacity(symbols.size_hint().0);
let mut left_cumulative = cdf.next().ok_or(())?;
for right_cumulative in cdf {
let symbol = symbols.next().ok_or(())?;
match table.entry(symbol) {
Occupied(_) => return Err(()),
Vacant(slot) => {
let probability = (right_cumulative - left_cumulative)
.into_nonzero()
.ok_or(())?;
slot.insert((left_cumulative, probability));
}
}
left_cumulative = right_cumulative;
}
let last_symbol = symbols.next().ok_or(())?;
let right_cumulative = wrapping_pow2::<Probability>(PRECISION);
match table.entry(last_symbol) {
Occupied(_) => return Err(()),
Vacant(slot) => {
let probability = right_cumulative
.wrapping_sub(&left_cumulative)
.into_nonzero()
.ok_or(())?;
slot.insert((left_cumulative, probability));
}
}
if symbols.next().is_some() {
Err(())
} else {
Ok(Self { table })
}
}
pub fn from_iterable_entropy_model<'m, M>(model: &'m M) -> Self
where
M: IterableEntropyModel<'m, PRECISION, Symbol = Symbol, Probability = Probability> + ?Sized,
{
let table = model
.symbol_table()
.map(|(symbol, left_sided_cumulative, probability)| {
(symbol, (left_sided_cumulative, probability))
})
.collect::<HashMap<_, _>>();
Self { table }
}
pub fn support_size(&self) -> usize {
self.table.len()
}
pub fn entropy_base2<F>(&self) -> F
where
F: num_traits::Float + core::iter::Sum,
Probability: Into<F>,
{
let entropy_scaled = self
.table
.values()
.map(|&(_, probability)| {
let probability = probability.get().into();
probability * probability.log2() })
.sum::<F>();
let whole = (F::one() + F::one()) * (Probability::one() << (PRECISION - 1)).into();
F::from(PRECISION).unwrap() - entropy_scaled / whole
}
}
impl<'m, Symbol, Probability, M, const PRECISION: usize> From<&'m M>
for NonContiguousCategoricalEncoderModel<Symbol, Probability, PRECISION>
where
Symbol: Hash + Eq,
Probability: BitArray,
M: IterableEntropyModel<'m, PRECISION, Symbol = Symbol, Probability = Probability> + ?Sized,
{
#[inline(always)]
fn from(model: &'m M) -> Self {
Self::from_iterable_entropy_model(model)
}
}
impl<Symbol, Probability: BitArray, const PRECISION: usize> EntropyModel<PRECISION>
for NonContiguousCategoricalEncoderModel<Symbol, Probability, PRECISION>
where
Symbol: Hash,
Probability: BitArray,
{
type Probability = Probability;
type Symbol = Symbol;
}
impl<Symbol, Probability: BitArray, const PRECISION: usize> EncoderModel<PRECISION>
for NonContiguousCategoricalEncoderModel<Symbol, Probability, PRECISION>
where
Symbol: Hash + Eq,
Probability: BitArray,
{
#[inline(always)]
fn left_cumulative_and_probability(
&self,
symbol: impl Borrow<Self::Symbol>,
) -> Option<(Self::Probability, Probability::NonZero)> {
self.table.get(symbol.borrow()).cloned()
}
}
#[cfg(test)]
mod tests {
use super::super::super::tests::{test_iterable_entropy_model, verify_iterable_entropy_model};
use super::*;
#[test]
fn non_contiguous_categorical() {
let hist = [
1u32, 186545, 237403, 295700, 361445, 433686, 509456, 586943, 663946, 737772, 1657269,
896675, 922197, 930672, 916665, 0, 0, 0, 0, 0, 723031, 650522, 572300, 494702, 418703,
347600, 1, 283500, 226158, 178194, 136301, 103158, 76823, 55540, 39258, 27988, 54269,
];
let probabilities = hist.iter().map(|&x| x as f64).collect::<Vec<_>>();
let symbols = "QWERTYUIOPASDFGHJKLZXCVBNM 1234567890"
.chars()
.collect::<Vec<_>>();
let fast =
NonContiguousCategoricalDecoderModel::<_,u32, _, 32>::from_symbols_and_floating_point_probabilities_fast(
symbols.iter().cloned(),
&probabilities,
None
)
.unwrap();
test_iterable_entropy_model(&fast, symbols.iter().cloned());
let kl_fast = verify_iterable_entropy_model(&fast, &hist, 1e-8);
let perfect =
NonContiguousCategoricalDecoderModel::<_,u32, _, 32>::from_symbols_and_floating_point_probabilities_perfect(
symbols.iter().cloned(),
&probabilities,
)
.unwrap();
test_iterable_entropy_model(&perfect, symbols.iter().cloned());
let kl_perfect = verify_iterable_entropy_model(&perfect, &hist, 1e-8);
assert!(kl_perfect < kl_fast);
}
}