pub use probability::distribution::Distribution;
pub use probability::distribution::Inverse;
mod categorical;
mod quantize;
mod uniform;
use core::{borrow::Borrow, hash::Hash};
use alloc::{boxed::Box, vec::Vec};
use num_traits::{float::FloatCore, AsPrimitive, One, Zero};
use crate::{BitArray, NonZeroBitArray};
pub trait EntropyModel<const PRECISION: usize> {
type Symbol;
type Probability: BitArray;
}
pub trait EncoderModel<const PRECISION: usize>: EntropyModel<PRECISION> {
fn left_cumulative_and_probability(
&self,
symbol: impl Borrow<Self::Symbol>,
) -> Option<(Self::Probability, <Self::Probability as BitArray>::NonZero)>;
#[inline]
fn floating_point_probability<F>(&self, symbol: Self::Symbol) -> F
where
F: FloatCore,
Self::Probability: Into<F>,
{
let whole = (F::one() + F::one()) * (Self::Probability::one() << (PRECISION - 1)).into();
let probability = self
.left_cumulative_and_probability(symbol)
.map_or(Self::Probability::zero(), |(_, p)| p.get());
probability.into() / whole
}
}
pub trait DecoderModel<const PRECISION: usize>: EntropyModel<PRECISION> {
fn quantile_function(
&self,
quantile: Self::Probability,
) -> (
Self::Symbol,
Self::Probability,
<Self::Probability as BitArray>::NonZero,
);
}
pub trait IterableEntropyModel<'m, const PRECISION: usize>: EntropyModel<PRECISION> {
fn symbol_table(
&'m self,
) -> impl Iterator<
Item = (
Self::Symbol,
Self::Probability,
<Self::Probability as BitArray>::NonZero,
),
>;
fn floating_point_symbol_table<F>(&'m self) -> impl Iterator<Item = (Self::Symbol, F, F)>
where
F: FloatCore + From<Self::Probability> + 'm,
Self::Probability: Into<F>,
{
let whole = (F::one() + F::one()) * (Self::Probability::one() << (PRECISION - 1)).into();
self.symbol_table()
.map(move |(symbol, cumulative, probability)| {
(
symbol,
cumulative.into() / whole,
probability.get().into() / whole,
)
})
}
fn entropy_base2<F>(&'m self) -> F
where
F: num_traits::Float + core::iter::Sum,
Self::Probability: Into<F>,
{
let scaled_shifted = self
.symbol_table()
.map(|(_, _, probability)| {
let probability = probability.get().into();
probability * probability.log2() })
.sum::<F>();
let whole = (F::one() + F::one()) * (Self::Probability::one() << (PRECISION - 1)).into();
F::from(PRECISION).unwrap() - scaled_shifted / whole
}
fn cross_entropy_base2<F>(&'m self, p: impl IntoIterator<Item = F>) -> F
where
F: num_traits::Float + core::iter::Sum,
Self::Probability: Into<F>,
{
let shift = F::from(PRECISION).unwrap();
self.symbol_table()
.zip(p)
.map(|((_, _, probability), p)| {
let probability = probability.get().into();
p * (shift - probability.log2()) })
.sum::<F>()
}
fn reverse_cross_entropy_base2<F>(&'m self, p: impl IntoIterator<Item = F>) -> F
where
F: num_traits::Float + core::iter::Sum,
Self::Probability: Into<F>,
{
let scaled = self
.symbol_table()
.zip(p)
.map(|((_, _, probability), p)| {
let probability = probability.get().into();
probability * p.log2()
})
.sum::<F>();
let whole = (F::one() + F::one()) * (Self::Probability::one() << (PRECISION - 1)).into();
-scaled / whole
}
fn kl_divergence_base2<F>(&'m self, p: impl IntoIterator<Item = F>) -> F
where
F: num_traits::Float + core::iter::Sum,
Self::Probability: Into<F>,
{
let shifted = self
.symbol_table()
.zip(p)
.map(|((_, _, probability), p)| {
if p == F::zero() {
F::zero()
} else {
let probability = probability.get().into();
p * (p.log2() - probability.log2())
}
})
.sum::<F>();
shifted + F::from(PRECISION).unwrap() }
fn reverse_kl_divergence_base2<F>(&'m self, p: impl IntoIterator<Item = F>) -> F
where
F: num_traits::Float + core::iter::Sum,
Self::Probability: Into<F>,
{
let scaled_shifted = self
.symbol_table()
.zip(p)
.map(|((_, _, probability), p)| {
let probability = probability.get().into();
probability * (probability.log2() - p.log2())
})
.sum::<F>();
let whole = (F::one() + F::one()) * (Self::Probability::one() << (PRECISION - 1)).into();
scaled_shifted / whole - F::from(PRECISION).unwrap()
}
#[inline(always)]
fn to_generic_encoder_model(
&'m self,
) -> NonContiguousCategoricalEncoderModel<Self::Symbol, Self::Probability, PRECISION>
where
Self::Symbol: Hash + Eq,
{
self.into()
}
#[inline(always)]
fn to_generic_decoder_model(
&'m self,
) -> NonContiguousCategoricalDecoderModel<
Self::Symbol,
Self::Probability,
Vec<(Self::Probability, Self::Symbol)>,
PRECISION,
>
where
Self::Symbol: Clone,
{
self.into()
}
#[inline(always)]
fn to_generic_lookup_decoder_model(
&'m self,
) -> NonContiguousLookupDecoderModel<
Self::Symbol,
Self::Probability,
Vec<(Self::Probability, Self::Symbol)>,
Box<[Self::Probability]>,
PRECISION,
>
where
Self::Probability: Into<usize>,
usize: AsPrimitive<Self::Probability>,
Self::Symbol: Clone + Default,
{
self.into()
}
}
impl<M, const PRECISION: usize> EntropyModel<PRECISION> for &M
where
M: EntropyModel<PRECISION> + ?Sized,
{
type Probability = M::Probability;
type Symbol = M::Symbol;
}
impl<M, const PRECISION: usize> EncoderModel<PRECISION> for &M
where
M: EncoderModel<PRECISION> + ?Sized,
{
#[inline(always)]
fn left_cumulative_and_probability(
&self,
symbol: impl Borrow<Self::Symbol>,
) -> Option<(Self::Probability, <Self::Probability as BitArray>::NonZero)> {
(*self).left_cumulative_and_probability(symbol)
}
}
impl<M, const PRECISION: usize> DecoderModel<PRECISION> for &M
where
M: DecoderModel<PRECISION> + ?Sized,
{
#[inline(always)]
fn quantile_function(
&self,
quantile: Self::Probability,
) -> (
Self::Symbol,
Self::Probability,
<Self::Probability as BitArray>::NonZero,
) {
(*self).quantile_function(quantile)
}
}
impl<'m, M, const PRECISION: usize> IterableEntropyModel<'m, PRECISION> for &'m M
where
M: IterableEntropyModel<'m, PRECISION>,
{
fn symbol_table(
&'m self,
) -> impl Iterator<
Item = (
Self::Symbol,
Self::Probability,
<Self::Probability as BitArray>::NonZero,
),
> {
(*self).symbol_table()
}
fn entropy_base2<F>(&'m self) -> F
where
F: num_traits::Float + core::iter::Sum,
Self::Probability: Into<F>,
{
(*self).entropy_base2()
}
#[inline(always)]
fn to_generic_encoder_model(
&'m self,
) -> NonContiguousCategoricalEncoderModel<Self::Symbol, Self::Probability, PRECISION>
where
Self::Symbol: Hash + Eq,
{
(*self).to_generic_encoder_model()
}
#[inline(always)]
fn to_generic_decoder_model(
&'m self,
) -> NonContiguousCategoricalDecoderModel<
Self::Symbol,
Self::Probability,
Vec<(Self::Probability, Self::Symbol)>,
PRECISION,
>
where
Self::Symbol: Clone,
{
(*self).to_generic_decoder_model()
}
}
pub use categorical::{
contiguous::{
ContiguousCategoricalEntropyModel, DefaultContiguousCategoricalEntropyModel,
SmallContiguousCategoricalEntropyModel,
},
lazy_contiguous::{
DefaultLazyContiguousCategoricalEntropyModel, LazyContiguousCategoricalEntropyModel,
SmallLazyContiguousCategoricalEntropyModel,
},
lookup_contiguous::{ContiguousLookupDecoderModel, SmallContiguousLookupDecoderModel},
lookup_noncontiguous::{NonContiguousLookupDecoderModel, SmallNonContiguousLookupDecoderModel},
non_contiguous::{
DefaultNonContiguousCategoricalDecoderModel, DefaultNonContiguousCategoricalEncoderModel,
NonContiguousCategoricalDecoderModel, NonContiguousCategoricalEncoderModel,
SmallNonContiguousCategoricalDecoderModel, SmallNonContiguousCategoricalEncoderModel,
},
};
pub use quantize::{
DefaultLeakyQuantizer, LeakilyQuantizedDistribution, LeakyQuantizer, SmallLeakyQuantizer,
};
pub use uniform::{DefaultUniformModel, SmallUniformModel, UniformModel};
#[cfg(test)]
mod tests {
use probability::prelude::*;
use super::*;
#[test]
fn entropy() {
#[cfg(not(miri))]
let (support, std_devs, means) = (-1000..=1000, [100., 200., 300.], [-10., 2.3, 50.1]);
#[cfg(miri)]
let (support, std_devs, means) = (-100..=100, [10., 20., 30.], [-1., 0.23, 5.01]);
let quantizer = LeakyQuantizer::<_, _, u32, 24>::new(support);
for &std_dev in &std_devs {
for &mean in &means {
let distribution = Gaussian::new(mean, std_dev);
let model = quantizer.quantize(distribution);
let entropy = model.entropy_base2::<f64>();
let expected_entropy = 2.047095585180641 + std_dev.log2();
assert!((entropy - expected_entropy).abs() < 0.01);
}
}
}
pub(super) fn test_entropy_model<'m, D, const PRECISION: usize>(
model: &'m D,
support: impl Clone + Iterator<Item = D::Symbol>,
) where
D: IterableEntropyModel<'m, PRECISION>
+ EncoderModel<PRECISION>
+ DecoderModel<PRECISION>
+ 'm,
D::Symbol: Copy + core::fmt::Debug + PartialEq,
D::Probability: Into<u64>,
u64: AsPrimitive<D::Probability>,
{
let mut sum = 0;
for symbol in support.clone() {
let (left_cumulative, prob) = model.left_cumulative_and_probability(symbol).unwrap();
assert_eq!(left_cumulative.into(), sum);
sum += prob.get().into();
let expected = (symbol, left_cumulative, prob);
assert_eq!(model.quantile_function(left_cumulative), expected);
assert_eq!(model.quantile_function((sum - 1).as_()), expected);
assert_eq!(
model.quantile_function((left_cumulative.into() + prob.get().into() / 2).as_()),
expected
);
}
assert_eq!(sum, 1 << PRECISION);
test_iterable_entropy_model(model, support);
}
pub(super) fn test_iterable_entropy_model<'m, M, const PRECISION: usize>(
model: &'m M,
support: impl Clone + Iterator<Item = M::Symbol>,
) where
M: IterableEntropyModel<'m, PRECISION> + 'm,
M::Symbol: Copy + core::fmt::Debug + PartialEq,
M::Probability: Into<u64>,
u64: AsPrimitive<M::Probability>,
{
let mut expected_cumulative = 0u64;
let mut count = 0;
for (expected_symbol, (symbol, left_sided_cumulative, probability)) in
support.clone().zip(model.symbol_table())
{
assert_eq!(symbol, expected_symbol);
assert_eq!(left_sided_cumulative.into(), expected_cumulative);
expected_cumulative += probability.get().into();
count += 1;
}
assert_eq!(count, support.size_hint().0);
assert_eq!(expected_cumulative, 1 << PRECISION);
}
pub(super) fn verify_iterable_entropy_model<'m, M, P, const PRECISION: usize>(
model: &'m M,
hist: &[P],
tol: f64,
) -> f64
where
M: IterableEntropyModel<'m, PRECISION> + 'm,
M::Probability: BitArray + Into<u64> + Into<f64>,
P: num_traits::Zero + Into<f64> + Copy + PartialOrd,
{
let weights: Vec<_> = model
.symbol_table()
.map(|(_, _, probability)| probability.get())
.collect();
assert_eq!(weights.len(), hist.len());
assert_eq!(
weights.iter().map(|&x| Into::<u64>::into(x)).sum::<u64>(),
1 << PRECISION
);
for &w in &weights {
assert!(w > M::Probability::zero());
}
let mut weights_and_hist = weights
.iter()
.cloned()
.zip(hist.iter().cloned())
.collect::<Vec<_>>();
weights_and_hist.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
let mut previous = P::zero();
for (_, hist) in weights_and_hist {
assert!(hist >= previous);
previous = hist;
}
let normalization = hist.iter().map(|&x| x.into()).sum::<f64>();
let normalized_hist = hist
.iter()
.map(|&x| Into::<f64>::into(x) / normalization)
.collect::<Vec<_>>();
let kl = model.kl_divergence_base2::<f64>(normalized_hist);
assert!(kl < tol);
kl
}
}