use core::{borrow::Borrow, marker::PhantomData};
use alloc::{boxed::Box, vec::Vec};
use num_traits::{float::FloatCore, AsPrimitive};
use crate::{
stream::model::{DecoderModel, EncoderModel, EntropyModel, IterableEntropyModel},
wrapping_pow2, BitArray,
};
use super::{
accumulate_nonzero_probabilities, fast_quantized_cdf, iter_extended_cdf,
lookup_contiguous::ContiguousLookupDecoderModel, perfectly_quantized_probabilities,
};
pub type DefaultContiguousCategoricalEntropyModel<Cdf = Vec<u32>> =
ContiguousCategoricalEntropyModel<u32, Cdf, 24>;
pub type SmallContiguousCategoricalEntropyModel<Cdf = Vec<u16>> =
ContiguousCategoricalEntropyModel<u16, Cdf, 12>;
#[derive(Debug, Clone, Copy)]
pub struct ContiguousCategoricalEntropyModel<Probability, Cdf, const PRECISION: usize> {
pub(super) cdf: Cdf,
pub(super) phantom: PhantomData<Probability>,
}
impl<Probability: BitArray, const PRECISION: usize>
ContiguousCategoricalEntropyModel<Probability, Vec<Probability>, PRECISION>
{
#[allow(clippy::result_unit_err)]
pub fn from_floating_point_probabilities_fast<F>(
probabilities: &[F],
normalization: Option<F>,
) -> Result<Self, ()>
where
F: FloatCore + core::iter::Sum<F> + AsPrimitive<Probability>,
Probability: BitArray + AsPrimitive<usize>,
usize: AsPrimitive<Probability> + AsPrimitive<F>,
{
let cdf = fast_quantized_cdf::<_, _, PRECISION>(probabilities, normalization)?;
Self::from_fixed_point_cdf(cdf)
}
#[allow(clippy::result_unit_err)]
pub fn from_floating_point_probabilities_perfect<F>(probabilities: &[F]) -> Result<Self, ()>
where
F: FloatCore + core::iter::Sum<F> + Into<f64>,
Probability: Into<f64> + AsPrimitive<usize>,
f64: AsPrimitive<Probability>,
usize: AsPrimitive<Probability>,
{
let slots = perfectly_quantized_probabilities::<_, _, PRECISION>(probabilities)?;
Self::from_nonzero_fixed_point_probabilities(
slots.into_iter().map(|slot| slot.weight),
false,
)
}
#[deprecated(
since = "0.4.0",
note = "Please use `from_floating_point_probabilities_fast` or \
`from_floating_point_probabilities_perfect` instead. See documentation for detailed \
upgrade instructions."
)]
#[allow(clippy::result_unit_err)]
#[inline(always)]
pub fn from_floating_point_probabilities<F>(probabilities: &[F]) -> Result<Self, ()>
where
F: FloatCore + core::iter::Sum<F> + Into<f64>,
Probability: Into<f64> + AsPrimitive<usize>,
f64: AsPrimitive<Probability>,
usize: AsPrimitive<Probability>,
{
Self::from_floating_point_probabilities_perfect(probabilities)
}
#[allow(clippy::result_unit_err)]
pub fn from_nonzero_fixed_point_probabilities<I>(
probabilities: I,
infer_last_probability: bool,
) -> Result<Self, ()>
where
I: IntoIterator,
I::Item: Borrow<Probability>,
{
let probabilities = probabilities.into_iter();
let mut cdf =
Vec::with_capacity(probabilities.size_hint().0 + 1 + infer_last_probability as usize);
accumulate_nonzero_probabilities::<_, _, _, _, _, PRECISION>(
core::iter::repeat(()),
probabilities,
|(), left_sided_cumulative, _| {
cdf.push(left_sided_cumulative);
Ok(())
},
infer_last_probability,
)?;
cdf.push(wrapping_pow2(PRECISION));
Ok(Self {
cdf,
phantom: PhantomData,
})
}
fn from_fixed_point_cdf<I>(cdf: I) -> Result<Self, ()>
where
I: ExactSizeIterator<Item = Probability>,
{
let extended_cdf = cdf
.chain(core::iter::once(wrapping_pow2(PRECISION)))
.collect();
Ok(Self {
cdf: extended_cdf,
phantom: PhantomData,
})
}
}
impl<Probability, Cdf, const PRECISION: usize>
ContiguousCategoricalEntropyModel<Probability, Cdf, PRECISION>
where
Probability: BitArray,
Cdf: AsRef<[Probability]>,
{
#[inline(always)]
pub fn support_size(&self) -> usize {
self.cdf.as_ref().len() - 1
}
#[inline]
pub fn as_view(
&self,
) -> ContiguousCategoricalEntropyModel<Probability, &[Probability], PRECISION> {
ContiguousCategoricalEntropyModel {
cdf: self.cdf.as_ref(),
phantom: PhantomData,
}
}
#[inline(always)]
pub fn to_lookup_decoder_model(
&self,
) -> ContiguousLookupDecoderModel<Probability, Vec<Probability>, Box<[Probability]>, PRECISION>
where
Probability: Into<usize>,
usize: AsPrimitive<Probability>,
{
self.into()
}
}
impl<Probability, Cdf, const PRECISION: usize> EntropyModel<PRECISION>
for ContiguousCategoricalEntropyModel<Probability, Cdf, PRECISION>
where
Probability: BitArray,
{
type Symbol = usize;
type Probability = Probability;
}
impl<'m, Probability, Cdf, const PRECISION: usize> IterableEntropyModel<'m, PRECISION>
for ContiguousCategoricalEntropyModel<Probability, Cdf, PRECISION>
where
Probability: BitArray,
Cdf: AsRef<[Probability]>,
{
fn symbol_table(
&'m self,
) -> impl Iterator<
Item = (
Self::Symbol,
Self::Probability,
<Self::Probability as BitArray>::NonZero,
),
> {
iter_extended_cdf(
self.cdf
.as_ref()
.iter()
.enumerate()
.map(|(symbol, &cumulative)| (cumulative, symbol)),
)
}
}
impl<Probability, Cdf, const PRECISION: usize> DecoderModel<PRECISION>
for ContiguousCategoricalEntropyModel<Probability, Cdf, PRECISION>
where
Probability: BitArray,
Cdf: AsRef<[Probability]>,
{
#[inline(always)]
fn quantile_function(
&self,
quantile: Self::Probability,
) -> (usize, Probability, Probability::NonZero) {
let cdf = self.cdf.as_ref();
let monotonic_part_of_cdf = unsafe { cdf.get_unchecked(..cdf.len() - 1) };
let Err(next_symbol) = monotonic_part_of_cdf.binary_search_by(|&x| {
if x <= quantile {
core::cmp::Ordering::Less
} else {
core::cmp::Ordering::Greater
}
}) else {
unsafe { core::hint::unreachable_unchecked() }
};
let symbol = next_symbol - 1;
let (right_cumulative, left_cumulative) =
unsafe { (*cdf.get_unchecked(next_symbol), *cdf.get_unchecked(symbol)) };
let probability = unsafe {
right_cumulative
.wrapping_sub(&left_cumulative)
.into_nonzero_unchecked()
};
(symbol, left_cumulative, probability)
}
}
impl<Probability, Cdf, const PRECISION: usize> EncoderModel<PRECISION>
for ContiguousCategoricalEntropyModel<Probability, Cdf, PRECISION>
where
Probability: BitArray,
Cdf: AsRef<[Probability]>,
{
fn left_cumulative_and_probability(
&self,
symbol: impl Borrow<usize>,
) -> Option<(Probability, Probability::NonZero)> {
let index = *symbol.borrow();
if index >= self.support_size() {
return None;
}
let cdf = self.cdf.as_ref();
let (left_cumulative, right_cumulative) =
unsafe { (*cdf.get_unchecked(index), *cdf.get_unchecked(index + 1)) };
let probability = unsafe {
right_cumulative
.wrapping_sub(&left_cumulative)
.into_nonzero_unchecked()
};
Some((left_cumulative, probability))
}
}
#[cfg(test)]
mod tests {
use super::super::super::tests::{test_entropy_model, verify_iterable_entropy_model};
use super::*;
#[test]
fn trivial_optimal_weights() {
let hist = [
56319u32, 134860032, 47755520, 60775168, 75699200, 92529920, 111023616, 130420736,
150257408, 169970176, 188869632, 424260864, 229548800, 236082432, 238252287, 234666240,
1, 1, 227725568, 216746240, 202127104, 185095936, 166533632, 146508800, 126643712,
107187968, 88985600, 72576000, 57896448, 45617664, 34893056, 26408448, 19666688,
14218240, 10050048, 7164928, 13892864,
];
assert_eq!(hist.iter().map(|&x| x as u64).sum::<u64>(), 1 << 32);
let probabilities = hist.iter().map(|&x| x as f64).collect::<Vec<_>>();
let categorical =
ContiguousCategoricalEntropyModel::<u32, _, 32>::from_floating_point_probabilities_perfect(
&probabilities,
)
.unwrap();
let weights: Vec<_> = categorical
.symbol_table()
.map(|(_, _, probability)| probability.get())
.collect();
assert_eq!(&weights[..], &hist[..]);
}
#[test]
fn nontrivial_optimal_weights_f64() {
let hist = [
1u32, 186545, 237403, 295700, 361445, 433686, 509456, 586943, 663946, 737772, 1657269,
896675, 922197, 930672, 916665, 0, 0, 0, 0, 0, 723031, 650522, 572300, 494702, 418703,
347600, 1, 283500, 226158, 178194, 136301, 103158, 76823, 55540, 39258, 27988, 54269,
];
assert_ne!(hist.iter().map(|&x| x as u64).sum::<u64>(), 1 << 32);
let probabilities = hist.iter().map(|&x| x as f64).collect::<Vec<_>>();
{
let fast =
ContiguousCategoricalEntropyModel::<u32, _, 32>::from_floating_point_probabilities_fast(
&probabilities,
None
)
.unwrap();
let kl_fast = verify_iterable_entropy_model(&fast, &hist, 1e-6);
let perfect =
ContiguousCategoricalEntropyModel::<u32, _, 32>::from_floating_point_probabilities_perfect(
&probabilities,
)
.unwrap();
let kl_perfect = verify_iterable_entropy_model(&perfect, &hist, 1e-6);
assert!(kl_perfect < kl_fast);
}
{
let fast =
DefaultContiguousCategoricalEntropyModel::from_floating_point_probabilities_fast(
&probabilities,
None,
)
.unwrap();
let kl_fast = verify_iterable_entropy_model(&fast, &hist, 1e-6);
let perfect =
DefaultContiguousCategoricalEntropyModel::from_floating_point_probabilities_perfect(
&probabilities,
)
.unwrap();
let kl_perfect = verify_iterable_entropy_model(&perfect, &hist, 1e-6);
assert!(kl_perfect < kl_fast);
}
}
#[test]
fn nontrivial_optimal_weights_f32() {
let hist = [
1u32, 186545, 237403, 295700, 361445, 433686, 509456, 586943, 663946, 737772, 1657269,
896675, 922197, 930672, 916665, 0, 0, 0, 0, 0, 723031, 650522, 572300, 494702, 418703,
347600, 1, 283500, 226158, 178194, 136301, 103158, 76823, 55540, 39258, 27988, 54269,
];
assert_ne!(hist.iter().map(|&x| x as u64).sum::<u64>(), 1 << 32);
let probabilities = hist.iter().map(|&x| x as f32).collect::<Vec<_>>();
{
let fast =
ContiguousCategoricalEntropyModel::<u32, _, 32>::from_floating_point_probabilities_fast(
&probabilities,
None
)
.unwrap();
let kl_fast = verify_iterable_entropy_model(&fast, &hist, 1e-6);
let perfect =
ContiguousCategoricalEntropyModel::<u32, _, 32>::from_floating_point_probabilities_perfect(
&probabilities,
)
.unwrap();
let kl_perfect = verify_iterable_entropy_model(&perfect, &hist, 1e-6);
assert!(kl_perfect < kl_fast);
}
{
let fast =
DefaultContiguousCategoricalEntropyModel::from_floating_point_probabilities_fast(
&probabilities,
None,
)
.unwrap();
let kl_fast = verify_iterable_entropy_model(&fast, &hist, 1e-6);
let perfect =
DefaultContiguousCategoricalEntropyModel::from_floating_point_probabilities_perfect(
&probabilities,
)
.unwrap();
let kl_perfect = verify_iterable_entropy_model(&perfect, &hist, 1e-6);
assert!(kl_perfect < kl_fast);
}
}
#[test]
fn perfect_converges() {
let example1 = [0.15, 0.69, 0.15];
let example2 = [
1.34673042e-04,
6.52306480e-04,
3.14999325e-03,
1.49921896e-02,
6.67127371e-02,
2.26679876e-01,
3.75356406e-01,
2.26679876e-01,
6.67127594e-02,
1.49922138e-02,
3.14990873e-03,
6.52299321e-04,
1.34715927e-04,
];
let categorical1 =
DefaultContiguousCategoricalEntropyModel::from_floating_point_probabilities_perfect(
&example1,
)
.unwrap();
let prob0 = categorical1.left_cumulative_and_probability(0).unwrap().1;
let prob2 = categorical1.left_cumulative_and_probability(2).unwrap().1;
assert!((-1..=1).contains(&(prob0.get() as i64 - prob2.get() as i64)));
verify_iterable_entropy_model(&categorical1, &example1, 1e-10);
let categorical2 =
DefaultContiguousCategoricalEntropyModel::from_floating_point_probabilities_perfect(
&example2,
)
.unwrap();
verify_iterable_entropy_model(&categorical2, &example2, 1e-10);
}
#[test]
fn contiguous_categorical() {
let hist = [
1u32, 186545, 237403, 295700, 361445, 433686, 509456, 586943, 663946, 737772, 1657269,
896675, 922197, 930672, 916665, 0, 0, 0, 0, 0, 723031, 650522, 572300, 494702, 418703,
347600, 1, 283500, 226158, 178194, 136301, 103158, 76823, 55540, 39258, 27988, 54269,
];
let probabilities = hist.iter().map(|&x| x as f64).collect::<Vec<_>>();
let fast =
ContiguousCategoricalEntropyModel::<u32, _, 32>::from_floating_point_probabilities_fast(
&probabilities,
None
)
.unwrap();
test_entropy_model(&fast, 0..probabilities.len());
let kl_fast = verify_iterable_entropy_model(&fast, &hist, 1e-8);
let perfect =
ContiguousCategoricalEntropyModel::<u32, _, 32>::from_floating_point_probabilities_perfect(
&probabilities,
)
.unwrap();
test_entropy_model(&perfect, 0..probabilities.len());
let kl_perfect = verify_iterable_entropy_model(&perfect, &hist, 1e-8);
assert!(kl_perfect < kl_fast);
}
}