constriction/stream/model/
uniform.rs

1use core::borrow::Borrow;
2
3use num_traits::AsPrimitive;
4
5use crate::{generic_static_asserts, wrapping_pow2, BitArray, NonZeroBitArray};
6
7use super::{DecoderModel, EncoderModel, EntropyModel, IterableEntropyModel};
8
9/// Type alias for a typical [`UniformModel`].
10///
11/// See:
12/// - [`UniformModel`]
13/// - [discussion of presets](crate::stream#presets)
14pub type DefaultUniformModel = UniformModel<u32, 24>;
15
16/// Type alias for a [`UniformModel`] that is easier to use within a sequence of compressed symbols
17/// that also involves some lookup models.
18///
19/// See:
20/// - [`UniformModel`]
21/// - [discussion of presets](crate::stream#presets)
22pub type SmallUniformModel = UniformModel<u16, 12>;
23
24#[derive(Debug, Clone, Copy)]
25pub struct UniformModel<Probability: BitArray, const PRECISION: usize> {
26    probability_per_bin: Probability::NonZero,
27    last_symbol: Probability,
28}
29
30impl<Probability: BitArray, const PRECISION: usize> UniformModel<Probability, PRECISION> {
31    pub fn new(range: usize) -> Self
32    where
33        usize: AsPrimitive<Probability>,
34        Probability: AsPrimitive<usize>,
35    {
36        generic_static_asserts!(
37            (Probability: BitArray; const PRECISION: usize);
38            PROBABILITY_MUST_SUPPORT_PRECISION: PRECISION <= Probability::BITS;
39            USIZE_MUST_SUPPORT_PRECISION: PRECISION <= <usize as BitArray>::BITS;
40            PRECISION_MUST_BE_NONZERO: PRECISION > 0;
41        );
42
43        assert!(range > 1); // We don't support degenerate probability distributions (i.e. range=1).
44        let range = unsafe { range.into_nonzero_unchecked() }; // For performance hint.
45        let last_symbol_usize = NonZeroBitArray::get(range) - 1;
46        let last_symbol = last_symbol_usize.as_();
47        assert!(
48            last_symbol
49                <= wrapping_pow2::<Probability>(PRECISION).wrapping_sub(&Probability::one())
50                && last_symbol.as_() == last_symbol_usize
51        );
52
53        if PRECISION == Probability::BITS {
54            let probability_per_bin = (wrapping_pow2::<usize>(PRECISION)
55                .wrapping_sub(NonZeroBitArray::get(range))
56                / NonZeroBitArray::get(range))
57            .as_()
58                + Probability::one();
59            unsafe {
60                Self {
61                    probability_per_bin: probability_per_bin.into_nonzero_unchecked(),
62                    last_symbol,
63                }
64            }
65        } else {
66            let probability_per_bin =
67                (Probability::one() << PRECISION) / NonZeroBitArray::get(range).as_();
68            let probability_per_bin = probability_per_bin
69                .into_nonzero()
70                .expect("range <= (1 << PRECISION)");
71            Self {
72                probability_per_bin,
73                last_symbol,
74            }
75        }
76    }
77}
78
79impl<Probability: BitArray, const PRECISION: usize> EntropyModel<PRECISION>
80    for UniformModel<Probability, PRECISION>
81{
82    type Symbol = usize;
83    type Probability = Probability;
84}
85
86impl<Probability: BitArray, const PRECISION: usize> EncoderModel<PRECISION>
87    for UniformModel<Probability, PRECISION>
88where
89    usize: AsPrimitive<Probability>,
90{
91    fn left_cumulative_and_probability(
92        &self,
93        symbol: impl Borrow<Self::Symbol>,
94    ) -> Option<(Self::Probability, <Self::Probability as BitArray>::NonZero)> {
95        let symbol = symbol.borrow().as_();
96        let left_cumulative = symbol.wrapping_mul(&self.probability_per_bin.get());
97
98        #[allow(clippy::comparison_chain)]
99        if symbol < self.last_symbol {
100            // Most common case.
101            Some((left_cumulative, self.probability_per_bin))
102        } else if symbol == self.last_symbol {
103            // Less common but possible case.
104            let probability =
105                wrapping_pow2::<Probability>(PRECISION).wrapping_sub(&left_cumulative);
106            let probability = unsafe { probability.into_nonzero_unchecked() };
107            Some((left_cumulative, probability))
108        } else {
109            // Least common case.
110            None
111        }
112    }
113}
114
115impl<Probability: BitArray, const PRECISION: usize> DecoderModel<PRECISION>
116    for UniformModel<Probability, PRECISION>
117where
118    Probability: AsPrimitive<usize>,
119{
120    fn quantile_function(
121        &self,
122        quantile: Self::Probability,
123    ) -> (
124        Self::Symbol,
125        Self::Probability,
126        <Self::Probability as BitArray>::NonZero,
127    ) {
128        let symbol_guess = quantile / self.probability_per_bin.get(); // Might be 1 too large for last symbol.
129        let remainder = quantile % self.probability_per_bin.get();
130        if symbol_guess < self.last_symbol {
131            (
132                symbol_guess.as_(),
133                quantile - remainder,
134                self.probability_per_bin,
135            )
136        } else {
137            let left_cumulative = self.last_symbol * self.probability_per_bin.get();
138            let prob = wrapping_pow2::<Probability>(PRECISION).wrapping_sub(&left_cumulative);
139            let prob = unsafe {
140                // SAFETY: prob can't be zero because we have a `quantile` that is contained in its interval.
141                prob.into_nonzero_unchecked()
142            };
143            (self.last_symbol.as_(), left_cumulative, prob)
144        }
145    }
146}
147
148impl<'m, Probability: BitArray, const PRECISION: usize> IterableEntropyModel<'m, PRECISION>
149    for UniformModel<Probability, PRECISION>
150where
151    Probability: AsPrimitive<usize>,
152    usize: AsPrimitive<Probability>,
153{
154    fn symbol_table(
155        &'m self,
156    ) -> impl Iterator<
157        Item = (
158            Self::Symbol,
159            Self::Probability,
160            <Self::Probability as BitArray>::NonZero,
161        ),
162    > {
163        // The following doesn't truncate on the conversion or overflow on the addition because it
164        // inverts an operation that was performed in the constructor (which checked for both
165        // potential sources of error).
166        let last_symbol = self.last_symbol.as_();
167        let range = last_symbol + 1;
168        let probability_per_bin = self.probability_per_bin;
169
170        (0..range).map(move |symbol| {
171            let left_cumulative = symbol.as_() * probability_per_bin.get();
172            let probability = if symbol != last_symbol {
173                probability_per_bin
174            } else {
175                let probability =
176                    wrapping_pow2::<Probability>(PRECISION).wrapping_sub(&left_cumulative);
177
178                // SAFETY: the constructor ensures that `range < 2^PRECISION`, so every bin has a
179                // nonzero probability mass.
180                unsafe { probability.into_nonzero_unchecked() }
181            };
182
183            (symbol, left_cumulative, probability)
184        })
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191
192    use super::super::tests::test_entropy_model;
193
194    #[test]
195    fn uniform() {
196        for range in [2, 3, 4, 5, 6, 7, 8, 9, 62, 63, 64, 254, 255, 256] {
197            test_entropy_model(&UniformModel::<u32, 24>::new(range), 0..range);
198            test_entropy_model(&UniformModel::<u32, 32>::new(range), 0..range);
199            test_entropy_model(&UniformModel::<u16, 12>::new(range), 0..range);
200            test_entropy_model(&UniformModel::<u16, 16>::new(range), 0..range);
201            if range < 255 {
202                test_entropy_model(&UniformModel::<u8, 8>::new(range), 0..range);
203            }
204            if range <= 64 {
205                test_entropy_model(&UniformModel::<u8, 6>::new(range), 0..range);
206            }
207        }
208    }
209}