constriction/stream/
queue.rs

1//! Near-optimal compression on a queue ("first in first out")
2//!
3//! This module provides an implementation of the Range Coding algorithm \[1], an entropy
4//! coder with near-optimal compression effectiveness that operates as a *queue* data
5//! structure. Range Coding is a more computationally efficient variant of Arithmetic
6//! Coding.
7//!
8//! # Comparison to sister module `stack`
9//!
10//! Range Coding operates as a *queue*: decoding a sequence of symbols yields the symbols in
11//! the same order in which they were encoded. This is unlike the case with the [`AnsCoder`]
12//! in the sister module [`stack`], which decodes in reverse order. Therefore, Range Coding
13//! is typically the preferred method for autoregressive models. On the other hand, the
14//! provided implementation of Range Coding uses two distinct data structures,
15//! [`RangeEncoder`] and [`RangeDecoder`], for encoding and decoding, respectively. This
16//! means that, unlike the case with the `AnsCoder`, encoding and decoding operations on a
17//! Range Coder cannot be interleaved: once you've *sealed* a `RangeEncoder` (e.g., by
18//! calling [`.into_compressed()`] on it) you cannot add any more compressed data onto it.
19//! This makes Range Coding difficult to use for advanced compression techniques such as
20//! bits-back coding with hierarchical models.
21//!
22//! The parent module contains a more detailed discussion of the [differences between ANS
23//! Coding and Range Coding](super#which-stream-code-should-i-use) .
24//!
25//! # References
26//!
27//! \[1] Pasco, Richard Clark. Source coding algorithms for fast data compression. Diss.
28//! Stanford University, 1976.
29//!
30//! [`AnsCoder`]: super::stack::AnsCoder
31//! [`stack`]: super::stack
32//! [`.into_compressed()`]: RangeEncoder::into_compressed
33
34use alloc::vec::Vec;
35use core::{
36    borrow::Borrow,
37    fmt::{Debug, Display},
38    marker::PhantomData,
39    num::NonZeroUsize,
40    ops::Deref,
41};
42
43use num_traits::AsPrimitive;
44
45use super::{
46    model::{DecoderModel, EncoderModel},
47    Code, Decode, Encode, IntoDecoder,
48};
49use crate::{
50    backends::{AsReadWords, BoundedReadWords, Cursor, IntoReadWords, ReadWords, WriteWords},
51    generic_static_asserts, BitArray, CoderError, DefaultEncoderError, DefaultEncoderFrontendError,
52    NonZeroBitArray, Pos, PosSeek, Queue, Seek, UnwrapInfallible,
53};
54
55/// Type of the internal state used by [`RangeEncoder<Word, State>`] and
56/// [`RangeDecoder<Word, State>`]. Relevant for [`Seek`]ing.
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
58pub struct RangeCoderState<Word, State: BitArray> {
59    lower: State,
60
61    /// Invariant: `range >= State::one() << (State::BITS - Word::BITS)`
62    /// Therefore, the highest order `Word` of `lower` is always sufficient to
63    /// identify the current interval, so only it has to be flushed at the end.
64    range: State::NonZero,
65
66    /// We keep track of the `Word` type so that we can statically enforce
67    /// the invariants for `lower` and `range`.
68    phantom: PhantomData<Word>,
69}
70
71impl<Word: BitArray, State: BitArray> RangeCoderState<Word, State> {
72    #[allow(clippy::result_unit_err)]
73    pub fn new(lower: State, range: State) -> Result<Self, ()> {
74        if range >> (State::BITS - Word::BITS) == State::zero() {
75            Err(())
76        } else {
77            Ok(Self {
78                lower,
79                range: range.into_nonzero().expect("We checked above."),
80                phantom: PhantomData,
81            })
82        }
83    }
84
85    /// Get the lower bound of the current range (inclusive)
86    pub fn lower(&self) -> State {
87        self.lower
88    }
89
90    /// Get the size of the current range
91    pub fn range(&self) -> State::NonZero {
92        self.range
93    }
94}
95
96impl<Word: BitArray, State: BitArray> Default for RangeCoderState<Word, State> {
97    fn default() -> Self {
98        Self {
99            lower: State::zero(),
100            range: State::max_value().into_nonzero().expect("max_value() != 0"),
101            phantom: PhantomData,
102        }
103    }
104}
105
106#[derive(Debug, Clone)]
107pub struct RangeEncoder<Word, State, Backend = Vec<Word>>
108where
109    Word: BitArray,
110    State: BitArray,
111    Backend: WriteWords<Word>,
112{
113    bulk: Backend,
114    state: RangeCoderState<Word, State>,
115    situation: EncoderSituation<Word>,
116}
117
118/// Keeps track of yet-to-be-finalized compressed words during encoding with a
119/// [`RangeEncoder`].
120///
121/// This type is mostly for internal use. It is only expsed via
122/// [`RangeEncoder::into_raw_parts`] and [`RangeEncoder::from_raw_parts`].
123#[derive(Debug, Clone, Copy, PartialEq, Eq)]
124pub enum EncoderSituation<Word> {
125    /// In the `Normal` situation, all full `Words` of compressed data have been written to
126    /// the backend (or "bulk"), and the internal coder state holds less than one word of
127    /// additional information content.
128    Normal,
129
130    /// The `Inverted` situation occurs only rarely. In this situation, some full words of
131    /// compressed data have been held back and not yet written to the backend (or "bulk")
132    /// because their final values may still change depending on subsequently encoded
133    /// symbols.
134    ///
135    /// More precisely, a situation of `Inverted(num_subsequent, first_word)` means that the
136    /// held-back words can become either `first_word + 1` followed by `num_subsequent` zero
137    /// words, or `first_word` followed by `num_subsequent` words that have all bits set.
138    Inverted(NonZeroUsize, Word),
139}
140
141impl<Word> Default for EncoderSituation<Word> {
142    fn default() -> Self {
143        Self::Normal
144    }
145}
146
147/// Type alias for an [`RangeEncoder`] with sane parameters for typical use cases.
148pub type DefaultRangeEncoder<Backend = Vec<u32>> = RangeEncoder<u32, u64, Backend>;
149
150/// Type alias for a [`RangeEncoder`] for use with lookup models
151///
152/// This encoder has a smaller word size and internal state than [`DefaultRangeEncoder`].
153/// This allows you to use lookup models when decoding data that was encoded with this
154/// coder, see [`SmallRangeDecoder`], as well as [`ContiguousLookupDecoderModel`] and
155/// [`NonContiguousLookupDecoderModel`].
156///
157/// [lookup models]: crate::stream::model::lookup
158/// [`ContiguousLookupDecoderModel`]: crate::stream::model::ContiguousLookupDecoderModel
159/// [`NonContiguousLookupDecoderModel`]: crate::stream::model::NonContiguousLookupDecoderModel
160pub type SmallRangeEncoder<Backend = Vec<u16>> = RangeEncoder<u16, u32, Backend>;
161
162impl<Word, State, Backend> Code for RangeEncoder<Word, State, Backend>
163where
164    Word: BitArray + Into<State>,
165    State: BitArray + AsPrimitive<Word>,
166    Backend: WriteWords<Word>,
167{
168    type State = RangeCoderState<Word, State>;
169    type Word = Word;
170
171    fn state(&self) -> Self::State {
172        self.state
173    }
174}
175
176impl<Word, State, Backend> PosSeek for RangeEncoder<Word, State, Backend>
177where
178    Word: BitArray,
179    State: BitArray,
180    Backend: WriteWords<Word> + PosSeek,
181    Self: Code,
182{
183    type Position = (Backend::Position, <Self as Code>::State);
184}
185
186impl<Word, State, Backend> Pos for RangeEncoder<Word, State, Backend>
187where
188    Word: BitArray + Into<State>,
189    State: BitArray + AsPrimitive<Word>,
190    Backend: WriteWords<Word> + Pos<Position = usize>,
191{
192    fn pos(&self) -> Self::Position {
193        let num_inverted = if let EncoderSituation::Inverted(num_inverted, _) = self.situation {
194            num_inverted.get()
195        } else {
196            0
197        };
198        (self.bulk.pos() + num_inverted, self.state())
199    }
200}
201
202impl<Word, State, Backend> Default for RangeEncoder<Word, State, Backend>
203where
204    Word: BitArray + Into<State>,
205    State: BitArray + AsPrimitive<Word>,
206    Backend: WriteWords<Word> + Default,
207{
208    /// This is essentially the same as `#[derive(Default)]`, except for the assertions on
209    /// `State::BITS` and `Word::BITS`.
210    fn default() -> Self {
211        Self::with_backend(Backend::default())
212    }
213}
214
215impl<Word, State> RangeEncoder<Word, State>
216where
217    Word: BitArray + Into<State>,
218    State: BitArray + AsPrimitive<Word>,
219{
220    /// Creates an empty encoder for range coding.
221    pub fn new() -> Self {
222        generic_static_asserts!(
223            (Word: BitArray, State:BitArray);
224            STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
225            STATE_SIZE_IS_MULTIPLE_OF_WORD_SIZE: State::BITS % Word::BITS == 0;
226        );
227
228        Self {
229            bulk: Vec::new(),
230            state: RangeCoderState::default(),
231            situation: EncoderSituation::Normal,
232        }
233    }
234}
235
236impl<Word, State> From<RangeEncoder<Word, State>> for Vec<Word>
237where
238    Word: BitArray + Into<State>,
239    State: BitArray + AsPrimitive<Word>,
240{
241    fn from(val: RangeEncoder<Word, State>) -> Self {
242        val.into_compressed().unwrap_infallible()
243    }
244}
245
246impl<Word, State, Backend> RangeEncoder<Word, State, Backend>
247where
248    Word: BitArray + Into<State>,
249    State: BitArray + AsPrimitive<Word>,
250    Backend: WriteWords<Word>,
251{
252    /// Assumes that the `backend` is in a state where the encoder can start writing as if
253    /// it was an empty backend. If there's already some compressed data on `backend`, then
254    /// this method will just concatanate the new sequence of `Word`s to the existing
255    /// sequence of `Word`s without gluing them together. This is likely not what you want
256    /// since you won't be able to decode the data in one go (however, it is Ok to
257    /// concatenate arbitrary data to the output of a `RangeEncoder`; it won't invalidate
258    /// the existing data).
259    ///
260    /// If you need an entropy coder that can be interrupted and serialized/deserialized
261    /// (i.e., an encoder that can encode some symbols, return the compressed bit string as
262    /// a sequence of `Words`, load the `Words` back in at a later point and then encode
263    /// some more symbols), then consider using an [`AnsCoder`].
264    ///
265    /// TODO: rename to `with_write_backend` and then add the same method to `AnsCoder`
266    ///
267    /// [`AnsCoder`]: super::stack::AnsCoder
268    pub fn with_backend(backend: Backend) -> Self {
269        generic_static_asserts!(
270            (Word: BitArray, State:BitArray);
271            STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
272            STATE_SIZE_IS_MULTIPLE_OF_WORD_SIZE: State::BITS % Word::BITS == 0;
273        );
274
275        Self {
276            bulk: backend,
277            state: RangeCoderState::default(),
278            situation: EncoderSituation::Normal,
279        }
280    }
281
282    /// Check if no data has been encoded yet.
283    pub fn is_empty<'a>(&'a self) -> bool
284    where
285        Backend: AsReadWords<'a, Word, Queue>,
286        Backend::AsReadWords: BoundedReadWords<Word, Queue>,
287    {
288        self.state.range.get() == State::max_value() && self.bulk.as_read_words().is_exhausted()
289    }
290
291    /// Same as `Encoder::maybe_full`, but can be called on a concrete type without type
292    /// annotations.
293    pub fn maybe_full(&self) -> bool {
294        self.bulk.maybe_full()
295    }
296
297    /// Same as IntoDecoder::into_decoder(self) but can be used for any `PRECISION`
298    /// and therefore doesn't require type arguments on the caller side.
299    ///
300    /// TODO: there should also be a `decoder()` method that takes `&mut self`
301    #[allow(clippy::result_unit_err)]
302    pub fn into_decoder(self) -> Result<RangeDecoder<Word, State, Backend::IntoReadWords>, ()>
303    where
304        Backend: IntoReadWords<Word, Queue>,
305    {
306        // TODO: return proper error (or just box it up).
307        RangeDecoder::from_compressed(self.into_compressed().map_err(|_| ())?).map_err(|_| ())
308    }
309
310    pub fn into_compressed(mut self) -> Result<Backend, Backend::WriteError> {
311        self.seal()?;
312        Ok(self.bulk)
313    }
314
315    /// Private method; flushes held-back words if in inverted situation and adds one or two
316    /// additional words that identify the range regardless of what the compressed data may
317    /// be concatenated with (unless no symbols have been encoded yet, in which case this is
318    /// a no-op).
319    ///
320    /// Doesn't change `self.state` or `self.situation` so that this operation can be
321    /// reversed if the backend supports removing words (see method `unseal`);
322    fn seal(&mut self) -> Result<(), Backend::WriteError> {
323        if self.state.range.get() == State::max_value() {
324            // This condition only holds upon initialization because encoding a symbol first
325            // reduces `range` and then only (possibly) right-shifts it, which introduces
326            // some zero bits. We treat this case special and don't emit any words, so that
327            // an empty sequence of symbols gets encoded to an empty sequence of words.
328            return Ok(());
329        }
330
331        let point = self
332            .state
333            .lower
334            .wrapping_add(&((State::one() << (State::BITS - Word::BITS)) - State::one()));
335
336        if let EncoderSituation::Inverted(num_inverted, first_inverted_lower_word) = self.situation
337        {
338            let (first_word, consecutive_words) = if point < self.state.lower {
339                // Unlikely case (addition has wrapped).
340                (first_inverted_lower_word + Word::one(), Word::zero())
341            } else {
342                // Likely case.
343                (first_inverted_lower_word, Word::max_value())
344            };
345
346            self.bulk.write(first_word)?;
347            for _ in 1..num_inverted.get() {
348                self.bulk.write(consecutive_words)?;
349            }
350        }
351
352        let point_word = (point >> (State::BITS - Word::BITS)).as_();
353        self.bulk.write(point_word)?;
354
355        let upper_word = (self.state.lower.wrapping_add(&self.state.range.get())
356            >> (State::BITS - Word::BITS))
357            .as_();
358        if upper_word == point_word {
359            self.bulk.write(Word::zero())?;
360        }
361
362        Ok(())
363    }
364
365    fn num_seal_words(&self) -> usize {
366        if self.state.range.get() == State::max_value() {
367            return 0;
368        }
369
370        let point = self
371            .state
372            .lower
373            .wrapping_add(&((State::one() << (State::BITS - Word::BITS)) - State::one()));
374        let point_word = (point >> (State::BITS - Word::BITS)).as_();
375        let upper_word = (self.state.lower.wrapping_add(&self.state.range.get())
376            >> (State::BITS - Word::BITS))
377            .as_();
378        let mut count = if upper_word == point_word { 2 } else { 1 };
379
380        if let EncoderSituation::Inverted(num_inverted, _) = self.situation {
381            count += num_inverted.get();
382        }
383        count
384    }
385
386    /// Returns the number of compressed words on the ans.
387    ///
388    /// This includes a constant overhead of between one and two words unless the
389    /// coder is completely empty.
390    ///
391    /// This method returns the length of the slice, the `Vec<Word>`, or the iterator
392    /// that would be returned by [`get_compressed`], [`into_compressed`], or
393    /// [`iter_compressed`], respectively, when called at this time.
394    ///
395    /// See also [`num_bits`].
396    ///
397    /// [`get_compressed`]: #method.get_compressed
398    /// [`into_compressed`]: #method.into_compressed
399    /// [`iter_compressed`]: #method.iter_compressed
400    /// [`num_bits`]: #method.num_bits
401    pub fn num_words<'a>(&'a self) -> usize
402    where
403        Backend: AsReadWords<'a, Word, Queue>,
404        Backend::AsReadWords: BoundedReadWords<Word, Queue>,
405    {
406        self.bulk.as_read_words().remaining() + self.num_seal_words()
407    }
408
409    /// Returns the size of the current queue of compressed data in bits.
410    ///
411    /// This includes some constant overhead unless the coder is completely empty
412    /// (see [`num_words`](#method.num_words)).
413    ///
414    /// The returned value is a multiple of the bitlength of the compressed word
415    /// type `Word`.
416    pub fn num_bits<'a>(&'a self) -> usize
417    where
418        Backend: AsReadWords<'a, Word, Queue>,
419        Backend::AsReadWords: BoundedReadWords<Word, Queue>,
420    {
421        Word::BITS * self.num_words()
422    }
423
424    pub fn bulk(&self) -> &Backend {
425        &self.bulk
426    }
427
428    /// Low-level constructor that assembles a `RangeEncoder` from its internal components.
429    ///
430    /// The arguments `bulk`, `state`, and `situation` correspond to the three return values
431    /// of the method [`into_raw_parts`](Self::into_raw_parts).
432    pub fn from_raw_parts(
433        bulk: Backend,
434        state: RangeCoderState<Word, State>,
435        situation: EncoderSituation<Word>,
436    ) -> Self {
437        generic_static_asserts!(
438            (Word: BitArray, State:BitArray);
439            STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
440            STATE_SIZE_IS_MULTIPLE_OF_WORD_SIZE: State::BITS % Word::BITS == 0;
441        );
442
443        // The invariants for `state` are already enforced statically.
444
445        Self {
446            bulk,
447            state,
448            situation,
449        }
450    }
451
452    /// Low-level method that disassembles the `RangeEncoder` into its internal components.
453    ///
454    /// Can be used together with [`from_raw_parts`](Self::from_raw_parts).
455    pub fn into_raw_parts(
456        self,
457    ) -> (
458        Backend,
459        RangeCoderState<Word, State>,
460        EncoderSituation<Word>,
461    ) {
462        (self.bulk, self.state, self.situation)
463    }
464}
465
466impl<Word, State> RangeEncoder<Word, State>
467where
468    Word: BitArray + Into<State>,
469    State: BitArray + AsPrimitive<Word>,
470{
471    /// Discards all compressed data and resets the coder to the same state as
472    /// [`Coder::new`](#method.new).
473    pub fn clear(&mut self) {
474        self.bulk.clear();
475        self.state = RangeCoderState::default();
476    }
477
478    /// Assembles the current compressed data into a single slice.
479    ///
480    /// This method is only implemented for encoders backed by a `Vec<Word>`
481    /// because we have to temporarily seal the encoder and then unseal it when the returned
482    /// `EncoderGuard` is dropped, which requires precise knowledge of the backend (and
483    /// which is also the reason why this method takes a `&mut self` receiver). If you're
484    /// using a different backend than a `Vec`, consider calling [`into_compressed`]
485    /// instead.
486    ///
487    /// [`into_compressed`]: Self::into_compressed
488    pub fn get_compressed(&mut self) -> EncoderGuard<'_, Word, State> {
489        EncoderGuard::new(self)
490    }
491
492    // TODO: implement `iter_compressed`
493
494    /// A decoder for temporary use.
495    ///
496    /// Once the returned decoder gets dropped, you can continue using this encoder. If you
497    /// don't need this flexibility, call [`into_decoder`] instead.
498    ///
499    /// This method is only implemented for encoders backed by a `Vec<Word>`
500    /// because we have to temporarily seal the encoder and then unseal it when the returned
501    /// decoder is dropped, which requires precise knowledge of the backend (and which is
502    /// also the reason why this method takes a `&mut self`receiver). If you're using a
503    /// different backend than a `Vec`, consider calling [`into_decoder`] instead.
504    ///
505    /// [`into_decoder`]: Self::into_decoder
506    pub fn decoder(
507        &mut self,
508    ) -> RangeDecoder<Word, State, Cursor<Word, EncoderGuard<'_, Word, State>>> {
509        RangeDecoder::from_compressed(self.get_compressed()).unwrap_infallible()
510    }
511
512    fn unseal(&mut self) {
513        for _ in 0..self.num_seal_words() {
514            let word = self.bulk.pop();
515            debug_assert!(word.is_some());
516        }
517    }
518}
519
520impl<Word, State, Backend, const PRECISION: usize> IntoDecoder<PRECISION>
521    for RangeEncoder<Word, State, Backend>
522where
523    Word: BitArray + Into<State>,
524    State: BitArray + AsPrimitive<Word>,
525    Backend: WriteWords<Word> + IntoReadWords<Word, Queue>,
526{
527    type IntoDecoder = RangeDecoder<Word, State, Backend::IntoReadWords>;
528
529    fn into_decoder(self) -> Self::IntoDecoder {
530        self.into()
531    }
532}
533
534impl<Word, State, Backend, const PRECISION: usize> Encode<PRECISION>
535    for RangeEncoder<Word, State, Backend>
536where
537    Word: BitArray + Into<State>,
538    State: BitArray + AsPrimitive<Word>,
539    Backend: WriteWords<Word>,
540{
541    type FrontendError = DefaultEncoderFrontendError;
542    type BackendError = Backend::WriteError;
543
544    fn encode_symbol<D>(
545        &mut self,
546        symbol: impl Borrow<D::Symbol>,
547        model: D,
548    ) -> Result<(), DefaultEncoderError<Self::BackendError>>
549    where
550        D: EncoderModel<PRECISION>,
551        D::Probability: Into<Self::Word>,
552        Self::Word: AsPrimitive<D::Probability>,
553    {
554        generic_static_asserts!(
555            (Word: BitArray, State:BitArray; const PRECISION: usize);
556            PROBABILITY_SUPPORTS_PRECISION: State::BITS >= Word::BITS + PRECISION;
557            NON_ZERO_PRECISION: PRECISION > 0;
558            STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
559            STATE_SIZE_IS_MULTIPLE_OF_WORD_SIZE: State::BITS % Word::BITS == 0;
560        );
561
562        // We maintain the following invariant (*):
563        //   range >= State::one() << (State::BITS - Word::BITS)
564
565        let (left_sided_cumulative, probability) = model
566            .left_cumulative_and_probability(symbol)
567            .ok_or_else(|| DefaultEncoderFrontendError::ImpossibleSymbol.into_coder_error())?;
568
569        let scale = self.state.range.get() >> PRECISION;
570        // This cannot overflow since `scale * probability <= (range >> PRECISION) << PRECISION`
571        self.state.range = (scale * probability.get().into().into())
572            .into_nonzero()
573            .ok_or_else(|| DefaultEncoderFrontendError::ImpossibleSymbol.into_coder_error())?;
574        let new_lower = self
575            .state
576            .lower
577            .wrapping_add(&(scale * left_sided_cumulative.into().into()));
578
579        if let EncoderSituation::Inverted(num_inverted, first_inverted_lower_word) = self.situation
580        {
581            // unlikely branch
582            if new_lower.wrapping_add(&self.state.range.get()) > new_lower {
583                // We've transitioned from an inverted to a normal situation.
584
585                let (first_word, consecutive_words) = if new_lower < self.state.lower {
586                    (first_inverted_lower_word + Word::one(), Word::zero())
587                } else {
588                    (first_inverted_lower_word, Word::max_value())
589                };
590
591                self.bulk.write(first_word)?;
592                for _ in 1..num_inverted.get() {
593                    self.bulk.write(consecutive_words)?;
594                }
595
596                self.situation = EncoderSituation::Normal;
597            }
598        }
599
600        self.state.lower = new_lower;
601
602        if self.state.range.get() < State::one() << (State::BITS - Word::BITS) {
603            // Invariant `range >= State::one() << (State::BITS - Word::BITS)` is
604            // violated. Since `left_cumulative_and_probability` succeeded, we know that
605            // `probability != 0` and therefore:
606            //   range >= scale * probability = (old_range >> PRECISION) * probability
607            //         >= old_range >> PRECISION
608            //         >= old_range >> Word::BITS
609            // where `old_range` is the `range` at method entry, which satisfied invariant (*)
610            // by assumption. Therefore, the following left-shift restores the invariant:
611            self.state.range = unsafe {
612                // SAFETY:
613                // - `range` is nonzero because it is a `State::NonZero`
614                // - Shifting `range` left by `Word::BITS` bits doesn't truncate
615                //   because we checked that `range < 1 << (State::BITS - Word::Bits)`.
616                (self.state.range.get() << Word::BITS).into_nonzero_unchecked()
617            };
618
619            let lower_word = (self.state.lower >> (State::BITS - Word::BITS)).as_();
620            self.state.lower = self.state.lower << Word::BITS;
621
622            if let EncoderSituation::Inverted(num_inverted, _) = &mut self.situation {
623                // Transition from an inverted to an inverted situation (TODO: mark as unlikely branch).
624                *num_inverted = NonZeroUsize::new(num_inverted.get().wrapping_add(1))
625                    .expect("Cannot encode more symbols than what's addressable with usize.");
626            } else if self.state.lower.wrapping_add(&self.state.range.get()) > self.state.lower {
627                // Transition from a normal to a normal situation (the most common case).
628                self.bulk.write(lower_word)?;
629            } else {
630                // Transition from a normal to an inverted situation.
631                self.situation =
632                    EncoderSituation::Inverted(NonZeroUsize::new(1).expect("1 != 0"), lower_word);
633            }
634        }
635
636        Ok(())
637    }
638
639    fn maybe_full(&self) -> bool {
640        RangeEncoder::maybe_full(self)
641    }
642}
643
644#[derive(Debug, Clone)]
645pub struct RangeDecoder<Word, State, Backend>
646where
647    Word: BitArray,
648    State: BitArray,
649    Backend: ReadWords<Word, Queue>,
650{
651    bulk: Backend,
652
653    state: RangeCoderState<Word, State>,
654
655    /// Invariant: `point.wrapping_sub(&state.lower) < state.range`
656    point: State,
657}
658
659/// Type alias for a [`RangeDecoder`] with sane parameters for typical use cases.
660pub type DefaultRangeDecoder<Backend = Cursor<u32, Vec<u32>>> = RangeDecoder<u32, u64, Backend>;
661
662/// Type alias for a [`RangeDecoder`] for use with [lookup models]
663///
664/// This encoder has a smaller word size and internal state than [`DefaultRangeDecoder`]. It
665/// is optimized for use with lookup entropy models, in particular with a
666/// [`ContiguousLookupDecoderModel`] or a [`NonContiguousLookupDecoderModel`].
667///
668/// # Examples
669///
670/// See [`ContiguousLookupDecoderModel`] and [`NonContiguousLookupDecoderModel`].
671///
672/// # See also
673///
674/// - [`SmallRangeEncoder`]
675///
676/// [lookup models]: crate::stream::model::ContiguousLookupDecoderModel
677/// [`ContiguousLookupDecoderModel`]: crate::stream::model::ContiguousLookupDecoderModel
678/// [`NonContiguousLookupDecoderModel`]: crate::stream::model::NonContiguousLookupDecoderModel
679pub type SmallRangeDecoder<Backend> = RangeDecoder<u16, u32, Backend>;
680
681impl<Word, State, Backend> RangeDecoder<Word, State, Backend>
682where
683    Word: BitArray + Into<State>,
684    State: BitArray + AsPrimitive<Word>,
685    Backend: ReadWords<Word, Queue>,
686{
687    pub fn from_compressed<Buf>(compressed: Buf) -> Result<Self, Backend::ReadError>
688    where
689        Buf: IntoReadWords<Word, Queue, IntoReadWords = Backend>,
690    {
691        generic_static_asserts!(
692            (Word: BitArray, State:BitArray);
693            STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
694            STATE_SIZE_IS_MULTIPLE_OF_WORD_SIZE: State::BITS % Word::BITS == 0;
695        );
696
697        let mut bulk = compressed.into_read_words();
698        let point = Self::read_point(&mut bulk)?;
699
700        Ok(RangeDecoder {
701            bulk,
702            state: RangeCoderState::default(),
703            point,
704        })
705    }
706
707    pub fn with_backend(backend: Backend) -> Result<Self, Backend::ReadError> {
708        generic_static_asserts!(
709            (Word: BitArray, State:BitArray);
710            STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
711            STATE_SIZE_IS_MULTIPLE_OF_WORD_SIZE: State::BITS % Word::BITS == 0;
712        );
713
714        let mut bulk = backend;
715        let point = Self::read_point(&mut bulk)?;
716
717        Ok(RangeDecoder {
718            bulk,
719            state: RangeCoderState::default(),
720            point,
721        })
722    }
723
724    pub fn for_compressed<'a, Buf>(compressed: &'a Buf) -> Result<Self, Backend::ReadError>
725    where
726        Buf: AsReadWords<'a, Word, Queue, AsReadWords = Backend>,
727    {
728        generic_static_asserts!(
729            (Word: BitArray, State:BitArray);
730            STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
731            STATE_SIZE_IS_MULTIPLE_OF_WORD_SIZE: State::BITS % Word::BITS == 0;
732        );
733
734        let mut bulk = compressed.as_read_words();
735        let point = Self::read_point(&mut bulk)?;
736
737        Ok(RangeDecoder {
738            bulk,
739            state: RangeCoderState::default(),
740            point,
741        })
742    }
743
744    /// Low-level constructor that assembles a `RangeDecoder` from its internal components.
745    ///
746    /// The arguments `bulk`, `state`, and `point` correspond to the three return values of
747    /// the method [`into_raw_parts`](Self::into_raw_parts).
748    ///
749    /// The construction fails if the argument `point` lies outside of the range represented
750    /// by `state`. In this case, the method returns the (unmodified) argument `bulk` back
751    /// to the caller, wrapped in an `Err` variant.
752    pub fn from_raw_parts(
753        bulk: Backend,
754        state: RangeCoderState<Word, State>,
755        point: State,
756    ) -> Result<Self, Backend> {
757        generic_static_asserts!(
758            (Word: BitArray, State:BitArray);
759            STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
760            STATE_SIZE_IS_MULTIPLE_OF_WORD_SIZE: State::BITS % Word::BITS == 0;
761        );
762
763        // The invariants for `state` are already enforced statically.
764
765        if point.wrapping_sub(&state.lower) >= state.range.get() {
766            Err(bulk)
767        } else {
768            Ok(Self { bulk, state, point })
769        }
770    }
771
772    /// Low-level method that disassembles the `RangeDecoder` into its internal components.
773    ///
774    /// Can be used together with [`from_raw_parts`](Self::from_raw_parts).
775    pub fn into_raw_parts(self) -> (Backend, RangeCoderState<Word, State>, State) {
776        (self.bulk, self.state, self.point)
777    }
778
779    fn read_point<B: ReadWords<Word, Queue>>(bulk: &mut B) -> Result<State, B::ReadError> {
780        let mut num_read = 0;
781        let mut point = State::zero();
782        while let Some(word) = bulk.read()? {
783            point = point << Word::BITS | word.into();
784            num_read += 1;
785            if num_read == State::BITS / Word::BITS {
786                break;
787            }
788        }
789
790        #[allow(clippy::collapsible_if)]
791        if num_read < State::BITS / Word::BITS {
792            if num_read != 0 {
793                point = point << (State::BITS - num_read * Word::BITS);
794            }
795            // TODO: do we need to advance the Backend's `pos` beyond the end to make
796            // `PosBackend` consistent with its implementation for the encoder?
797        }
798
799        Ok(point)
800    }
801
802    /// Same as `Decoder::maybe_exhausted`, but can be called on a concrete type without
803    /// type annotations.
804    pub fn maybe_exhausted(&self) -> bool {
805        // The maximum possible difference between `point` and `lower`, even if the
806        // compressed data was concatenated with a lot of one bits.
807        let max_difference =
808            ((State::one() << (State::BITS - Word::BITS)) << 1).wrapping_sub(&State::one());
809
810        // The check for `self.state.range == State::max_value()` is for the special case of
811        // an empty buffer.
812        self.bulk.maybe_exhausted()
813            && (self.state.range.get() == State::max_value()
814                || self.point.wrapping_sub(&self.state.lower) < max_difference)
815    }
816}
817
818impl<Word, State, Backend> Code for RangeDecoder<Word, State, Backend>
819where
820    Word: BitArray + Into<State>,
821    State: BitArray + AsPrimitive<Word>,
822    Backend: ReadWords<Word, Queue>,
823{
824    type State = RangeCoderState<Word, State>;
825    type Word = Word;
826
827    fn state(&self) -> Self::State {
828        self.state
829    }
830}
831
832impl<Word, State, Backend> PosSeek for RangeDecoder<Word, State, Backend>
833where
834    Word: BitArray,
835    State: BitArray,
836    Backend: ReadWords<Word, Queue>,
837    Backend: PosSeek,
838    Self: Code,
839{
840    type Position = (Backend::Position, <Self as Code>::State);
841}
842
843impl<Word, State, Backend> Seek for RangeDecoder<Word, State, Backend>
844where
845    Word: BitArray + Into<State>,
846    State: BitArray + AsPrimitive<Word>,
847    Backend: ReadWords<Word, Queue> + Seek,
848{
849    fn seek(&mut self, pos_and_state: Self::Position) -> Result<(), ()> {
850        let (pos, state) = pos_and_state;
851
852        self.bulk.seek(pos)?;
853        self.point = Self::read_point(&mut self.bulk).map_err(|_| ())?;
854        self.state = state;
855
856        // TODO: deal with positions very close to end.
857
858        Ok(())
859    }
860}
861
862impl<Word, State, Backend> From<RangeEncoder<Word, State, Backend>>
863    for RangeDecoder<Word, State, Backend::IntoReadWords>
864where
865    Word: BitArray + Into<State>,
866    State: BitArray + AsPrimitive<Word>,
867    Backend: WriteWords<Word> + IntoReadWords<Word, Queue>,
868{
869    fn from(encoder: RangeEncoder<Word, State, Backend>) -> Self {
870        // TODO: implement a `try_into_decoder` or something instead. Or specialize this
871        // method to the case where both read and write error are Infallible, which is
872        // probably the only place where this will be used anyway.
873        encoder.into_decoder().unwrap()
874    }
875}
876
877// TODO (implement for infallible case)
878// impl<'a, Word, State, Backend> From<&'a mut RangeEncoder<Word, State, Backend>>
879//     for RangeDecoder<Word, State, Backend::AsReadWords>
880// where
881//     Word: BitArray + Into<State>,
882//     State: BitArray + AsPrimitive<Word>,
883//     Backend: WriteWords<Word> + AsReadWords<'a, Word, Queue>,
884// {
885//     fn from(encoder: &'a mut RangeEncoder<Word, State, Backend>) -> Self {
886//         encoder.as_decoder()
887//     }
888// }
889
890impl<Word, State, Backend, const PRECISION: usize> Decode<PRECISION>
891    for RangeDecoder<Word, State, Backend>
892where
893    Word: BitArray + Into<State>,
894    State: BitArray + AsPrimitive<Word>,
895    Backend: ReadWords<Word, Queue>,
896{
897    type FrontendError = DecoderFrontendError;
898
899    type BackendError = Backend::ReadError;
900
901    /// Decodes a single symbol and pops it off the compressed data.
902    ///
903    /// This is a low level method. You usually probably want to call a batch method
904    /// like [`decode_symbols`](#method.decode_symbols) or
905    /// [`decode_iid_symbols`](#method.decode_iid_symbols) instead.
906    ///
907    /// This method is called `decode_symbol` rather than `decode_symbol` to stress the
908    /// fact that the `Coder` is a stack: `decode_symbol` will return the *last* symbol
909    /// that was previously encoded via [`encode_symbol`](#method.encode_symbol).
910    ///
911    /// Note that this method cannot fail. It will still produce symbols in a
912    /// deterministic way even if the coder is empty, but such symbols will not
913    /// recover any previously encoded data and will generally have low entropy.
914    /// Still, being able to pop off an arbitrary number of symbols can sometimes be
915    /// useful in edge cases of, e.g., the bits-back algorithm.
916    fn decode_symbol<D>(
917        &mut self,
918        model: D,
919    ) -> Result<D::Symbol, CoderError<Self::FrontendError, Self::BackendError>>
920    where
921        D: DecoderModel<PRECISION>,
922        D::Probability: Into<Self::Word>,
923        Self::Word: AsPrimitive<D::Probability>,
924    {
925        generic_static_asserts!(
926            (Word: BitArray, State:BitArray; const PRECISION: usize);
927            PROBABILITY_SUPPORTS_PRECISION: State::BITS >= Word::BITS + PRECISION;
928            NON_ZERO_PRECISION: PRECISION > 0;
929            STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
930            STATE_SIZE_IS_MULTIPLE_OF_WORD_SIZE: State::BITS % Word::BITS == 0;
931        );
932
933        // We maintain the following invariant (*):
934        //   point (-) lower < range
935        // where (-) denotes wrapping subtraction (in `Self::State`).
936
937        let scale = self.state.range.get() >> PRECISION;
938        let quantile = self.point.wrapping_sub(&self.state.lower) / scale;
939        if quantile >= State::one() << PRECISION {
940            return Err(CoderError::Frontend(DecoderFrontendError::InvalidData));
941        }
942
943        let (symbol, left_sided_cumulative, probability) =
944            model.quantile_function(quantile.as_().as_());
945
946        // Update `state` in the same way as we do in `encode_symbol` (see comments there):
947        self.state.lower = self
948            .state
949            .lower
950            .wrapping_add(&(scale * left_sided_cumulative.into().into()));
951        self.state.range = (scale * probability.get().into().into())
952            .into_nonzero()
953            .expect("TODO");
954
955        // Invariant (*) is still satisfied at this point because:
956        //   (point (-) lower) / scale = (point (-) old_lower) / scale (-) left_sided_cumulative
957        //                             = quantile (-) left_sided_cumulative
958        //                             < probability
959        // Therefore, we have:
960        //   point (-) lower < scale * probability <= range
961
962        if self.state.range.get() < State::one() << (State::BITS - Word::BITS) {
963            // First update `state` in the same way as we do in `encode_symbol`:
964            self.state.lower = self.state.lower << Word::BITS;
965            self.state.range = unsafe {
966                // SAFETY:
967                // - `range` is nonzero because it is a `State::NonZero`
968                // - Shifting `range` left by `Word::BITS` bits doesn't truncate
969                //   because we checked that `range < 1 << (State::BITS - Word::Bits)`.
970                (self.state.range.get() << Word::BITS).into_nonzero_unchecked()
971            };
972
973            // Then update `point`, which restores invariant (*):
974            self.point = self.point << Word::BITS;
975            if let Some(word) = self.bulk.read()? {
976                self.point = self.point | word.into();
977            }
978
979            // TODO: register reads past end?
980        }
981
982        Ok(symbol)
983    }
984
985    fn maybe_exhausted(&self) -> bool {
986        RangeDecoder::maybe_exhausted(self)
987    }
988}
989
990/// Provides temporary read-only access to the compressed data wrapped in an
991/// [`RangeEncoder`].
992///
993/// Dereferences to `&[Word]`. See [`RangeEncoder::get_compressed`] for an example.
994pub struct EncoderGuard<'a, Word, State>
995where
996    Word: BitArray + Into<State>,
997    State: BitArray + AsPrimitive<Word>,
998{
999    inner: &'a mut RangeEncoder<Word, State>,
1000}
1001
1002impl<Word, State> Debug for EncoderGuard<'_, Word, State>
1003where
1004    Word: BitArray + Into<State>,
1005    State: BitArray + AsPrimitive<Word>,
1006{
1007    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1008        Debug::fmt(&**self, f)
1009    }
1010}
1011
1012impl<'a, Word, State> EncoderGuard<'a, Word, State>
1013where
1014    Word: BitArray + Into<State>,
1015    State: BitArray + AsPrimitive<Word>,
1016{
1017    fn new(encoder: &'a mut RangeEncoder<Word, State>) -> Self {
1018        // Append state. Will be undone in `<Self as Drop>::drop`.
1019        if !encoder.is_empty() {
1020            encoder.seal().unwrap_infallible();
1021        }
1022        Self { inner: encoder }
1023    }
1024}
1025
1026impl<'a, Word, State> Drop for EncoderGuard<'a, Word, State>
1027where
1028    Word: BitArray + Into<State>,
1029    State: BitArray + AsPrimitive<Word>,
1030{
1031    fn drop(&mut self) {
1032        self.inner.unseal();
1033    }
1034}
1035
1036impl<'a, Word, State> Deref for EncoderGuard<'a, Word, State>
1037where
1038    Word: BitArray + Into<State>,
1039    State: BitArray + AsPrimitive<Word>,
1040{
1041    type Target = [Word];
1042
1043    fn deref(&self) -> &Self::Target {
1044        &self.inner.bulk
1045    }
1046}
1047
1048impl<'a, Word, State> AsRef<[Word]> for EncoderGuard<'a, Word, State>
1049where
1050    Word: BitArray + Into<State>,
1051    State: BitArray + AsPrimitive<Word>,
1052{
1053    fn as_ref(&self) -> &[Word] {
1054        self
1055    }
1056}
1057
1058#[cfg(test)]
1059mod tests {
1060    extern crate std;
1061    use std::dbg;
1062
1063    use super::super::model::{
1064        ContiguousCategoricalEntropyModel, IterableEntropyModel, LeakyQuantizer,
1065    };
1066    use super::*;
1067
1068    use probability::distribution::{Gaussian, Inverse};
1069    use rand_xoshiro::{
1070        rand_core::{RngCore, SeedableRng},
1071        Xoshiro256StarStar,
1072    };
1073
1074    #[test]
1075    fn compress_none() {
1076        let encoder = DefaultRangeEncoder::new();
1077        assert!(encoder.is_empty());
1078        let compressed = encoder.into_compressed().unwrap();
1079        assert!(compressed.is_empty());
1080
1081        let decoder = DefaultRangeDecoder::from_compressed(compressed).unwrap();
1082        assert!(decoder.maybe_exhausted());
1083    }
1084
1085    #[test]
1086    fn compress_one() {
1087        generic_compress_few(core::iter::once(5), 1)
1088    }
1089
1090    #[test]
1091    fn compress_two() {
1092        generic_compress_few([2, 8].iter().cloned(), 1)
1093    }
1094
1095    #[test]
1096    fn compress_ten() {
1097        generic_compress_few(0..10, 2)
1098    }
1099
1100    #[test]
1101    fn compress_twenty() {
1102        generic_compress_few(-10..10, 4)
1103    }
1104
1105    fn generic_compress_few<I>(symbols: I, expected_size: usize)
1106    where
1107        I: IntoIterator<Item = i32>,
1108        I::IntoIter: Clone,
1109    {
1110        let symbols = symbols.into_iter();
1111
1112        let mut encoder = DefaultRangeEncoder::new();
1113        let quantizer = LeakyQuantizer::<_, _, u32, 24>::new(-127..=127);
1114        let model = quantizer.quantize(Gaussian::new(3.2, 5.1));
1115
1116        encoder.encode_iid_symbols(symbols.clone(), model).unwrap();
1117        let compressed = encoder.into_compressed().unwrap();
1118        assert_eq!(compressed.len(), expected_size);
1119
1120        let mut decoder = DefaultRangeDecoder::from_compressed(&compressed).unwrap();
1121        for symbol in symbols {
1122            assert_eq!(decoder.decode_symbol(model).unwrap(), symbol);
1123        }
1124        assert!(decoder.maybe_exhausted());
1125    }
1126
1127    #[test]
1128    fn compress_many_u32_u64_32() {
1129        generic_compress_many::<u32, u64, u32, 32>();
1130    }
1131
1132    #[test]
1133    fn compress_many_u32_u64_24() {
1134        generic_compress_many::<u32, u64, u32, 24>();
1135    }
1136
1137    #[test]
1138    fn compress_many_u32_u64_16() {
1139        generic_compress_many::<u32, u64, u16, 16>();
1140    }
1141
1142    #[test]
1143    fn compress_many_u32_u64_8() {
1144        generic_compress_many::<u32, u64, u8, 8>();
1145    }
1146
1147    #[test]
1148    fn compress_many_u16_u64_16() {
1149        generic_compress_many::<u16, u64, u16, 16>();
1150    }
1151
1152    #[test]
1153    fn compress_many_u16_u64_12() {
1154        generic_compress_many::<u16, u64, u16, 12>();
1155    }
1156
1157    #[test]
1158    fn compress_many_u16_u64_8() {
1159        generic_compress_many::<u16, u64, u8, 8>();
1160    }
1161
1162    #[test]
1163    fn compress_many_u8_u64_8() {
1164        generic_compress_many::<u8, u64, u8, 8>();
1165    }
1166
1167    #[test]
1168    fn compress_many_u16_u32_16() {
1169        generic_compress_many::<u16, u32, u16, 16>();
1170    }
1171
1172    #[test]
1173    fn compress_many_u16_u32_12() {
1174        generic_compress_many::<u16, u32, u16, 12>();
1175    }
1176
1177    #[test]
1178    fn compress_many_u16_u32_8() {
1179        generic_compress_many::<u16, u32, u8, 8>();
1180    }
1181
1182    #[test]
1183    fn compress_many_u8_u32_8() {
1184        generic_compress_many::<u8, u32, u8, 8>();
1185    }
1186
1187    #[test]
1188    fn compress_many_u8_u16_8() {
1189        generic_compress_many::<u8, u16, u8, 8>();
1190    }
1191
1192    fn generic_compress_many<Word, State, Probability, const PRECISION: usize>()
1193    where
1194        State: BitArray + AsPrimitive<Word>,
1195        Word: BitArray + Into<State> + AsPrimitive<Probability>,
1196        Probability: BitArray + Into<Word> + AsPrimitive<usize> + Into<f64>,
1197        u32: AsPrimitive<Probability>,
1198        usize: AsPrimitive<Probability>,
1199        f64: AsPrimitive<Probability>,
1200        i32: AsPrimitive<Probability>,
1201    {
1202        #[cfg(not(miri))]
1203        const AMT: usize = 1000;
1204
1205        #[cfg(miri)]
1206        const AMT: usize = 100;
1207
1208        let mut symbols_gaussian = Vec::with_capacity(AMT);
1209        let mut means = Vec::with_capacity(AMT);
1210        let mut stds = Vec::with_capacity(AMT);
1211
1212        let mut rng = Xoshiro256StarStar::seed_from_u64(1234);
1213        for _ in 0..AMT {
1214            let mean = (200.0 / u32::MAX as f64) * rng.next_u32() as f64 - 100.0;
1215            let std_dev = (10.0 / u32::MAX as f64) * rng.next_u32() as f64 + 0.001;
1216            let quantile = (rng.next_u32() as f64 + 0.5) / (1u64 << 32) as f64;
1217            let dist = Gaussian::new(mean, std_dev);
1218            let symbol = (dist.inverse(quantile).round() as i32).clamp(-127, 127);
1219
1220            symbols_gaussian.push(symbol);
1221            means.push(mean);
1222            stds.push(std_dev);
1223        }
1224
1225        let hist = [
1226            1u32, 186545, 237403, 295700, 361445, 433686, 509456, 586943, 663946, 737772, 1657269,
1227            896675, 922197, 930672, 916665, 0, 0, 0, 0, 0, 723031, 650522, 572300, 494702, 418703,
1228            347600, 1, 283500, 226158, 178194, 136301, 103158, 76823, 55540, 39258, 27988, 54269,
1229        ];
1230        let categorical_probabilities = hist.iter().map(|&x| x as f64).collect::<Vec<_>>();
1231        let categorical =
1232            ContiguousCategoricalEntropyModel::<Probability, _, PRECISION>::from_floating_point_probabilities_fast::<f64>(
1233                &categorical_probabilities,None
1234            )
1235            .unwrap();
1236        let mut symbols_categorical = Vec::with_capacity(AMT);
1237        let max_probability = Probability::max_value() >> (Probability::BITS - PRECISION);
1238        for _ in 0..AMT {
1239            let quantile = rng.next_u32().as_() & max_probability;
1240            let symbol = categorical.quantile_function(quantile).0;
1241            symbols_categorical.push(symbol);
1242        }
1243
1244        let mut encoder = RangeEncoder::<Word, State>::new();
1245
1246        encoder
1247            .encode_iid_symbols(&symbols_categorical, &categorical)
1248            .unwrap();
1249        dbg!(
1250            encoder.num_bits(),
1251            AMT as f64 * categorical.entropy_base2::<f64>()
1252        );
1253
1254        let quantizer = LeakyQuantizer::<_, _, Probability, PRECISION>::new(-127..=127);
1255        encoder
1256            .encode_symbols(symbols_gaussian.iter().zip(&means).zip(&stds).map(
1257                |((&symbol, &mean), &core)| (symbol, quantizer.quantize(Gaussian::new(mean, core))),
1258            ))
1259            .unwrap();
1260        dbg!(encoder.num_bits());
1261
1262        let mut decoder = encoder.into_decoder().unwrap();
1263
1264        let reconstructed_categorical = decoder
1265            .decode_iid_symbols(AMT, &categorical)
1266            .collect::<Result<Vec<_>, _>>()
1267            .unwrap();
1268        let reconstructed_gaussian = decoder
1269            .decode_symbols(
1270                means
1271                    .iter()
1272                    .zip(&stds)
1273                    .map(|(&mean, &core)| quantizer.quantize(Gaussian::new(mean, core))),
1274            )
1275            .collect::<Result<Vec<_>, _>>()
1276            .unwrap();
1277
1278        assert!(decoder.maybe_exhausted());
1279
1280        assert_eq!(symbols_categorical, reconstructed_categorical);
1281        assert_eq!(symbols_gaussian, reconstructed_gaussian);
1282    }
1283
1284    #[test]
1285    fn seek() {
1286        #[cfg(not(miri))]
1287        let (num_chunks, symbols_per_chunk) = (100, 100);
1288
1289        #[cfg(miri)]
1290        let (num_chunks, symbols_per_chunk) = (10, 10);
1291
1292        let quantizer = LeakyQuantizer::<_, _, u32, 24>::new(-100..=100);
1293        let model = quantizer.quantize(Gaussian::new(0.0, 10.0));
1294
1295        let mut encoder = DefaultRangeEncoder::new();
1296
1297        let mut rng = Xoshiro256StarStar::seed_from_u64(123);
1298        let mut symbols = Vec::with_capacity(num_chunks);
1299        let mut jump_table = Vec::with_capacity(num_chunks);
1300
1301        for _ in 0..num_chunks {
1302            jump_table.push(encoder.pos());
1303            let chunk = (0..symbols_per_chunk)
1304                .map(|_| model.quantile_function(rng.next_u32() % (1 << 24)).0)
1305                .collect::<Vec<_>>();
1306            encoder.encode_iid_symbols(&chunk, &model).unwrap();
1307            symbols.push(chunk);
1308        }
1309        let final_pos_and_state = encoder.pos();
1310
1311        let mut decoder = encoder.decoder();
1312
1313        // Verify we can decode the chunks normally (we can't veryify that coding and
1314        // decoding lead to same `pos_and_state` because the range decoder currently doesn't
1315        // implement `Pos` due to complications at the stream end.)
1316        for (chunk, _) in symbols.iter().zip(&jump_table) {
1317            let decoded = decoder
1318                .decode_iid_symbols(symbols_per_chunk, &model)
1319                .collect::<Result<Vec<_>, _>>()
1320                .unwrap();
1321            assert_eq!(&decoded, chunk);
1322        }
1323        assert!(decoder.maybe_exhausted());
1324
1325        // Seek to some random offsets in the jump table and decode one chunk
1326        for i in 0..100 {
1327            let chunk_index = if i == 3 {
1328                // Make sure we test jumping to beginning at least once.
1329                0
1330            } else {
1331                rng.next_u32() as usize % num_chunks
1332            };
1333
1334            let pos_and_state = jump_table[chunk_index];
1335            decoder.seek(pos_and_state).unwrap();
1336            let decoded = decoder
1337                .decode_iid_symbols(symbols_per_chunk, &model)
1338                .collect::<Result<Vec<_>, _>>()
1339                .unwrap();
1340            assert_eq!(&decoded, &symbols[chunk_index])
1341        }
1342
1343        // Test jumping to end (but first make sure we're not already at the end).
1344        decoder.seek(jump_table[0]).unwrap();
1345        assert!(!decoder.maybe_exhausted());
1346        decoder.seek(final_pos_and_state).unwrap();
1347        assert!(decoder.maybe_exhausted());
1348    }
1349}
1350
1351#[derive(Debug)]
1352#[non_exhaustive]
1353pub enum DecoderFrontendError {
1354    /// This can only happen if both of the following conditions apply:
1355    ///
1356    /// 1. we are decoding invalid compressed data; and
1357    /// 2. we use entropy models with varying `PRECISION`s.
1358    ///
1359    /// Unless you change the `PRECISION` mid-decoding this error cannot occur. However,
1360    /// note that the encoder is not surjective, i.e., it cannot reach all possible values.
1361    /// The reason why the decoder still doesn't err (unless varying `PRECISION`s are used)
1362    /// is that it is not injective, i.e., it maps the bit strings that are unreachable by
1363    /// the encoder to symbols that could have been encoded into a different bit string.
1364    ///
1365    /// The lack of injectivity of the encoder makes the Range Coder implementation in this
1366    /// library unsuitable for bitsback coding. Even though you can encode an arbitrary bit
1367    /// string into a sequence of symbols using any entropy model, decoding the sequence of
1368    /// symbols with the same entropy models won't always give you the same bit string. In
1369    /// other words,
1370    ///
1371    /// - `range_decode(range_encode(sequence_of_symbols)) = sequence_of_symbols` for all
1372    ///   `sequence_of_symbols`; but
1373    /// - `range_encode(range_encode(bit_string)) != bit_string` in general.
1374    ///
1375    /// If you need equality in the second relation, use an [`AnsCoder`].
1376    ///
1377    /// [`AnsCoder`]: super::stack::AnsCoder
1378    InvalidData,
1379}
1380
1381impl Display for DecoderFrontendError {
1382    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1383        match self {
1384            Self::InvalidData => write!(
1385                f,
1386                "Tried to decode from compressed data that is invalid for the employed entropy model."
1387            ),
1388        }
1389    }
1390}
1391
1392#[cfg(feature = "std")]
1393impl std::error::Error for DecoderFrontendError {}