constriction/stream/model/categorical/
lazy_contiguous.rs

1use core::{borrow::Borrow, marker::PhantomData};
2
3use alloc::vec::Vec;
4use num_traits::{float::FloatCore, AsPrimitive};
5
6use crate::{generic_static_asserts, wrapping_pow2, BitArray};
7
8use super::super::{DecoderModel, EncoderModel, EntropyModel};
9
10/// Type alias for a typical [`LazyContiguousCategoricalEntropyModel`].
11///
12/// See:
13/// - [`LazyContiguousCategoricalEntropyModel`]
14/// - [discussion of presets](crate::stream#presets)
15pub type DefaultLazyContiguousCategoricalEntropyModel<F = f32, Pmf = Vec<F>> =
16    LazyContiguousCategoricalEntropyModel<u32, F, Pmf, 24>;
17
18/// Type alias for a [`LazyContiguousCategoricalEntropyModel`] that can be used with coders that use
19/// `u16` for their word size.
20///
21/// Note that, unlike the other type aliases with the `Small...` prefix, creating a lookup table for
22/// a *lazy* categorical model is rarely useful. Lazy models are optimized for applications where a
23/// model gets used only a few times (e.g., as a part of an autoregressive model) whereas lookup
24/// tables are useful if you use the same model lots of times.
25///
26/// See:
27/// - [`LazyContiguousCategoricalEntropyModel`]
28/// - [discussion of presets](crate::stream#presets)
29pub type SmallLazyContiguousCategoricalEntropyModel<F = f32, Pmf = Vec<F>> =
30    LazyContiguousCategoricalEntropyModel<u16, F, Pmf, 12>;
31
32/// Lazily constructed variant of [`ContiguousCategoricalEntropyModel`]
33///
34/// This type is similar to [`ContiguousCategoricalEntropyModel`], and data encoded with
35/// either of the two models can be decoded with either of the two models (provided the both
36/// models are constructed with constructors with the same name; see [compatibility table
37/// for `ContiguousCategoricalEntropyModel`]).
38///
39/// The difference between this type and `ContiguousCategoricalEntropyModel` is that this
40/// type is lazy, i.e., it delays most of the calculation necessary for approximating a
41/// given floating-point probability mass function into fixed-point precision to encoding or
42/// decoding time (and then only does the work necessary for the models that actually get
43/// encoded or decoded).
44///
45/// # When Should I Use This Type of Entropy Model?
46///
47/// - Use this type if you want to encode or decode only a few (or even just a single)
48///   symbol with the same categorical distribution.
49/// - Use [`ContiguousCategoricalEntropyModel`], [`NonContiguousCategoricalEncoderModel`],
50///   or [`NonContiguousCategoricalDecoderModel`] if you want to encode several symbols with
51///   the same categorical distribution. These models precalculate the fixed-point
52///   approximation of the entire cumulative distribution function at model construction, so
53///   that the calculation doesn't have to be done at every encoding/decoding step.
54/// - Use [`ContiguousLookupDecoderModel`] or [`NonContiguousLookupDecoderModel`] (together
55///   with a small `Probability` data type, see [discussion of presets]) for decoding a
56///   *very* large number of i.i.d. symbols if runtime is more important to you than
57///   near-optimal bit rate. These models create a lookup table that maps all `2^PRECISION`
58///   possible quantiles to the corresponding symbol, thus eliminating the need for a binary
59///   search over the CDF at decoding time.
60///
61/// # Computational Efficiency
62///
63/// For a probability distribution with a support of `N` symbols, a
64/// `LazyContiguousCategoricalEntropyModel` has the following asymptotic costs:
65///
66/// - creation:
67///   - runtime cost: `Θ(1)` if the normalization constant is known and provided, `O(N)`
68///     otherwise (but still faster by a constant factor than creating a
69///     [`ContiguousCategoricalEntropyModel`] from floating point probabilities);
70///   - memory footprint: `Θ(N)`;
71///   - both are cheaper by a constant factor than for a
72///     [`NonContiguousCategoricalEncoderModel`] or a
73///     [`NonContiguousCategoricalDecoderModel`].
74/// - encoding a symbol (calling [`EncoderModel::left_cumulative_and_probability`]):
75///   - runtime cost: `Θ(1)` (cheaper than for [`NonContiguousCategoricalEncoderModel`]
76///     since it compiles to a simple array lookup rather than a `HashMap` lookup)
77///   - memory footprint: no heap allocations, constant stack space.
78/// - decoding a symbol (calling [`DecoderModel::quantile_function`]):
79///   - runtime cost: `Θ(log(N))` (both expected and worst-case; probably slightly cheaper
80///     than for [`NonContiguousCategoricalDecoderModel`] due to better memory locality)
81///   - memory footprint: no heap allocations, constant stack space.
82///
83/// # Why is there no `NonContiguous` variant of this model?
84///
85/// In contrast to `NonContiguousCategorical{En, De}coderModel`, there is no `NonContiguous`
86/// variant of this type. A `NonContiguous` variant of this type would offer no improvement
87/// in runtime performance compared to using this type
88/// (`LazyContiguousCategoricalEntropyModel`) together with a HashMap or Array (for encoding
89/// or decoding, respectively) to map between a non-contiguous alphabet and a contiguous set
90/// of indices. (This is different for `NonContiguousCategorical{En, De}coderModel`, which
91/// avoid an otherwise additional array lookup).
92///
93/// [`ContiguousCategoricalEntropyModel`]:
94///     crate::stream::model::ContiguousCategoricalEntropyModel
95/// [`NonContiguousCategoricalEncoderModel`]:
96///     crate::stream::model::NonContiguousCategoricalEncoderModel
97/// [`NonContiguousCategoricalDecoderModel`]:
98///     crate::stream::model::NonContiguousCategoricalDecoderModel
99/// [`ContiguousLookupDecoderModel`]: crate::stream::model::ContiguousLookupDecoderModel
100/// [`NonContiguousLookupDecoderModel`]:
101///     crate::stream::model::NonContiguousLookupDecoderModel
102/// [compatibility table for `ContiguousCategoricalEntropyModel`]:
103///     crate::stream::model::ContiguousCategoricalEntropyModel#compatibility-table
104/// [discussion of presets]: crate::stream#presets
105#[derive(Debug, Clone, Copy)]
106pub struct LazyContiguousCategoricalEntropyModel<Probability, F, Pmf, const PRECISION: usize> {
107    /// Invariants:
108    /// - `pmf.len() >= 2`
109    pmf: Pmf,
110    scale: F,
111    phantom: PhantomData<Probability>,
112}
113
114impl<Probability, F, Pmf, const PRECISION: usize>
115    LazyContiguousCategoricalEntropyModel<Probability, F, Pmf, PRECISION>
116where
117    Probability: BitArray,
118    F: FloatCore + core::iter::Sum<F>,
119    Pmf: AsRef<[F]>,
120{
121    /// Lazily constructs a leaky distribution whose PMF approximates given probabilities.
122    ///
123    /// Equivalent (and binary compatible to) the [constructor for
124    /// `ContiguousCategoricalEntropyModel` with the same
125    /// name](crate::stream::model::ContiguousCategoricalEntropyModel::from_floating_point_probabilities_fast).
126    /// However, this constructor is lazy, i.e., it delays most of the calculation necessary
127    /// for approximating the given `probabilities` into fixed-point precision to encoding
128    /// or decoding time (and then only does the work necessary for the models that actually
129    /// get encoded or decoded). See [struct documentation](Self).
130    #[allow(clippy::result_unit_err)]
131    pub fn from_floating_point_probabilities_fast(
132        probabilities: Pmf,
133        normalization: Option<F>,
134    ) -> Result<Self, ()>
135    where
136        F: AsPrimitive<Probability>,
137        Probability: AsPrimitive<usize>,
138        usize: AsPrimitive<Probability> + AsPrimitive<F>,
139    {
140        generic_static_asserts!(
141            (Probability: BitArray; const PRECISION: usize);
142            PROBABILITY_MUST_SUPPORT_PRECISION: PRECISION <= Probability::BITS;
143            PRECISION_MUST_BE_NONZERO: PRECISION > 0;
144        );
145
146        let probs = probabilities.as_ref();
147
148        if probs.len() < 2 || probs.len() >= wrapping_pow2::<usize>(PRECISION).wrapping_sub(1) {
149            return Err(());
150        }
151
152        let remaining_free_weight =
153            wrapping_pow2::<Probability>(PRECISION).wrapping_sub(&probs.len().as_());
154        let normalization =
155            normalization.unwrap_or_else(|| probabilities.as_ref().iter().copied().sum::<F>());
156        if !normalization.is_normal() || !normalization.is_sign_positive() {
157            return Err(());
158        }
159
160        let scale = AsPrimitive::<F>::as_(remaining_free_weight.as_()) / normalization;
161
162        Ok(Self {
163            pmf: probabilities,
164            scale,
165            phantom: PhantomData,
166        })
167    }
168
169    /// Returns the number of symbols supported by the model.
170    ///
171    /// The distribution is defined on the contiguous range of symbols from zero
172    /// (inclusively) to `support_size()` (exclusively). All symbols within this range are
173    /// guaranteed to have a nonzero probability, while all symbols outside of this range
174    /// have a zero probability.
175    #[inline(always)]
176    pub fn support_size(&self) -> usize {
177        self.pmf.as_ref().len()
178    }
179
180    /// Makes a very cheap shallow copy of the model that can be used much like a shared
181    /// reference.
182    ///
183    /// The returned `LazyContiguousCategoricalEntropyModel` implements `Copy`, which is a
184    /// requirement for some methods, such as [`Encode::encode_iid_symbols`] or
185    /// [`Decode::decode_iid_symbols`]. These methods could also accept a shared reference
186    /// to a `LazyContiguousCategoricalEntropyModel` (since all references to entropy models are
187    /// also entropy models, and all shared references implement `Copy`), but passing a
188    /// *view* instead may be slightly more efficient because it avoids one level of
189    /// dereferencing.
190    ///
191    /// Note that `LazyContiguousCategoricalEntropyModel` is optimized for models that are used
192    /// only rarely (often just a single time). Thus, if you find yourself handing out lots of
193    /// views to the same `LazyContiguousCategoricalEntropyModel` then you'd likely be better off
194    /// using a [`ContiguousCategoricalEntropyModel`] instead.
195    ///
196    /// [`Encode::encode_iid_symbols`]: crate::stream::Encode::encode_iid_symbols
197    /// [`Decode::decode_iid_symbols`]: crate::stream::Decode::decode_iid_symbols
198    /// [`ContiguousCategoricalEntropyModel`]: crate::stream::model::ContiguousCategoricalEntropyModel
199    #[inline]
200    pub fn as_view(
201        &self,
202    ) -> LazyContiguousCategoricalEntropyModel<Probability, F, &[F], PRECISION> {
203        LazyContiguousCategoricalEntropyModel {
204            pmf: self.pmf.as_ref(),
205            scale: self.scale,
206            phantom: PhantomData,
207        }
208    }
209}
210
211impl<Probability, F, Pmf, const PRECISION: usize> EntropyModel<PRECISION>
212    for LazyContiguousCategoricalEntropyModel<Probability, F, Pmf, PRECISION>
213where
214    Probability: BitArray,
215{
216    type Symbol = usize;
217    type Probability = Probability;
218}
219
220impl<Probability, F, Pmf, const PRECISION: usize> EncoderModel<PRECISION>
221    for LazyContiguousCategoricalEntropyModel<Probability, F, Pmf, PRECISION>
222where
223    Probability: BitArray,
224    F: FloatCore + core::iter::Sum<F> + AsPrimitive<Probability>,
225    usize: AsPrimitive<Probability>,
226    Pmf: AsRef<[F]>,
227{
228    fn left_cumulative_and_probability(
229        &self,
230        symbol: impl Borrow<Self::Symbol>,
231    ) -> Option<(Self::Probability, <Self::Probability as BitArray>::NonZero)> {
232        let symbol = *symbol.borrow();
233        let pmf = self.pmf.as_ref();
234        let probability_float = *pmf.get(symbol)?;
235
236        // SAFETY: when we initialized `probability_float`, we checked if `symbol` is out of bounds.
237        let left_side = unsafe { pmf.get_unchecked(..symbol) };
238        let left_cumulative_float = left_side.iter().copied().sum::<F>();
239        let left_cumulative = (left_cumulative_float * self.scale).as_() + symbol.as_();
240
241        // It may seem easier to calculate `probability` directly from `probability_float` but
242        // this could pick up different rounding errors, breaking guarantees of `EncoderModel`.
243        let right_cumulative_float = left_cumulative_float + probability_float;
244        let right_cumulative: Probability = if symbol == pmf.len() - 1 {
245            // We have to treat the last symbol as a special case since standard treatment could
246            // lead to an inaccessible last quantile due to rounding errors.
247            wrapping_pow2(PRECISION)
248        } else {
249            (right_cumulative_float * self.scale).as_() + symbol.as_() + Probability::one()
250        };
251        let probability = right_cumulative
252            .wrapping_sub(&left_cumulative)
253            .into_nonzero()
254            .expect("leakiness should guarantee nonzero probabilities.");
255
256        Some((left_cumulative, probability))
257    }
258}
259
260impl<Probability, F, Pmf, const PRECISION: usize> DecoderModel<PRECISION>
261    for LazyContiguousCategoricalEntropyModel<Probability, F, Pmf, PRECISION>
262where
263    F: FloatCore + core::iter::Sum<F> + AsPrimitive<Probability>,
264    usize: AsPrimitive<Probability>,
265    Probability: BitArray + AsPrimitive<F>,
266    Pmf: AsRef<[F]>,
267{
268    fn quantile_function(
269        &self,
270        quantile: Self::Probability,
271    ) -> (
272        Self::Symbol,
273        Self::Probability,
274        <Self::Probability as BitArray>::NonZero,
275    ) {
276        // We avoid division completely and float-to-int conversion as much as possible here
277        // because they are slow.
278
279        let mut left_cumulative_float = F::zero();
280        let mut right_cumulative_float = F::zero();
281
282        // First, skip any symbols where we can conclude even without any expensive float-to-int
283        // conversions that are too early. We slightly over-estimate `self.scale` so that any
284        // mismatch in rounding errors can only make our bound more conservative.
285        let enlarged_scale = (F::one() + F::epsilon() + F::epsilon()) * self.scale;
286        let lower_bound =
287            quantile.saturating_sub(self.pmf.as_ref().len().as_()).as_() / enlarged_scale;
288
289        let mut iter = self.pmf.as_ref().iter();
290        let mut next_symbol = 0usize;
291        for &next_probability in &mut iter {
292            next_symbol = next_symbol.wrapping_add(1);
293            left_cumulative_float = right_cumulative_float;
294            right_cumulative_float = right_cumulative_float + next_probability;
295            if right_cumulative_float >= lower_bound {
296                break;
297            }
298        }
299
300        // Then search for the correct `symbol` using the same float-to-int conversions as in
301        // `EncoderModel::left_cumulative_and_probability`.
302        let mut left_cumulative =
303            (left_cumulative_float * self.scale).as_() + next_symbol.wrapping_sub(1).as_();
304
305        for &next_probability in &mut iter {
306            let right_cumulative = (right_cumulative_float * self.scale).as_() + next_symbol.as_();
307            if right_cumulative > quantile {
308                let probability = right_cumulative
309                    .wrapping_sub(&left_cumulative)
310                    .into_nonzero()
311                    .expect("leakiness should guarantee nonzero probabilities.");
312                return (next_symbol.wrapping_sub(1), left_cumulative, probability);
313            }
314
315            left_cumulative = right_cumulative;
316
317            right_cumulative_float = right_cumulative_float + next_probability;
318            next_symbol = next_symbol.wrapping_add(1);
319        }
320
321        // We have to treat the last symbol as a special case since standard treatment could
322        // lead to an inaccessible last quantile due to rounding errors.
323        let right_cumulative = wrapping_pow2::<Probability>(PRECISION);
324        let probability = right_cumulative
325            .wrapping_sub(&left_cumulative)
326            .into_nonzero()
327            .expect("leakiness should guarantee nonzero probabilities.");
328
329        (next_symbol.wrapping_sub(1), left_cumulative, probability)
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336
337    #[test]
338    fn lazy_contiguous_categorical() {
339        #[allow(clippy::excessive_precision)]
340        let unnormalized_probs: [f32; 30] = [
341            4.22713972, 1e-20, 0.22221771, 0.00927659, 1.58383270, 0.95804675, 0.78104103,
342            0.81518454, 0.75206966, 0.58559047, 0.00024284, 1.81382388, 3.22535052, 0.77940434,
343            0.24507986, 0.07767093, 0.0, 0.11429778, 0.00179474, 0.30613952, 0.72192056,
344            0.00778274, 0.18957551, 10.2402638, 3.36959484, 0.02624742, 1.85103708, 0.25614601,
345            0.09754817, 0.27998250,
346        ];
347        let normalization = 33.538302;
348
349        const PRECISION: usize = 32;
350        let model =
351            LazyContiguousCategoricalEntropyModel::<u32, _,_, PRECISION>::from_floating_point_probabilities_fast(
352                &unnormalized_probs,
353                None,
354            ).unwrap();
355
356        let mut sum: u64 = 0;
357        for (symbol, &unnormalized_prob) in unnormalized_probs.iter().enumerate() {
358            let (left_cumulative, prob) = model.left_cumulative_and_probability(symbol).unwrap();
359            assert_eq!(left_cumulative as u64, sum);
360            let float_prob = prob.get() as f32 / (1u64 << PRECISION) as f32;
361            assert!((float_prob - unnormalized_prob / normalization).abs() < 1e-6);
362            sum += prob.get() as u64;
363
364            let expected = (symbol, left_cumulative, prob);
365            assert_eq!(model.quantile_function(left_cumulative), expected);
366            assert_eq!(model.quantile_function((sum - 1).as_()), expected);
367            assert_eq!(
368                model.quantile_function((left_cumulative as u64 + prob.get() as u64 / 2) as u32),
369                expected
370            );
371        }
372        assert_eq!(sum, 1 << PRECISION);
373    }
374}