constriction/stream/
stack.rs

1//! Fast and Near-optimal compression on a stack ("last in first out")
2//!
3//! This module provides the [`AnsCoder`], a highly efficient entropy coder with
4//! near-optimal compression effectiveness that operates as a *stack* data structure. It
5//! implements the Asymmetric Numeral Systems (ANS) compression algorithm \[1].
6//!
7//! # Comparison to sister module `queue`
8//!
9//! ANS Coding operates as a stack, which means that encoding and decoding operate in
10//! reverse direction with respect to each other. The provided implementation of ANS Coding
11//! uses a single data structure, the [`AnsCoder`], for both encoding and decoding. It
12//! allows you to interleave encoding and decoding operations arbitrarily, which is in
13//! contrast to the situation in the sister module [`queue`] and important for advanced
14//! compression techniques such as bits-back coding in hierarchical probabilistic models.
15//!
16//! The parent module contains a more detailed discussion of the [differences between ANS
17//! Coding and Range Coding](super#which-stream-code-should-i-use) .
18//!
19//! # References
20//!
21//! \[1] Duda, Jarek, et al. "The use of asymmetric numeral systems as an accurate
22//! replacement for Huffman coding." 2015 Picture Coding Symposium (PCS). IEEE, 2015.
23//!
24//! [`queue`]: super::queue
25
26use alloc::vec::Vec;
27use core::{
28    borrow::Borrow, convert::Infallible, fmt::Debug, iter::Fuse, marker::PhantomData, ops::Deref,
29};
30use num_traits::AsPrimitive;
31
32use super::{
33    model::{DecoderModel, EncoderModel},
34    AsDecoder, Code, Decode, Encode, IntoDecoder, TryCodingError,
35};
36use crate::{
37    backends::{
38        self, AsReadWords, AsSeekReadWords, BoundedReadWords, Cursor, FallibleIteratorReadWords,
39        IntoReadWords, IntoSeekReadWords, ReadWords, Reverse, WriteWords,
40    },
41    bit_array_to_chunks_truncated, generic_static_asserts, BitArray, CoderError,
42    DefaultEncoderError, DefaultEncoderFrontendError, NonZeroBitArray, Pos, PosSeek, Seek, Stack,
43    UnwrapInfallible,
44};
45
46/// Entropy coder for both encoding and decoding on a stack.
47///
48/// This is the generic struct for an ANS coder. It provides fine-tuned control over type
49/// parameters (see [discussion in parent
50/// module](super#highly-customizable-implementations-with-sane-presets)). You'll usually
51/// want to use this type through the type alias [`DefaultAnsCoder`], which provides sane
52/// default settings for the type parameters.
53///
54/// The `AnsCoder` uses an entropy coding algorithm called [range Asymmetric
55/// Numeral Systems (rANS)]. This means that it operates as a stack, i.e., a "last
56/// in first out" data structure: encoding "pushes symbols on" the stack and
57/// decoding "pops symbols off" the stack in reverse order. In default operation, decoding
58/// with an `AnsCoder` *consumes* the compressed data for the decoded symbols (however, you
59/// can also decode immutable data by using a [`Cursor`]). This means
60/// that encoding and decoding can be interleaved arbitrarily, thus growing and shrinking
61/// the stack of compressed data as you go.
62///
63/// # Example
64///
65/// Basic usage example:
66///
67/// ```
68/// use constriction::stream::{model::DefaultLeakyQuantizer, stack::DefaultAnsCoder, Decode};
69///
70/// // `DefaultAnsCoder` is a type alias to `AnsCoder` with sane generic parameters.
71/// let mut ans = DefaultAnsCoder::new();
72///
73/// // Create an entropy model based on a quantized Gaussian distribution. You can use `AnsCoder`
74/// // with any entropy model defined in the `models` module.
75/// let quantizer = DefaultLeakyQuantizer::new(-100..=100);
76/// let entropy_model = quantizer.quantize(probability::distribution::Gaussian::new(0.0, 10.0));
77///
78/// let symbols = vec![-10, 4, 0, 3];
79/// // Encode symbols in *reverse* order, so that we can decode them in forward order.
80/// ans.encode_iid_symbols_reverse(&symbols, &entropy_model).unwrap();
81///
82/// // Obtain temporary shared access to the compressed bit string. If you want ownership of the
83/// // compressed bit string, call `.into_compressed()` instead of `.get_compressed()`.
84/// println!("Encoded into {} bits: {:?}", ans.num_bits(), &*ans.get_compressed().unwrap());
85///
86/// // Decode the symbols and verify correctness.
87/// let reconstructed = ans
88///     .decode_iid_symbols(4, &entropy_model)
89///     .collect::<Result<Vec<_>, _>>()
90///     .unwrap();
91/// assert_eq!(reconstructed, symbols);
92/// ```
93///
94/// # Consistency Between Encoding and Decoding
95///
96/// As elaborated in the [parent module's documentation](super#whats-a-stream-code),
97/// encoding and decoding operates on a sequence of symbols. Each symbol can be encoded and
98/// decoded with its own entropy model (the symbols can even have heterogeneous types). If
99/// your goal is to reconstruct the originally encoded symbols during decoding, then you
100/// must employ the same sequence of entropy models (in reversed order) during encoding and
101/// decoding.
102///
103/// However, using the same entropy models for encoding and decoding is not a *general*
104/// requirement. It is perfectly legal to push (encode) symbols on the `AnsCoder` using some
105/// entropy models, and then pop off (decode) symbols using different entropy models. The
106/// popped off symbols will then in general be different from the original symbols, but will
107/// be generated in a deterministic way. If there is no deterministic relation between the
108/// entropy models used for pushing and popping, and if there is still compressed data left
109/// at the end (i.e., if [`is_empty`] returns false), then the popped off symbols are, to a
110/// very good approximation, distributed as independent samples from the respective entropy
111/// models. Such random samples, which consume parts of the compressed data, are useful in
112/// the bits-back algorithm.
113///
114/// [range Asymmetric Numeral Systems (rANS)]:
115/// https://en.wikipedia.org/wiki/Asymmetric_numeral_systems#Range_variants_(rANS)_and_streaming
116/// [`is_empty`]: #method.is_empty`
117/// [`Cursor`]: crate::backends::Cursor
118#[derive(Clone)]
119pub struct AnsCoder<Word, State, Backend = Vec<Word>>
120where
121    Word: BitArray + Into<State>,
122    State: BitArray + AsPrimitive<Word>,
123{
124    bulk: Backend,
125
126    /// Invariant: `state >= State::one() << (State::BITS - Word::BITS)` unless
127    /// `bulk.is_empty()`.
128    state: State,
129
130    /// We keep track of the `Word` type so that we can statically enforce the invariant
131    /// `Word: Into<State>`.
132    phantom: PhantomData<Word>,
133}
134
135/// Type alias for an [`AnsCoder`] with sane parameters for typical use cases.
136///
137/// This type alias sets the generic type arguments `Word` and `State` to sane values for
138/// many typical use cases.
139pub type DefaultAnsCoder<Backend = Vec<u32>> = AnsCoder<u32, u64, Backend>;
140
141/// Type alias for an [`AnsCoder`] for use with a [`ContiguousLookupDecoderModel`] or [`NonContiguousLookupDecoderModel`]
142///
143/// This encoder has a smaller word size and internal state than [`AnsCoder`]. It is
144/// optimized for use with a [`ContiguousLookupDecoderModel`] or [`NonContiguousLookupDecoderModel`].
145///
146/// # Examples
147///
148/// See [`ContiguousLookupDecoderModel`].
149///
150/// [`ContiguousLookupDecoderModel`]: crate::stream::model::ContiguousLookupDecoderModel
151/// [`NonContiguousLookupDecoderModel`]: crate::stream::model::NonContiguousLookupDecoderModel
152/// [`ContiguousLookupDecoderModel`]: crate::stream::model::ContiguousLookupDecoderModel
153pub type SmallAnsCoder<Backend = Vec<u16>> = AnsCoder<u16, u32, Backend>;
154
155impl<Word, State, Backend> Debug for AnsCoder<Word, State, Backend>
156where
157    Word: BitArray + Into<State>,
158    State: BitArray + AsPrimitive<Word>,
159    for<'a> &'a Backend: IntoIterator<Item = &'a Word>,
160{
161    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
162        f.debug_list().entries(self.iter_compressed()).finish()
163    }
164}
165
166impl<Word, State, Backend, const PRECISION: usize> IntoDecoder<PRECISION>
167    for AnsCoder<Word, State, Backend>
168where
169    Word: BitArray + Into<State>,
170    State: BitArray + AsPrimitive<Word>,
171    Backend: WriteWords<Word> + IntoReadWords<Word, Stack>,
172{
173    type IntoDecoder = AnsCoder<Word, State, Backend::IntoReadWords>;
174
175    fn into_decoder(self) -> Self::IntoDecoder {
176        AnsCoder {
177            bulk: self.bulk.into_read_words(),
178            state: self.state,
179            phantom: PhantomData,
180        }
181    }
182}
183
184impl<'a, Word, State, Backend> From<&'a AnsCoder<Word, State, Backend>>
185    for AnsCoder<Word, State, <Backend as AsReadWords<'a, Word, Stack>>::AsReadWords>
186where
187    Word: BitArray + Into<State>,
188    State: BitArray + AsPrimitive<Word>,
189    Backend: AsReadWords<'a, Word, Stack>,
190{
191    fn from(ans: &'a AnsCoder<Word, State, Backend>) -> Self {
192        AnsCoder {
193            bulk: ans.bulk().as_read_words(),
194            state: ans.state(),
195            phantom: PhantomData,
196        }
197    }
198}
199
200impl<'a, Word, State, Backend, const PRECISION: usize> AsDecoder<'a, PRECISION>
201    for AnsCoder<Word, State, Backend>
202where
203    Word: BitArray + Into<State>,
204    State: BitArray + AsPrimitive<Word>,
205    Backend: WriteWords<Word> + AsReadWords<'a, Word, Stack>,
206{
207    type AsDecoder = AnsCoder<Word, State, Backend::AsReadWords>;
208
209    fn as_decoder(&'a self) -> Self::AsDecoder {
210        self.into()
211    }
212}
213
214impl<Word, State> From<AnsCoder<Word, State, Vec<Word>>> for Vec<Word>
215where
216    Word: BitArray + Into<State>,
217    State: BitArray + AsPrimitive<Word>,
218{
219    fn from(val: AnsCoder<Word, State, Vec<Word>>) -> Self {
220        val.into_compressed().unwrap_infallible()
221    }
222}
223
224impl<Word, State> AnsCoder<Word, State, Vec<Word>>
225where
226    Word: BitArray + Into<State>,
227    State: BitArray + AsPrimitive<Word>,
228{
229    /// Creates an empty ANS entropy coder.
230    ///
231    /// This is usually the starting point if you want to *compress* data.
232    ///
233    /// # Example
234    ///
235    /// ```
236    /// let mut ans = constriction::stream::stack::DefaultAnsCoder::new();
237    ///
238    /// // ... push some symbols onto the ANS coder's stack ...
239    ///
240    /// // Finally, get the compressed data.
241    /// let compressed = ans.into_compressed();
242    /// ```
243    ///
244    /// # Generality
245    ///
246    /// To avoid type parameters in common use cases, `new` is only implemented for
247    /// `AnsCoder`s with a `Vec` backend. To create an empty coder with a different backend,
248    /// call [`Default::default`] instead.
249    pub fn new() -> Self {
250        Self::default()
251    }
252}
253
254impl<Word, State, Backend> Default for AnsCoder<Word, State, Backend>
255where
256    Word: BitArray + Into<State>,
257    State: BitArray + AsPrimitive<Word>,
258    Backend: Default,
259{
260    fn default() -> Self {
261        generic_static_asserts!(
262            (Word: BitArray, State:BitArray);
263            STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
264        );
265
266        Self {
267            state: State::zero(),
268            bulk: Default::default(),
269            phantom: PhantomData,
270        }
271    }
272}
273
274impl<Word, State, Backend> AnsCoder<Word, State, Backend>
275where
276    Word: BitArray + Into<State>,
277    State: BitArray + AsPrimitive<Word>,
278{
279    /// Low-level constructor that assembles an `AnsCoder` from its internal components.
280    ///
281    /// The arguments `bulk` and `state` correspond to the two return values of the method
282    /// [`into_raw_parts`](Self::into_raw_parts).
283    ///
284    /// The caller must ensure that `state >= State::one() << (State::BITS - Word::BITS)`
285    /// unless `bulk` is empty. This cannot be checked by the method since not all
286    /// `Backend`s have an `is_empty` method. Violating this invariant is not a memory
287    /// safety issue but it will lead to incorrect behavior.
288    pub fn from_raw_parts(bulk: Backend, state: State) -> Self {
289        Self {
290            bulk,
291            state,
292            phantom: PhantomData,
293        }
294    }
295
296    /// Creates an ANS stack with some initial compressed data.
297    ///
298    /// This is usually the starting point if you want to *decompress* data previously
299    /// obtained from [`into_compressed`].  However, it can also be used to append more
300    /// symbols to an existing compressed buffer of data.
301    ///
302    /// Returns `Err(compressed)` if `compressed` is not empty and its last entry is
303    /// zero, since an `AnsCoder` cannot represent trailing zero words. This error cannot
304    /// occur if `compressed` was obtained from [`into_compressed`], which never returns
305    /// data with a trailing zero word. If you want to construct a `AnsCoder` from an
306    /// unknown source of binary data (e.g., to decode some side information into latent
307    /// variables) then call [`from_binary`] instead.
308    ///
309    /// [`into_compressed`]: #method.into_compressed
310    /// [`from_binary`]: #method.from_binary
311    pub fn from_compressed(mut compressed: Backend) -> Result<Self, Backend>
312    where
313        Backend: ReadWords<Word, Stack>,
314    {
315        generic_static_asserts!(
316            (Word: BitArray, State:BitArray);
317            STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
318        );
319
320        let state = match Self::read_initial_state(|| compressed.read()) {
321            Ok(state) => state,
322            Err(_) => return Err(compressed),
323        };
324
325        Ok(Self {
326            bulk: compressed,
327            state,
328            phantom: PhantomData,
329        })
330    }
331
332    fn read_initial_state<Error>(
333        mut read_word: impl FnMut() -> Result<Option<Word>, Error>,
334    ) -> Result<State, ()>
335    where
336        Backend: ReadWords<Word, Stack>,
337    {
338        if let Some(first_word) = read_word().map_err(|_| ())? {
339            if first_word == Word::zero() {
340                return Err(());
341            }
342
343            let mut state = first_word.into();
344            while let Some(word) = read_word().map_err(|_| ())? {
345                state = (state << Word::BITS) | word.into();
346                if state >= State::one() << (State::BITS - Word::BITS) {
347                    break;
348                }
349            }
350            Ok(state)
351        } else {
352            Ok(State::zero())
353        }
354    }
355
356    /// Like [`from_compressed`] but works on any binary data.
357    ///
358    /// This method is meant for rather advanced use cases. For most common use cases,
359    /// you probably want to call [`from_compressed`] instead.
360    ///
361    /// Different to `from_compressed`, this method also works if `data` ends in a zero
362    /// word. Calling this method is equivalent to (but likely more efficient than)
363    /// appending a `1` word to `data` and then calling `from_compressed`. Note that
364    /// therefore, this method always constructs a non-empty `AnsCoder` (even if `data` is
365    /// empty):
366    ///
367    /// ```
368    /// use constriction::stream::stack::DefaultAnsCoder;
369    ///
370    /// let stack1 = DefaultAnsCoder::from_binary(Vec::new()).unwrap();
371    /// assert!(!stack1.is_empty()); // <-- stack1 is *not* empty.
372    ///
373    /// let stack2 = DefaultAnsCoder::from_compressed(Vec::new()).unwrap();
374    /// assert!(stack2.is_empty()); // <-- stack2 is empty.
375    /// ```
376    /// [`from_compressed`]: #method.from_compressed
377    pub fn from_binary(mut data: Backend) -> Result<Self, Backend::ReadError>
378    where
379        Backend: ReadWords<Word, Stack>,
380    {
381        let mut state = State::one();
382
383        while state < State::one() << (State::BITS - Word::BITS) {
384            if let Some(word) = data.read()? {
385                state = (state << Word::BITS) | word.into();
386            } else {
387                break;
388            }
389        }
390
391        Ok(Self {
392            bulk: data,
393            state,
394            phantom: PhantomData,
395        })
396    }
397
398    #[inline(always)]
399    pub fn bulk(&self) -> &Backend {
400        &self.bulk
401    }
402
403    /// Low-level method that disassembles the `AnsCoder` into its internal components.
404    ///
405    /// Can be used together with [`from_raw_parts`](Self::from_raw_parts).
406    pub fn into_raw_parts(self) -> (Backend, State) {
407        (self.bulk, self.state)
408    }
409
410    /// Check if no data for decoding is left.
411    ///
412    /// Note that you can still pop symbols off an empty stack, but this is only
413    /// useful in rare edge cases, see documentation of
414    /// [`decode_symbol`](#method.decode_symbol).
415    pub fn is_empty(&self) -> bool {
416        // We don't need to check if `bulk` is empty (which would require an additional
417        // type bound `Backend: ReadLookaheadItems<Word>` because we keep up the
418        // invariant that `state >= State::one() << (State::BITS - Word::BITS))`
419        // when `bulk` is not empty.
420        self.state == State::zero()
421    }
422
423    /// Assembles the current compressed data into a single slice.
424    ///
425    /// Returns the concatenation of [`bulk`] and [`state`]. The concatenation truncates
426    /// any trailing zero words, which is compatible with the constructor
427    /// [`from_compressed`].
428    ///
429    /// This method requires a `&mut self` receiver to temporarily append `state` to
430    /// [`bulk`] (this mutationwill be reversed to recreate the original `bulk` as soon as
431    /// the caller drops the returned value). If you don't have mutable access to the
432    /// `AnsCoder`, consider calling [`iter_compressed`] instead, or get the `bulk` and
433    /// `state` separately by calling [`bulk`] and [`state`], respectively.
434    ///
435    /// The return type dereferences to `&[Word]`, thus providing read-only
436    /// access to the compressed data. If you need ownership of the compressed data,
437    /// consider calling [`into_compressed`] instead.
438    ///
439    /// # Example
440    ///
441    /// ```
442    /// use constriction::stream::{
443    ///     model::DefaultContiguousCategoricalEntropyModel, stack::DefaultAnsCoder, Decode
444    /// };
445    ///
446    /// let mut ans = DefaultAnsCoder::new();
447    ///
448    /// // Push some data on the ans.
449    /// let symbols = vec![8, 2, 0, 7];
450    /// let probabilities = vec![0.03, 0.07, 0.1, 0.1, 0.2, 0.2, 0.1, 0.15, 0.05];
451    /// let model = DefaultContiguousCategoricalEntropyModel
452    ///     ::from_floating_point_probabilities_fast(&probabilities, None).unwrap();
453    /// ans.encode_iid_symbols_reverse(&symbols, &model).unwrap();
454    ///
455    /// // Inspect the compressed data.
456    /// dbg!(ans.get_compressed());
457    ///
458    /// // We can still use the ANS coder afterwards.
459    /// let reconstructed = ans
460    ///     .decode_iid_symbols(4, &model)
461    ///     .collect::<Result<Vec<_>, _>>()
462    ///     .unwrap();
463    /// assert_eq!(reconstructed, symbols);
464    /// ```
465    ///
466    /// [`bulk`]: #method.bulk
467    /// [`state`]: #method.state
468    /// [`from_compressed`]: #method.from_compressed
469    /// [`iter_compressed`]: #method.iter_compressed
470    /// [`into_compressed`]: #method.into_compressed
471    pub fn get_compressed(
472        &mut self,
473    ) -> Result<impl Deref<Target = Backend> + Debug + Drop + '_, Backend::WriteError>
474    where
475        Backend: ReadWords<Word, Stack> + WriteWords<Word> + Debug,
476    {
477        CoderGuard::<'_, _, _, _, false>::new(self).map_err(|err| match err {
478            CoderError::Frontend(()) => unreachable!("Can't happen for SEALED==false."),
479            CoderError::Backend(err) => err,
480        })
481    }
482
483    pub fn get_binary(
484        &mut self,
485    ) -> Result<impl Deref<Target = Backend> + Debug + Drop + '_, CoderError<(), Backend::WriteError>>
486    where
487        Backend: ReadWords<Word, Stack> + WriteWords<Word> + Debug,
488    {
489        CoderGuard::<'_, _, _, _, true>::new(self)
490    }
491
492    /// Iterates over the compressed data currently on the ans.
493    ///
494    /// In contrast to [`get_compressed`] or [`into_compressed`], this method does
495    /// not require mutable access or even ownership of the `AnsCoder`.
496    ///
497    /// # Example
498    ///
499    /// ```
500    /// use constriction::stream::{model::DefaultLeakyQuantizer, stack::DefaultAnsCoder, Decode};
501    ///
502    /// // Create a stack and encode some stuff.
503    /// let mut ans = DefaultAnsCoder::new();
504    /// let symbols = vec![8, -12, 0, 7];
505    /// let quantizer = DefaultLeakyQuantizer::new(-100..=100);
506    /// let model =
507    ///     quantizer.quantize(probability::distribution::Gaussian::new(0.0, 10.0));
508    /// ans.encode_iid_symbols_reverse(&symbols, &model).unwrap();
509    ///
510    /// // Iterate over compressed data, collect it into to a Vec``, and compare to direct method.
511    /// let compressed_iter = ans.iter_compressed();
512    /// let compressed_collected = compressed_iter.collect::<Vec<_>>();
513    /// assert!(!compressed_collected.is_empty());
514    /// assert_eq!(compressed_collected, *ans.get_compressed().unwrap());
515    /// ```
516    ///
517    /// [`get_compressed`]: #method.get_compressed
518    /// [`into_compressed`]: #method.into_compressed
519    pub fn iter_compressed<'a>(&'a self) -> impl Iterator<Item = Word> + 'a
520    where
521        &'a Backend: IntoIterator<Item = &'a Word>,
522    {
523        let bulk_iter = self.bulk.into_iter().cloned();
524        let state_iter = bit_array_to_chunks_truncated(self.state).rev();
525        bulk_iter.chain(state_iter)
526    }
527
528    /// Returns the number of compressed words on the ANS coder's stack.
529    ///
530    /// This includes a constant overhead of between one and two words unless the
531    /// stack is completely empty.
532    ///
533    /// This method returns the length of the slice, the `Vec<Word>`, or the iterator
534    /// that would be returned by [`get_compressed`], [`into_compressed`], or
535    /// [`iter_compressed`], respectively, when called at this time.
536    ///
537    /// See also [`num_bits`].
538    ///
539    /// [`get_compressed`]: #method.get_compressed
540    /// [`into_compressed`]: #method.into_compressed
541    /// [`iter_compressed`]: #method.iter_compressed
542    /// [`num_bits`]: #method.num_bits
543    pub fn num_words(&self) -> usize
544    where
545        Backend: BoundedReadWords<Word, Stack>,
546    {
547        self.bulk.remaining() + bit_array_to_chunks_truncated::<_, Word>(self.state).len()
548    }
549
550    pub fn num_bits(&self) -> usize
551    where
552        Backend: BoundedReadWords<Word, Stack>,
553    {
554        Word::BITS * self.num_words()
555    }
556
557    pub fn num_valid_bits(&self) -> usize
558    where
559        Backend: BoundedReadWords<Word, Stack>,
560    {
561        Word::BITS * self.bulk.remaining()
562            + core::cmp::max(State::BITS - self.state.leading_zeros() as usize, 1)
563            - 1
564    }
565
566    pub fn into_decoder(self) -> AnsCoder<Word, State, Backend::IntoReadWords>
567    where
568        Backend: IntoReadWords<Word, Stack>,
569    {
570        AnsCoder {
571            bulk: self.bulk.into_read_words(),
572            state: self.state,
573            phantom: PhantomData,
574        }
575    }
576
577    /// Consumes the `AnsCoder` and returns a decoder that implements [`Seek`].
578    ///
579    /// This method is similar to [`as_seekable_decoder`] except that it takes ownership of
580    /// the original `AnsCoder`, so the returned seekable decoder can typically be returned
581    /// from the calling function or put on the heap.
582    ///
583    /// [`as_seekable_decoder`]: Self::as_seekable_decoder
584    pub fn into_seekable_decoder(self) -> AnsCoder<Word, State, Backend::IntoSeekReadWords>
585    where
586        Backend: IntoSeekReadWords<Word, Stack>,
587    {
588        AnsCoder {
589            bulk: self.bulk.into_seek_read_words(),
590            state: self.state,
591            phantom: PhantomData,
592        }
593    }
594
595    pub fn as_decoder<'a>(&'a self) -> AnsCoder<Word, State, Backend::AsReadWords>
596    where
597        Backend: AsReadWords<'a, Word, Stack>,
598    {
599        AnsCoder {
600            bulk: self.bulk.as_read_words(),
601            state: self.state,
602            phantom: PhantomData,
603        }
604    }
605
606    /// Returns a decoder that implements [`Seek`].
607    ///
608    /// The returned decoder shares access to the compressed data with the original
609    /// `AnsCoder` (i.e., `self`). This means that:
610    /// - you can call this method several times to create several seekable decoders
611    ///   with independent views into the same compressed data;
612    /// - once the lifetime of all handed out seekable decoders ends, the original
613    ///   `AnsCoder` can be used again; and
614    /// - the constructed seekable decoder cannot outlive the original `AnsCoder`; for
615    ///   example, if the original `AnsCoder` lives on the calling function's call stack
616    ///   frame then you cannot return the constructed seekable decoder from the calling
617    ///   function. If this is a problem then call [`into_seekable_decoder`] instead.
618    ///
619    /// # Limitations
620    ///
621    /// TODO: this text is outdated.
622    ///
623    /// This method is only implemented for `AnsCoder`s whose backing store of compressed
624    /// data (`Backend`) implements `AsRef<[Word]>`. This includes the default
625    /// backing data store `Backend = Vec<Word>`.
626    ///
627    /// [`into_seekable_decoder`]: Self::into_seekable_decoder
628    pub fn as_seekable_decoder<'a>(&'a self) -> AnsCoder<Word, State, Backend::AsSeekReadWords>
629    where
630        Backend: AsSeekReadWords<'a, Word, Stack>,
631    {
632        AnsCoder {
633            bulk: self.bulk.as_seek_read_words(),
634            state: self.state,
635            phantom: PhantomData,
636        }
637    }
638}
639
640impl<Word, State> AnsCoder<Word, State>
641where
642    Word: BitArray + Into<State>,
643    State: BitArray + AsPrimitive<Word>,
644{
645    /// Discards all compressed data and resets the coder to the same state as
646    /// [`Coder::new`](#method.new).
647    pub fn clear(&mut self) {
648        self.bulk.clear();
649        self.state = State::zero();
650    }
651}
652
653impl<'bulk, Word, State> AnsCoder<Word, State, Cursor<Word, &'bulk [Word]>>
654where
655    Word: BitArray + Into<State>,
656    State: BitArray + AsPrimitive<Word>,
657{
658    // TODO: proper error type (also for `from_compressed`)
659    #[allow(clippy::result_unit_err)]
660    pub fn from_compressed_slice(compressed: &'bulk [Word]) -> Result<Self, ()> {
661        Self::from_compressed(backends::Cursor::new_at_write_end(compressed)).map_err(|_| ())
662    }
663
664    pub fn from_binary_slice(data: &'bulk [Word]) -> Self {
665        Self::from_binary(backends::Cursor::new_at_write_end(data)).unwrap_infallible()
666    }
667}
668
669impl<Word, State, Buf> AnsCoder<Word, State, Reverse<Cursor<Word, Buf>>>
670where
671    Word: BitArray + Into<State>,
672    State: BitArray + AsPrimitive<Word>,
673    Buf: AsRef<[Word]>,
674{
675    pub fn from_reversed_compressed(compressed: Buf) -> Result<Self, Buf> {
676        Self::from_compressed(Reverse(Cursor::new_at_write_beginning(compressed)))
677            .map_err(|Reverse(cursor)| cursor.into_buf_and_pos().0)
678    }
679
680    pub fn from_reversed_binary(data: Buf) -> Self {
681        Self::from_binary(Reverse(Cursor::new_at_write_beginning(data))).unwrap_infallible()
682    }
683}
684
685impl<Word, State, Iter, ReadError> AnsCoder<Word, State, FallibleIteratorReadWords<Iter>>
686where
687    Word: BitArray + Into<State>,
688    State: BitArray + AsPrimitive<Word>,
689    Iter: Iterator<Item = Result<Word, ReadError>>,
690    FallibleIteratorReadWords<Iter>: ReadWords<Word, Stack, ReadError = ReadError>,
691{
692    pub fn from_reversed_compressed_iter(compressed: Iter) -> Result<Self, Fuse<Iter>> {
693        Self::from_compressed(FallibleIteratorReadWords::new(compressed))
694            .map_err(|iterator_backend| iterator_backend.into_iter())
695    }
696
697    pub fn from_reversed_binary_iter(data: Iter) -> Result<Self, ReadError> {
698        Self::from_binary(FallibleIteratorReadWords::new(data))
699    }
700}
701
702impl<Word, State, Backend> AnsCoder<Word, State, Backend>
703where
704    Word: BitArray + Into<State>,
705    State: BitArray + AsPrimitive<Word>,
706    Backend: WriteWords<Word>,
707{
708    /// Recommended way to encode a heterogeneously distributed sequence of
709    /// symbols onto an `AnsCoder`.
710    ///
711    /// This method is similar to the trait method [`Encode::encode_symbols`],
712    /// but it encodes the symbols in *reverse* order (and therefore requires
713    /// the provided iterator to implement [`DoubleEndedIterator`]). Encoding
714    /// in reverse order is the recommended way to encode onto an `AnsCoder`
715    /// because an `AnsCoder` is a *stack*, i.e., the last symbol you encode
716    /// onto an `AnsCoder` is the first symbol that you will decode from it.
717    /// Thus, encoding a sequence of symbols in reverse order will allow you to
718    /// decode them in normal order.
719    pub fn encode_symbols_reverse<S, M, I, const PRECISION: usize>(
720        &mut self,
721        symbols_and_models: I,
722    ) -> Result<(), DefaultEncoderError<Backend::WriteError>>
723    where
724        S: Borrow<M::Symbol>,
725        M: EncoderModel<PRECISION>,
726        M::Probability: Into<Word>,
727        Word: AsPrimitive<M::Probability>,
728        I: IntoIterator<Item = (S, M)>,
729        I::IntoIter: DoubleEndedIterator,
730    {
731        self.encode_symbols(symbols_and_models.into_iter().rev())
732    }
733
734    /// Recommended way to encode onto an `AnsCoder` from a fallible iterator.
735    ///
736    /// This method is similar to the trait method
737    /// [`Encode::try_encode_symbols`], but it encodes the symbols in *reverse*
738    /// order (and therefore requires the provided iterator to implement
739    /// [`DoubleEndedIterator`]). Encoding in reverse order is the recommended
740    /// way to encode onto an `AnsCoder` because an `AnsCoder` is a *stack*,
741    /// i.e., the last symbol you encode  onto an `AnsCoder` is the first symbol
742    /// that you will decode from it. Thus, encoding a sequence of symbols in
743    /// reverse order will allow you to decode them in normal order.
744    pub fn try_encode_symbols_reverse<S, M, E, I, const PRECISION: usize>(
745        &mut self,
746        symbols_and_models: I,
747    ) -> Result<(), TryCodingError<DefaultEncoderError<Backend::WriteError>, E>>
748    where
749        S: Borrow<M::Symbol>,
750        M: EncoderModel<PRECISION>,
751        M::Probability: Into<Word>,
752        Word: AsPrimitive<M::Probability>,
753        I: IntoIterator<Item = core::result::Result<(S, M), E>>,
754        I::IntoIter: DoubleEndedIterator,
755    {
756        self.try_encode_symbols(symbols_and_models.into_iter().rev())
757    }
758
759    /// Recommended way to encode a sequence of i.i.d. symbols onto an
760    /// `AnsCoder`.
761    ///
762    /// This method is similar to the trait method
763    /// [`Encode::encode_iid_symbols`], but it encodes the symbols in *reverse*
764    /// order (and therefore requires the provided iterator to implement
765    /// [`DoubleEndedIterator`]). Encoding in reverse order is the recommended
766    /// way to encode onto an `AnsCoder` because an `AnsCoder` is a *stack*,
767    /// i.e., the last symbol you encode onto an `AnsCoder` is the first symbol
768    /// that you will decode from it. Thus, encoding a sequence of symbols in
769    /// reverse order will allow you to decode them in normal order.
770    pub fn encode_iid_symbols_reverse<S, M, I, const PRECISION: usize>(
771        &mut self,
772        symbols: I,
773        model: M,
774    ) -> Result<(), DefaultEncoderError<Backend::WriteError>>
775    where
776        S: Borrow<M::Symbol>,
777        M: EncoderModel<PRECISION> + Copy,
778        M::Probability: Into<Word>,
779        Word: AsPrimitive<M::Probability>,
780        I: IntoIterator<Item = S>,
781        I::IntoIter: DoubleEndedIterator,
782    {
783        self.encode_iid_symbols(symbols.into_iter().rev(), model)
784    }
785
786    /// Consumes the ANS coder and returns the compressed data.
787    ///
788    /// The returned data can be used to recreate an ANS coder with the same state
789    /// (e.g., for decoding) by passing it to
790    /// [`from_compressed`](#method.from_compressed).
791    ///
792    /// If you don't want to consume the ANS coder, consider calling
793    /// [`get_compressed`](#method.get_compressed),
794    /// [`iter_compressed`](#method.iter_compressed) instead.
795    ///
796    /// # Example
797    ///
798    /// ```
799    /// use constriction::stream::{
800    ///     model::DefaultContiguousCategoricalEntropyModel, stack::DefaultAnsCoder, Decode
801    /// };
802    ///
803    /// let mut ans = DefaultAnsCoder::new();
804    ///
805    /// // Push some data onto the ANS coder's stack:
806    /// let symbols = vec![8, 2, 0, 7];
807    /// let probabilities = vec![0.03, 0.07, 0.1, 0.1, 0.2, 0.2, 0.1, 0.15, 0.05];
808    /// let model = DefaultContiguousCategoricalEntropyModel
809    ///     ::from_floating_point_probabilities_fast(&probabilities, None).unwrap();
810    /// ans.encode_iid_symbols_reverse(&symbols, &model).unwrap();
811    ///
812    /// // Get the compressed data, consuming the ANS coder:
813    /// let compressed = ans.into_compressed().unwrap();
814    ///
815    /// // ... write `compressed` to a file and then read it back later ...
816    ///
817    /// // Create a new ANS coder with the same state and use it for decompression:
818    /// let mut ans = DefaultAnsCoder::from_compressed(compressed).expect("Corrupted compressed file.");
819    /// let reconstructed = ans
820    ///     .decode_iid_symbols(4, &model)
821    ///     .collect::<Result<Vec<_>, _>>()
822    ///     .unwrap();
823    /// assert_eq!(reconstructed, symbols);
824    /// assert!(ans.is_empty())
825    /// ```
826    pub fn into_compressed(mut self) -> Result<Backend, Backend::WriteError> {
827        self.bulk
828            .extend_from_iter(bit_array_to_chunks_truncated(self.state).rev())?;
829        Ok(self.bulk)
830    }
831
832    /// Returns the binary data if it fits precisely into an integer number of
833    /// `Word`s
834    ///
835    /// This method is meant for rather advanced use cases. For most common use cases,
836    /// you probably want to call [`into_compressed`] instead.
837    ///
838    /// This method is the inverse of [`from_binary`]. It is equivalent to calling
839    /// [`into_compressed`], verifying that the returned vector ends in a `1` word, and
840    /// popping off that trailing `1` word.
841    ///
842    /// Returns `Err(())` if the compressed data (excluding an obligatory trailing
843    /// `1` bit) does not fit into an integer number of `Word`s. This error
844    /// case includes the case of an empty `AnsCoder` (since an empty `AnsCoder` lacks the
845    /// obligatory trailing one-bit).
846    ///
847    /// # Example
848    ///
849    /// ```
850    /// // Some binary data we want to represent on a `AnsCoder`.
851    /// let data = vec![0x89ab_cdef, 0x0123_4567];
852    ///
853    /// // Constructing a `AnsCoder` with `from_binary` indicates that all bits of `data` are
854    /// // considered part of the information-carrying payload.
855    /// let stack1 = constriction::stream::stack::DefaultAnsCoder::from_binary(data.clone()).unwrap();
856    /// assert_eq!(stack1.clone().into_binary().unwrap(), data); // <-- Retrieves the original `data`.
857    ///
858    /// // By contrast, if we construct a `AnsCoder` with `from_compressed`, we indicate that
859    /// // - any leading `0` bits of the last entry of `data` are not considered part of
860    /// //   the information-carrying payload; and
861    /// // - the (obligatory) first `1` bit of the last entry of `data` defines the
862    /// //   boundary between unused bits and information-carrying bits; it is therefore
863    /// //   also not considered part of the payload.
864    /// // Therefore, `stack2` below only contains `32 * 2 - 7 - 1 = 56` bits of payload,
865    /// // which cannot be exported into an integer number of `u32` words:
866    /// let stack2 = constriction::stream::stack::DefaultAnsCoder::from_compressed(data.clone()).unwrap();
867    /// assert!(stack2.clone().into_binary().is_err()); // <-- Returns an error.
868    ///
869    /// // Use `into_compressed` to retrieve the data in this case:
870    /// assert_eq!(stack2.into_compressed().unwrap(), data);
871    ///
872    /// // Calling `into_compressed` on `stack1` would append an extra `1` bit to indicate
873    /// // the boundary between information-carrying bits and padding `0` bits:
874    /// assert_eq!(stack1.into_compressed().unwrap(), vec![0x89ab_cdef, 0x0123_4567, 0x0000_0001]);
875    /// ```
876    ///
877    /// [`from_binary`]: #method.from_binary
878    /// [`into_compressed`]: #method.into_compressed
879    pub fn into_binary(mut self) -> Result<Backend, Option<Backend::WriteError>> {
880        let valid_bits = (State::BITS - 1).wrapping_sub(self.state.leading_zeros() as usize);
881
882        if valid_bits % Word::BITS != 0 || valid_bits == usize::MAX {
883            Err(None)
884        } else {
885            let truncated_state = self.state ^ (State::one() << valid_bits);
886            self.bulk
887                .extend_from_iter(bit_array_to_chunks_truncated(truncated_state).rev())?;
888            Ok(self.bulk)
889        }
890    }
891}
892
893impl<Word, State, Buf> AnsCoder<Word, State, Cursor<Word, Buf>>
894where
895    Word: BitArray,
896    State: BitArray + AsPrimitive<Word> + From<Word>,
897    Buf: AsRef<[Word]> + AsMut<[Word]>,
898{
899    pub fn into_reversed(self) -> AnsCoder<Word, State, Reverse<Cursor<Word, Buf>>> {
900        let (bulk, state) = self.into_raw_parts();
901        AnsCoder {
902            bulk: bulk.into_reversed(),
903            state,
904            phantom: PhantomData,
905        }
906    }
907}
908
909impl<Word, State, Buf> AnsCoder<Word, State, Reverse<Cursor<Word, Buf>>>
910where
911    Word: BitArray,
912    State: BitArray + AsPrimitive<Word> + From<Word>,
913    Buf: AsRef<[Word]> + AsMut<[Word]>,
914{
915    pub fn into_reversed(self) -> AnsCoder<Word, State, Cursor<Word, Buf>> {
916        let (bulk, state) = self.into_raw_parts();
917        AnsCoder {
918            bulk: bulk.into_reversed(),
919            state,
920            phantom: PhantomData,
921        }
922    }
923}
924
925impl<Word, State, Backend> Code for AnsCoder<Word, State, Backend>
926where
927    Word: BitArray + Into<State>,
928    State: BitArray + AsPrimitive<Word>,
929{
930    type Word = Word;
931    type State = State;
932
933    #[inline(always)]
934    fn state(&self) -> Self::State {
935        self.state
936    }
937}
938
939impl<Word, State, Backend, const PRECISION: usize> Encode<PRECISION>
940    for AnsCoder<Word, State, Backend>
941where
942    Word: BitArray + Into<State>,
943    State: BitArray + AsPrimitive<Word>,
944    Backend: WriteWords<Word>,
945{
946    type FrontendError = DefaultEncoderFrontendError;
947    type BackendError = Backend::WriteError;
948
949    /// Encodes a single symbol and appends it to the compressed data.
950    ///
951    /// This is a low level method. You probably usually want to call a batch method
952    /// like [`encode_symbols`](#method.encode_symbols) or
953    /// [`encode_iid_symbols`](#method.encode_iid_symbols) instead. See examples there.
954    ///
955    /// The bound `impl Borrow<M::Symbol>` on argument `symbol` essentially means that
956    /// you can provide the symbol either by value or by reference, at your choice.
957    ///
958    /// Returns [`Err(ImpossibleSymbol)`] if `symbol` has zero probability under the
959    /// entropy model `model`. This error can usually be avoided by using a
960    /// "leaky" distribution as the entropy model, i.e., a distribution that assigns a
961    /// nonzero probability to all symbols within a finite domain. Leaky distributions
962    /// can be constructed with, e.g., a
963    /// [`LeakyQuantizer`](models/struct.LeakyQuantizer.html) or with
964    /// [`LeakyCategorical::from_floating_point_probabilities`](
965    /// models/struct.LeakyCategorical.html#method.from_floating_point_probabilities).
966    ///
967    /// TODO: move this and similar doc comments to the trait definition.
968    ///
969    /// [`Err(ImpossibleSymbol)`]: enum.EncodingError.html#variant.ImpossibleSymbol
970    fn encode_symbol<M>(
971        &mut self,
972        symbol: impl Borrow<M::Symbol>,
973        model: M,
974    ) -> Result<(), DefaultEncoderError<Self::BackendError>>
975    where
976        M: EncoderModel<PRECISION>,
977        M::Probability: Into<Self::Word>,
978        Self::Word: AsPrimitive<M::Probability>,
979    {
980        generic_static_asserts!(
981            (Word: BitArray, State:BitArray; const PRECISION: usize);
982            PROBABILITY_SUPPORTS_PRECISION: State::BITS >= Word::BITS + PRECISION;
983            NON_ZERO_PRECISION: PRECISION > 0;
984            STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
985        );
986
987        let (left_sided_cumulative, probability) = model
988            .left_cumulative_and_probability(symbol)
989            .ok_or_else(|| DefaultEncoderFrontendError::ImpossibleSymbol.into_coder_error())?;
990
991        if (self.state >> (State::BITS - PRECISION)) >= probability.get().into().into() {
992            self.bulk.write(self.state.as_())?;
993            self.state = self.state >> Word::BITS;
994            // At this point, the invariant on `self.state` (see its doc comment) is
995            // temporarily violated, but it will be restored below.
996        }
997
998        let remainder = (self.state % probability.get().into().into()).as_().as_();
999        let prefix = self.state / probability.get().into().into();
1000        let quantile = left_sided_cumulative + remainder;
1001        self.state = (prefix << PRECISION) | quantile.into().into();
1002
1003        Ok(())
1004    }
1005
1006    fn maybe_full(&self) -> bool {
1007        self.bulk.maybe_full()
1008    }
1009}
1010
1011impl<Word, State, Backend, const PRECISION: usize> Decode<PRECISION>
1012    for AnsCoder<Word, State, Backend>
1013where
1014    Word: BitArray + Into<State>,
1015    State: BitArray + AsPrimitive<Word>,
1016    Backend: ReadWords<Word, Stack>,
1017{
1018    /// ANS coding is surjective, and we (deliberately) allow decoding past EOF (in a
1019    /// deterministic way) for consistency. Therefore, decoding cannot fail in the front
1020    /// end.
1021    type FrontendError = Infallible;
1022
1023    type BackendError = Backend::ReadError;
1024
1025    #[inline(always)]
1026    fn decode_symbol<M>(
1027        &mut self,
1028        model: M,
1029    ) -> Result<M::Symbol, CoderError<Self::FrontendError, Self::BackendError>>
1030    where
1031        M: DecoderModel<PRECISION>,
1032        M::Probability: Into<Self::Word>,
1033        Self::Word: AsPrimitive<M::Probability>,
1034    {
1035        generic_static_asserts!(
1036            (Word: BitArray, State:BitArray; const PRECISION: usize);
1037            PROBABILITY_SUPPORTS_PRECISION: State::BITS >= Word::BITS + PRECISION;
1038            NON_ZERO_PRECISION: PRECISION > 0;
1039            STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
1040        );
1041
1042        let quantile = (self.state % (State::one() << PRECISION)).as_().as_();
1043        let (symbol, left_sided_cumulative, probability) = model.quantile_function(quantile);
1044        let remainder = quantile - left_sided_cumulative;
1045        self.state =
1046            (self.state >> PRECISION) * probability.get().into().into() + remainder.into().into();
1047        if self.state < State::one() << (State::BITS - Word::BITS) {
1048            // Invariant on `self.state` (see its doc comment) is violated. Restore it by
1049            // refilling with a compressed word from `self.bulk` if available.
1050            if let Some(word) = self.bulk.read()? {
1051                self.state = (self.state << Word::BITS) | word.into();
1052            }
1053        }
1054
1055        Ok(symbol)
1056    }
1057
1058    fn maybe_exhausted(&self) -> bool {
1059        self.is_empty()
1060    }
1061}
1062
1063impl<Word, State, Backend> PosSeek for AnsCoder<Word, State, Backend>
1064where
1065    Word: BitArray + Into<State>,
1066    State: BitArray + AsPrimitive<Word>,
1067    Backend: PosSeek,
1068    Self: Code,
1069{
1070    type Position = (Backend::Position, <Self as Code>::State);
1071}
1072
1073impl<Word, State, Backend> Seek for AnsCoder<Word, State, Backend>
1074where
1075    Word: BitArray + Into<State>,
1076    State: BitArray + AsPrimitive<Word>,
1077    Backend: Seek,
1078{
1079    fn seek(&mut self, (pos, state): Self::Position) -> Result<(), ()> {
1080        self.bulk.seek(pos)?;
1081        self.state = state;
1082        Ok(())
1083    }
1084}
1085
1086impl<Word, State, Backend> Pos for AnsCoder<Word, State, Backend>
1087where
1088    Word: BitArray + Into<State>,
1089    State: BitArray + AsPrimitive<Word>,
1090    Backend: Pos,
1091{
1092    fn pos(&self) -> Self::Position {
1093        (self.bulk.pos(), self.state())
1094    }
1095}
1096
1097/// Provides temporary read-only access to the compressed data wrapped in a
1098/// [`AnsCoder`].
1099///
1100/// Dereferences to `&[Word]`. See [`Coder::get_compressed`] for an example.
1101///
1102/// [`AnsCoder`]: struct.Coder.html
1103/// [`Coder::get_compressed`]: struct.Coder.html#method.get_compressed
1104struct CoderGuard<'a, Word, State, Backend, const SEALED: bool>
1105where
1106    Word: BitArray + Into<State>,
1107    State: BitArray + AsPrimitive<Word>,
1108    Backend: WriteWords<Word> + ReadWords<Word, Stack>,
1109{
1110    inner: &'a mut AnsCoder<Word, State, Backend>,
1111}
1112
1113impl<'a, Word, State, Backend, const SEALED: bool> CoderGuard<'a, Word, State, Backend, SEALED>
1114where
1115    Word: BitArray + Into<State>,
1116    State: BitArray + AsPrimitive<Word>,
1117    Backend: WriteWords<Word> + ReadWords<Word, Stack>,
1118{
1119    #[inline(always)]
1120    fn new(
1121        ans: &'a mut AnsCoder<Word, State, Backend>,
1122    ) -> Result<Self, CoderError<(), Backend::WriteError>> {
1123        // Append state. Will be undone in `<Self as Drop>::drop`.
1124        let mut chunks_rev = bit_array_to_chunks_truncated(ans.state);
1125        if SEALED && chunks_rev.next() != Some(Word::one()) {
1126            return Err(CoderError::Frontend(()));
1127        }
1128        for chunk in chunks_rev.rev() {
1129            ans.bulk.write(chunk)?
1130        }
1131
1132        Ok(Self { inner: ans })
1133    }
1134}
1135
1136impl<Word, State, Backend, const SEALED: bool> Drop for CoderGuard<'_, Word, State, Backend, SEALED>
1137where
1138    Word: BitArray + Into<State>,
1139    State: BitArray + AsPrimitive<Word>,
1140    Backend: WriteWords<Word> + ReadWords<Word, Stack>,
1141{
1142    fn drop(&mut self) {
1143        // Revert what we did in `Self::new`.
1144        let mut chunks_rev = bit_array_to_chunks_truncated(self.inner.state);
1145        if SEALED {
1146            chunks_rev.next();
1147        }
1148        for _ in chunks_rev {
1149            core::mem::drop(self.inner.bulk.read());
1150        }
1151    }
1152}
1153
1154impl<Word, State, Backend, const SEALED: bool> Deref
1155    for CoderGuard<'_, Word, State, Backend, SEALED>
1156where
1157    Word: BitArray + Into<State>,
1158    State: BitArray + AsPrimitive<Word>,
1159    Backend: WriteWords<Word> + ReadWords<Word, Stack>,
1160{
1161    type Target = Backend;
1162
1163    fn deref(&self) -> &Self::Target {
1164        &self.inner.bulk
1165    }
1166}
1167
1168impl<Word, State, Backend, const SEALED: bool> Debug
1169    for CoderGuard<'_, Word, State, Backend, SEALED>
1170where
1171    Word: BitArray + Into<State>,
1172    State: BitArray + AsPrimitive<Word>,
1173    Backend: WriteWords<Word> + ReadWords<Word, Stack> + Debug,
1174{
1175    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1176        Debug::fmt(&**self, f)
1177    }
1178}
1179
1180#[cfg(test)]
1181mod tests {
1182    use super::super::model::{
1183        ContiguousCategoricalEntropyModel, DefaultLeakyQuantizer, IterableEntropyModel,
1184        LeakyQuantizer,
1185    };
1186    use super::*;
1187    extern crate std;
1188    use std::dbg;
1189
1190    use probability::distribution::{Gaussian, Inverse};
1191    use rand_xoshiro::{
1192        rand_core::{RngCore, SeedableRng},
1193        Xoshiro256StarStar,
1194    };
1195
1196    #[test]
1197    fn compress_none() {
1198        let coder1 = DefaultAnsCoder::new();
1199        assert!(coder1.is_empty());
1200        let compressed = coder1.into_compressed().unwrap();
1201        assert!(compressed.is_empty());
1202
1203        let coder2 = DefaultAnsCoder::from_compressed(compressed).unwrap();
1204        assert!(coder2.is_empty());
1205    }
1206
1207    #[test]
1208    fn compress_one() {
1209        generic_compress_few(core::iter::once(5), 1)
1210    }
1211
1212    #[test]
1213    fn compress_two() {
1214        generic_compress_few([2, 8].iter().cloned(), 1)
1215    }
1216
1217    #[test]
1218    fn compress_ten() {
1219        generic_compress_few(0..10, 2)
1220    }
1221
1222    #[test]
1223    fn compress_twenty() {
1224        generic_compress_few(-10..10, 4)
1225    }
1226
1227    fn generic_compress_few<I>(symbols: I, expected_size: usize)
1228    where
1229        I: IntoIterator<Item = i32>,
1230        I::IntoIter: Clone + DoubleEndedIterator,
1231    {
1232        let symbols = symbols.into_iter();
1233
1234        let mut encoder = DefaultAnsCoder::new();
1235        let quantizer = DefaultLeakyQuantizer::new(-127..=127);
1236        let model = quantizer.quantize(Gaussian::new(3.2, 5.1));
1237
1238        // We don't reuse the same encoder for decoding because we want to test
1239        // if exporting and re-importing of compressed data works.
1240        encoder.encode_iid_symbols(symbols.clone(), model).unwrap();
1241        let compressed = encoder.into_compressed().unwrap();
1242        assert_eq!(compressed.len(), expected_size);
1243
1244        let mut decoder = DefaultAnsCoder::from_compressed(compressed).unwrap();
1245        for symbol in symbols.rev() {
1246            assert_eq!(decoder.decode_symbol(model).unwrap(), symbol);
1247        }
1248        assert!(decoder.is_empty());
1249    }
1250
1251    #[test]
1252    fn compress_many_u32_u64_32() {
1253        generic_compress_many::<u32, u64, u32, 32>();
1254    }
1255
1256    #[test]
1257    fn compress_many_u32_u64_24() {
1258        generic_compress_many::<u32, u64, u32, 24>();
1259    }
1260
1261    #[test]
1262    fn compress_many_u32_u64_16() {
1263        generic_compress_many::<u32, u64, u16, 16>();
1264    }
1265
1266    #[test]
1267    fn compress_many_u32_u64_8() {
1268        generic_compress_many::<u32, u64, u8, 8>();
1269    }
1270
1271    #[test]
1272    fn compress_many_u16_u64_16() {
1273        generic_compress_many::<u16, u64, u16, 16>();
1274    }
1275
1276    #[test]
1277    fn compress_many_u16_u64_12() {
1278        generic_compress_many::<u16, u64, u16, 12>();
1279    }
1280
1281    #[test]
1282    fn compress_many_u16_u64_8() {
1283        generic_compress_many::<u16, u64, u8, 8>();
1284    }
1285
1286    #[test]
1287    fn compress_many_u8_u64_8() {
1288        generic_compress_many::<u8, u64, u8, 8>();
1289    }
1290
1291    #[test]
1292    fn compress_many_u16_u32_16() {
1293        generic_compress_many::<u16, u32, u16, 16>();
1294    }
1295
1296    #[test]
1297    fn compress_many_u16_u32_12() {
1298        generic_compress_many::<u16, u32, u16, 12>();
1299    }
1300
1301    #[test]
1302    fn compress_many_u16_u32_8() {
1303        generic_compress_many::<u16, u32, u8, 8>();
1304    }
1305
1306    #[test]
1307    fn compress_many_u8_u32_8() {
1308        generic_compress_many::<u8, u32, u8, 8>();
1309    }
1310
1311    #[test]
1312    fn compress_many_u8_u16_8() {
1313        generic_compress_many::<u8, u16, u8, 8>();
1314    }
1315
1316    fn generic_compress_many<Word, State, Probability, const PRECISION: usize>()
1317    where
1318        State: BitArray + AsPrimitive<Word>,
1319        Word: BitArray + Into<State> + AsPrimitive<Probability>,
1320        Probability: BitArray + Into<Word> + AsPrimitive<usize> + Into<f64>,
1321        u32: AsPrimitive<Probability>,
1322        usize: AsPrimitive<Probability>,
1323        f64: AsPrimitive<Probability>,
1324        i32: AsPrimitive<Probability>,
1325    {
1326        #[cfg(not(miri))]
1327        const AMT: usize = 1000;
1328
1329        #[cfg(miri)]
1330        const AMT: usize = 100;
1331
1332        let mut symbols_gaussian = Vec::with_capacity(AMT);
1333        let mut means = Vec::with_capacity(AMT);
1334        let mut stds = Vec::with_capacity(AMT);
1335
1336        let mut rng = Xoshiro256StarStar::seed_from_u64(
1337            (Word::BITS as u64).rotate_left(3 * 16)
1338                ^ (State::BITS as u64).rotate_left(2 * 16)
1339                ^ (Probability::BITS as u64).rotate_left(16)
1340                ^ PRECISION as u64,
1341        );
1342
1343        for _ in 0..AMT {
1344            let mean = (200.0 / u32::MAX as f64) * rng.next_u32() as f64 - 100.0;
1345            let std_dev = (10.0 / u32::MAX as f64) * rng.next_u32() as f64 + 0.001;
1346            let quantile = (rng.next_u32() as f64 + 0.5) / (1u64 << 32) as f64;
1347            let dist = Gaussian::new(mean, std_dev);
1348            let symbol = (dist.inverse(quantile).round() as i32).clamp(-127, 127);
1349
1350            symbols_gaussian.push(symbol);
1351            means.push(mean);
1352            stds.push(std_dev);
1353        }
1354
1355        let hist = [
1356            1u32, 186545, 237403, 295700, 361445, 433686, 509456, 586943, 663946, 737772, 1657269,
1357            896675, 922197, 930672, 916665, 0, 0, 0, 0, 0, 723031, 650522, 572300, 494702, 418703,
1358            347600, 1, 283500, 226158, 178194, 136301, 103158, 76823, 55540, 39258, 27988, 54269,
1359        ];
1360        let categorical_probabilities = hist.iter().map(|&x| x as f64).collect::<Vec<_>>();
1361        let categorical =
1362            ContiguousCategoricalEntropyModel::<Probability, _, PRECISION>::from_floating_point_probabilities_fast::<f64>(
1363                &categorical_probabilities,None
1364            )
1365            .unwrap();
1366        let mut symbols_categorical = Vec::with_capacity(AMT);
1367        let max_probability = Probability::max_value() >> (Probability::BITS - PRECISION);
1368        for _ in 0..AMT {
1369            let quantile = rng.next_u32().as_() & max_probability;
1370            let symbol = categorical.quantile_function(quantile).0;
1371            symbols_categorical.push(symbol);
1372        }
1373
1374        let mut ans = AnsCoder::<Word, State>::new();
1375
1376        ans.encode_iid_symbols_reverse(&symbols_categorical, &categorical)
1377            .unwrap();
1378        dbg!(
1379            ans.num_valid_bits(),
1380            AMT as f64 * categorical.entropy_base2::<f64>()
1381        );
1382
1383        let quantizer = LeakyQuantizer::<_, _, Probability, PRECISION>::new(-127..=127);
1384        ans.encode_symbols_reverse(symbols_gaussian.iter().zip(&means).zip(&stds).map(
1385            |((&symbol, &mean), &core)| (symbol, quantizer.quantize(Gaussian::new(mean, core))),
1386        ))
1387        .unwrap();
1388        dbg!(ans.num_valid_bits());
1389
1390        // Test if import/export of compressed data works.
1391        let compressed = ans.into_compressed().unwrap();
1392        let mut ans = AnsCoder::from_compressed(compressed).unwrap();
1393
1394        let reconstructed_gaussian = ans
1395            .decode_symbols(
1396                means
1397                    .iter()
1398                    .zip(&stds)
1399                    .map(|(&mean, &core)| quantizer.quantize(Gaussian::new(mean, core))),
1400            )
1401            .collect::<Result<Vec<_>, CoderError<Infallible, Infallible>>>()
1402            .unwrap();
1403        let reconstructed_categorical = ans
1404            .decode_iid_symbols(AMT, &categorical)
1405            .collect::<Result<Vec<_>, CoderError<Infallible, Infallible>>>()
1406            .unwrap();
1407
1408        assert!(ans.is_empty());
1409
1410        assert_eq!(symbols_gaussian, reconstructed_gaussian);
1411        assert_eq!(symbols_categorical, reconstructed_categorical);
1412    }
1413
1414    #[test]
1415    fn seek() {
1416        #[cfg(not(miri))]
1417        let (num_chunks, symbols_per_chunk) = (100, 100);
1418
1419        #[cfg(miri)]
1420        let (num_chunks, symbols_per_chunk) = (10, 10);
1421
1422        let quantizer = DefaultLeakyQuantizer::new(-100..=100);
1423        let model = quantizer.quantize(Gaussian::new(0.0, 10.0));
1424
1425        let mut encoder = DefaultAnsCoder::new();
1426
1427        let mut rng = Xoshiro256StarStar::seed_from_u64(123);
1428        let mut symbols = Vec::with_capacity(num_chunks);
1429        let mut jump_table = Vec::with_capacity(num_chunks);
1430        let (initial_pos, initial_state) = encoder.pos();
1431
1432        for _ in 0..num_chunks {
1433            let chunk = (0..symbols_per_chunk)
1434                .map(|_| model.quantile_function(rng.next_u32() % (1 << 24)).0)
1435                .collect::<Vec<_>>();
1436            encoder.encode_iid_symbols_reverse(&chunk, &model).unwrap();
1437            symbols.push(chunk);
1438            jump_table.push(encoder.pos());
1439        }
1440
1441        // Test decoding from back to front.
1442        {
1443            let mut seekable_decoder = encoder.as_seekable_decoder();
1444
1445            // Verify that decoding leads to the same positions and states.
1446            for (chunk, &(pos, state)) in symbols.iter().zip(&jump_table).rev() {
1447                assert_eq!(seekable_decoder.pos(), (pos, state));
1448                let decoded = seekable_decoder
1449                    .decode_iid_symbols(symbols_per_chunk, &model)
1450                    .collect::<Result<Vec<_>, _>>()
1451                    .unwrap();
1452                assert_eq!(&decoded, chunk)
1453            }
1454            assert_eq!(seekable_decoder.pos(), (initial_pos, initial_state));
1455            assert!(seekable_decoder.is_empty());
1456
1457            // Seek to some random offsets in the jump table and decode one chunk
1458            for _ in 0..100 {
1459                let chunk_index = rng.next_u32() as usize % num_chunks;
1460                let (pos, state) = jump_table[chunk_index];
1461                seekable_decoder.seek((pos, state)).unwrap();
1462                let decoded = seekable_decoder
1463                    .decode_iid_symbols(symbols_per_chunk, &model)
1464                    .collect::<Result<Vec<_>, _>>()
1465                    .unwrap();
1466                assert_eq!(&decoded, &symbols[chunk_index])
1467            }
1468        }
1469
1470        // Reverse compressed data, map positions in jump table to reversed positions,
1471        // and test decoding from front to back.
1472        let mut compressed = encoder.into_compressed().unwrap();
1473        compressed.reverse();
1474        for (pos, _state) in jump_table.iter_mut() {
1475            *pos = compressed.len() - *pos;
1476        }
1477        let initial_pos = compressed.len() - initial_pos;
1478
1479        {
1480            let mut seekable_decoder = AnsCoder::from_reversed_compressed(compressed).unwrap();
1481
1482            // Verify that decoding leads to the expected positions and states.
1483            for (chunk, &(pos, state)) in symbols.iter().zip(&jump_table).rev() {
1484                assert_eq!(seekable_decoder.pos(), (pos, state));
1485                let decoded = seekable_decoder
1486                    .decode_iid_symbols(symbols_per_chunk, &model)
1487                    .collect::<Result<Vec<_>, _>>()
1488                    .unwrap();
1489                assert_eq!(&decoded, chunk)
1490            }
1491            assert_eq!(seekable_decoder.pos(), (initial_pos, initial_state));
1492            assert!(seekable_decoder.is_empty());
1493
1494            // Seek to some random offsets in the jump table and decode one chunk each time.
1495            for _ in 0..100 {
1496                let chunk_index = rng.next_u32() as usize % num_chunks;
1497                let (pos, state) = jump_table[chunk_index];
1498                seekable_decoder.seek((pos, state)).unwrap();
1499                let decoded = seekable_decoder
1500                    .decode_iid_symbols(symbols_per_chunk, &model)
1501                    .collect::<Result<Vec<_>, _>>()
1502                    .unwrap();
1503                assert_eq!(&decoded, &symbols[chunk_index])
1504            }
1505        }
1506    }
1507}