constriction/stream/stack.rs
1//! Fast and Near-optimal compression on a stack ("last in first out")
2//!
3//! This module provides the [`AnsCoder`], a highly efficient entropy coder with
4//! near-optimal compression effectiveness that operates as a *stack* data structure. It
5//! implements the Asymmetric Numeral Systems (ANS) compression algorithm \[1].
6//!
7//! # Comparison to sister module `queue`
8//!
9//! ANS Coding operates as a stack, which means that encoding and decoding operate in
10//! reverse direction with respect to each other. The provided implementation of ANS Coding
11//! uses a single data structure, the [`AnsCoder`], for both encoding and decoding. It
12//! allows you to interleave encoding and decoding operations arbitrarily, which is in
13//! contrast to the situation in the sister module [`queue`] and important for advanced
14//! compression techniques such as bits-back coding in hierarchical probabilistic models.
15//!
16//! The parent module contains a more detailed discussion of the [differences between ANS
17//! Coding and Range Coding](super#which-stream-code-should-i-use) .
18//!
19//! # References
20//!
21//! \[1] Duda, Jarek, et al. "The use of asymmetric numeral systems as an accurate
22//! replacement for Huffman coding." 2015 Picture Coding Symposium (PCS). IEEE, 2015.
23//!
24//! [`queue`]: super::queue
25
26use alloc::vec::Vec;
27use core::{
28 borrow::Borrow, convert::Infallible, fmt::Debug, iter::Fuse, marker::PhantomData, ops::Deref,
29};
30use num_traits::AsPrimitive;
31
32use super::{
33 model::{DecoderModel, EncoderModel},
34 AsDecoder, Code, Decode, Encode, IntoDecoder, TryCodingError,
35};
36use crate::{
37 backends::{
38 self, AsReadWords, AsSeekReadWords, BoundedReadWords, Cursor, FallibleIteratorReadWords,
39 IntoReadWords, IntoSeekReadWords, ReadWords, Reverse, WriteWords,
40 },
41 bit_array_to_chunks_truncated, generic_static_asserts, BitArray, CoderError,
42 DefaultEncoderError, DefaultEncoderFrontendError, NonZeroBitArray, Pos, PosSeek, Seek, Stack,
43 UnwrapInfallible,
44};
45
46/// Entropy coder for both encoding and decoding on a stack.
47///
48/// This is the generic struct for an ANS coder. It provides fine-tuned control over type
49/// parameters (see [discussion in parent
50/// module](super#highly-customizable-implementations-with-sane-presets)). You'll usually
51/// want to use this type through the type alias [`DefaultAnsCoder`], which provides sane
52/// default settings for the type parameters.
53///
54/// The `AnsCoder` uses an entropy coding algorithm called [range Asymmetric
55/// Numeral Systems (rANS)]. This means that it operates as a stack, i.e., a "last
56/// in first out" data structure: encoding "pushes symbols on" the stack and
57/// decoding "pops symbols off" the stack in reverse order. In default operation, decoding
58/// with an `AnsCoder` *consumes* the compressed data for the decoded symbols (however, you
59/// can also decode immutable data by using a [`Cursor`]). This means
60/// that encoding and decoding can be interleaved arbitrarily, thus growing and shrinking
61/// the stack of compressed data as you go.
62///
63/// # Example
64///
65/// Basic usage example:
66///
67/// ```
68/// use constriction::stream::{model::DefaultLeakyQuantizer, stack::DefaultAnsCoder, Decode};
69///
70/// // `DefaultAnsCoder` is a type alias to `AnsCoder` with sane generic parameters.
71/// let mut ans = DefaultAnsCoder::new();
72///
73/// // Create an entropy model based on a quantized Gaussian distribution. You can use `AnsCoder`
74/// // with any entropy model defined in the `models` module.
75/// let quantizer = DefaultLeakyQuantizer::new(-100..=100);
76/// let entropy_model = quantizer.quantize(probability::distribution::Gaussian::new(0.0, 10.0));
77///
78/// let symbols = vec![-10, 4, 0, 3];
79/// // Encode symbols in *reverse* order, so that we can decode them in forward order.
80/// ans.encode_iid_symbols_reverse(&symbols, &entropy_model).unwrap();
81///
82/// // Obtain temporary shared access to the compressed bit string. If you want ownership of the
83/// // compressed bit string, call `.into_compressed()` instead of `.get_compressed()`.
84/// println!("Encoded into {} bits: {:?}", ans.num_bits(), &*ans.get_compressed().unwrap());
85///
86/// // Decode the symbols and verify correctness.
87/// let reconstructed = ans
88/// .decode_iid_symbols(4, &entropy_model)
89/// .collect::<Result<Vec<_>, _>>()
90/// .unwrap();
91/// assert_eq!(reconstructed, symbols);
92/// ```
93///
94/// # Consistency Between Encoding and Decoding
95///
96/// As elaborated in the [parent module's documentation](super#whats-a-stream-code),
97/// encoding and decoding operates on a sequence of symbols. Each symbol can be encoded and
98/// decoded with its own entropy model (the symbols can even have heterogeneous types). If
99/// your goal is to reconstruct the originally encoded symbols during decoding, then you
100/// must employ the same sequence of entropy models (in reversed order) during encoding and
101/// decoding.
102///
103/// However, using the same entropy models for encoding and decoding is not a *general*
104/// requirement. It is perfectly legal to push (encode) symbols on the `AnsCoder` using some
105/// entropy models, and then pop off (decode) symbols using different entropy models. The
106/// popped off symbols will then in general be different from the original symbols, but will
107/// be generated in a deterministic way. If there is no deterministic relation between the
108/// entropy models used for pushing and popping, and if there is still compressed data left
109/// at the end (i.e., if [`is_empty`] returns false), then the popped off symbols are, to a
110/// very good approximation, distributed as independent samples from the respective entropy
111/// models. Such random samples, which consume parts of the compressed data, are useful in
112/// the bits-back algorithm.
113///
114/// [range Asymmetric Numeral Systems (rANS)]:
115/// https://en.wikipedia.org/wiki/Asymmetric_numeral_systems#Range_variants_(rANS)_and_streaming
116/// [`is_empty`]: #method.is_empty`
117/// [`Cursor`]: crate::backends::Cursor
118#[derive(Clone)]
119pub struct AnsCoder<Word, State, Backend = Vec<Word>>
120where
121 Word: BitArray + Into<State>,
122 State: BitArray + AsPrimitive<Word>,
123{
124 bulk: Backend,
125
126 /// Invariant: `state >= State::one() << (State::BITS - Word::BITS)` unless
127 /// `bulk.is_empty()`.
128 state: State,
129
130 /// We keep track of the `Word` type so that we can statically enforce the invariant
131 /// `Word: Into<State>`.
132 phantom: PhantomData<Word>,
133}
134
135/// Type alias for an [`AnsCoder`] with sane parameters for typical use cases.
136///
137/// This type alias sets the generic type arguments `Word` and `State` to sane values for
138/// many typical use cases.
139pub type DefaultAnsCoder<Backend = Vec<u32>> = AnsCoder<u32, u64, Backend>;
140
141/// Type alias for an [`AnsCoder`] for use with a [`ContiguousLookupDecoderModel`] or [`NonContiguousLookupDecoderModel`]
142///
143/// This encoder has a smaller word size and internal state than [`AnsCoder`]. It is
144/// optimized for use with a [`ContiguousLookupDecoderModel`] or [`NonContiguousLookupDecoderModel`].
145///
146/// # Examples
147///
148/// See [`ContiguousLookupDecoderModel`].
149///
150/// [`ContiguousLookupDecoderModel`]: crate::stream::model::ContiguousLookupDecoderModel
151/// [`NonContiguousLookupDecoderModel`]: crate::stream::model::NonContiguousLookupDecoderModel
152/// [`ContiguousLookupDecoderModel`]: crate::stream::model::ContiguousLookupDecoderModel
153pub type SmallAnsCoder<Backend = Vec<u16>> = AnsCoder<u16, u32, Backend>;
154
155impl<Word, State, Backend> Debug for AnsCoder<Word, State, Backend>
156where
157 Word: BitArray + Into<State>,
158 State: BitArray + AsPrimitive<Word>,
159 for<'a> &'a Backend: IntoIterator<Item = &'a Word>,
160{
161 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
162 f.debug_list().entries(self.iter_compressed()).finish()
163 }
164}
165
166impl<Word, State, Backend, const PRECISION: usize> IntoDecoder<PRECISION>
167 for AnsCoder<Word, State, Backend>
168where
169 Word: BitArray + Into<State>,
170 State: BitArray + AsPrimitive<Word>,
171 Backend: WriteWords<Word> + IntoReadWords<Word, Stack>,
172{
173 type IntoDecoder = AnsCoder<Word, State, Backend::IntoReadWords>;
174
175 fn into_decoder(self) -> Self::IntoDecoder {
176 AnsCoder {
177 bulk: self.bulk.into_read_words(),
178 state: self.state,
179 phantom: PhantomData,
180 }
181 }
182}
183
184impl<'a, Word, State, Backend> From<&'a AnsCoder<Word, State, Backend>>
185 for AnsCoder<Word, State, <Backend as AsReadWords<'a, Word, Stack>>::AsReadWords>
186where
187 Word: BitArray + Into<State>,
188 State: BitArray + AsPrimitive<Word>,
189 Backend: AsReadWords<'a, Word, Stack>,
190{
191 fn from(ans: &'a AnsCoder<Word, State, Backend>) -> Self {
192 AnsCoder {
193 bulk: ans.bulk().as_read_words(),
194 state: ans.state(),
195 phantom: PhantomData,
196 }
197 }
198}
199
200impl<'a, Word, State, Backend, const PRECISION: usize> AsDecoder<'a, PRECISION>
201 for AnsCoder<Word, State, Backend>
202where
203 Word: BitArray + Into<State>,
204 State: BitArray + AsPrimitive<Word>,
205 Backend: WriteWords<Word> + AsReadWords<'a, Word, Stack>,
206{
207 type AsDecoder = AnsCoder<Word, State, Backend::AsReadWords>;
208
209 fn as_decoder(&'a self) -> Self::AsDecoder {
210 self.into()
211 }
212}
213
214impl<Word, State> From<AnsCoder<Word, State, Vec<Word>>> for Vec<Word>
215where
216 Word: BitArray + Into<State>,
217 State: BitArray + AsPrimitive<Word>,
218{
219 fn from(val: AnsCoder<Word, State, Vec<Word>>) -> Self {
220 val.into_compressed().unwrap_infallible()
221 }
222}
223
224impl<Word, State> AnsCoder<Word, State, Vec<Word>>
225where
226 Word: BitArray + Into<State>,
227 State: BitArray + AsPrimitive<Word>,
228{
229 /// Creates an empty ANS entropy coder.
230 ///
231 /// This is usually the starting point if you want to *compress* data.
232 ///
233 /// # Example
234 ///
235 /// ```
236 /// let mut ans = constriction::stream::stack::DefaultAnsCoder::new();
237 ///
238 /// // ... push some symbols onto the ANS coder's stack ...
239 ///
240 /// // Finally, get the compressed data.
241 /// let compressed = ans.into_compressed();
242 /// ```
243 ///
244 /// # Generality
245 ///
246 /// To avoid type parameters in common use cases, `new` is only implemented for
247 /// `AnsCoder`s with a `Vec` backend. To create an empty coder with a different backend,
248 /// call [`Default::default`] instead.
249 pub fn new() -> Self {
250 Self::default()
251 }
252}
253
254impl<Word, State, Backend> Default for AnsCoder<Word, State, Backend>
255where
256 Word: BitArray + Into<State>,
257 State: BitArray + AsPrimitive<Word>,
258 Backend: Default,
259{
260 fn default() -> Self {
261 generic_static_asserts!(
262 (Word: BitArray, State:BitArray);
263 STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
264 );
265
266 Self {
267 state: State::zero(),
268 bulk: Default::default(),
269 phantom: PhantomData,
270 }
271 }
272}
273
274impl<Word, State, Backend> AnsCoder<Word, State, Backend>
275where
276 Word: BitArray + Into<State>,
277 State: BitArray + AsPrimitive<Word>,
278{
279 /// Low-level constructor that assembles an `AnsCoder` from its internal components.
280 ///
281 /// The arguments `bulk` and `state` correspond to the two return values of the method
282 /// [`into_raw_parts`](Self::into_raw_parts).
283 ///
284 /// The caller must ensure that `state >= State::one() << (State::BITS - Word::BITS)`
285 /// unless `bulk` is empty. This cannot be checked by the method since not all
286 /// `Backend`s have an `is_empty` method. Violating this invariant is not a memory
287 /// safety issue but it will lead to incorrect behavior.
288 pub fn from_raw_parts(bulk: Backend, state: State) -> Self {
289 Self {
290 bulk,
291 state,
292 phantom: PhantomData,
293 }
294 }
295
296 /// Creates an ANS stack with some initial compressed data.
297 ///
298 /// This is usually the starting point if you want to *decompress* data previously
299 /// obtained from [`into_compressed`]. However, it can also be used to append more
300 /// symbols to an existing compressed buffer of data.
301 ///
302 /// Returns `Err(compressed)` if `compressed` is not empty and its last entry is
303 /// zero, since an `AnsCoder` cannot represent trailing zero words. This error cannot
304 /// occur if `compressed` was obtained from [`into_compressed`], which never returns
305 /// data with a trailing zero word. If you want to construct a `AnsCoder` from an
306 /// unknown source of binary data (e.g., to decode some side information into latent
307 /// variables) then call [`from_binary`] instead.
308 ///
309 /// [`into_compressed`]: #method.into_compressed
310 /// [`from_binary`]: #method.from_binary
311 pub fn from_compressed(mut compressed: Backend) -> Result<Self, Backend>
312 where
313 Backend: ReadWords<Word, Stack>,
314 {
315 generic_static_asserts!(
316 (Word: BitArray, State:BitArray);
317 STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
318 );
319
320 let state = match Self::read_initial_state(|| compressed.read()) {
321 Ok(state) => state,
322 Err(_) => return Err(compressed),
323 };
324
325 Ok(Self {
326 bulk: compressed,
327 state,
328 phantom: PhantomData,
329 })
330 }
331
332 fn read_initial_state<Error>(
333 mut read_word: impl FnMut() -> Result<Option<Word>, Error>,
334 ) -> Result<State, ()>
335 where
336 Backend: ReadWords<Word, Stack>,
337 {
338 if let Some(first_word) = read_word().map_err(|_| ())? {
339 if first_word == Word::zero() {
340 return Err(());
341 }
342
343 let mut state = first_word.into();
344 while let Some(word) = read_word().map_err(|_| ())? {
345 state = (state << Word::BITS) | word.into();
346 if state >= State::one() << (State::BITS - Word::BITS) {
347 break;
348 }
349 }
350 Ok(state)
351 } else {
352 Ok(State::zero())
353 }
354 }
355
356 /// Like [`from_compressed`] but works on any binary data.
357 ///
358 /// This method is meant for rather advanced use cases. For most common use cases,
359 /// you probably want to call [`from_compressed`] instead.
360 ///
361 /// Different to `from_compressed`, this method also works if `data` ends in a zero
362 /// word. Calling this method is equivalent to (but likely more efficient than)
363 /// appending a `1` word to `data` and then calling `from_compressed`. Note that
364 /// therefore, this method always constructs a non-empty `AnsCoder` (even if `data` is
365 /// empty):
366 ///
367 /// ```
368 /// use constriction::stream::stack::DefaultAnsCoder;
369 ///
370 /// let stack1 = DefaultAnsCoder::from_binary(Vec::new()).unwrap();
371 /// assert!(!stack1.is_empty()); // <-- stack1 is *not* empty.
372 ///
373 /// let stack2 = DefaultAnsCoder::from_compressed(Vec::new()).unwrap();
374 /// assert!(stack2.is_empty()); // <-- stack2 is empty.
375 /// ```
376 /// [`from_compressed`]: #method.from_compressed
377 pub fn from_binary(mut data: Backend) -> Result<Self, Backend::ReadError>
378 where
379 Backend: ReadWords<Word, Stack>,
380 {
381 let mut state = State::one();
382
383 while state < State::one() << (State::BITS - Word::BITS) {
384 if let Some(word) = data.read()? {
385 state = (state << Word::BITS) | word.into();
386 } else {
387 break;
388 }
389 }
390
391 Ok(Self {
392 bulk: data,
393 state,
394 phantom: PhantomData,
395 })
396 }
397
398 #[inline(always)]
399 pub fn bulk(&self) -> &Backend {
400 &self.bulk
401 }
402
403 /// Low-level method that disassembles the `AnsCoder` into its internal components.
404 ///
405 /// Can be used together with [`from_raw_parts`](Self::from_raw_parts).
406 pub fn into_raw_parts(self) -> (Backend, State) {
407 (self.bulk, self.state)
408 }
409
410 /// Check if no data for decoding is left.
411 ///
412 /// Note that you can still pop symbols off an empty stack, but this is only
413 /// useful in rare edge cases, see documentation of
414 /// [`decode_symbol`](#method.decode_symbol).
415 pub fn is_empty(&self) -> bool {
416 // We don't need to check if `bulk` is empty (which would require an additional
417 // type bound `Backend: ReadLookaheadItems<Word>` because we keep up the
418 // invariant that `state >= State::one() << (State::BITS - Word::BITS))`
419 // when `bulk` is not empty.
420 self.state == State::zero()
421 }
422
423 /// Assembles the current compressed data into a single slice.
424 ///
425 /// Returns the concatenation of [`bulk`] and [`state`]. The concatenation truncates
426 /// any trailing zero words, which is compatible with the constructor
427 /// [`from_compressed`].
428 ///
429 /// This method requires a `&mut self` receiver to temporarily append `state` to
430 /// [`bulk`] (this mutationwill be reversed to recreate the original `bulk` as soon as
431 /// the caller drops the returned value). If you don't have mutable access to the
432 /// `AnsCoder`, consider calling [`iter_compressed`] instead, or get the `bulk` and
433 /// `state` separately by calling [`bulk`] and [`state`], respectively.
434 ///
435 /// The return type dereferences to `&[Word]`, thus providing read-only
436 /// access to the compressed data. If you need ownership of the compressed data,
437 /// consider calling [`into_compressed`] instead.
438 ///
439 /// # Example
440 ///
441 /// ```
442 /// use constriction::stream::{
443 /// model::DefaultContiguousCategoricalEntropyModel, stack::DefaultAnsCoder, Decode
444 /// };
445 ///
446 /// let mut ans = DefaultAnsCoder::new();
447 ///
448 /// // Push some data on the ans.
449 /// let symbols = vec![8, 2, 0, 7];
450 /// let probabilities = vec![0.03, 0.07, 0.1, 0.1, 0.2, 0.2, 0.1, 0.15, 0.05];
451 /// let model = DefaultContiguousCategoricalEntropyModel
452 /// ::from_floating_point_probabilities_fast(&probabilities, None).unwrap();
453 /// ans.encode_iid_symbols_reverse(&symbols, &model).unwrap();
454 ///
455 /// // Inspect the compressed data.
456 /// dbg!(ans.get_compressed());
457 ///
458 /// // We can still use the ANS coder afterwards.
459 /// let reconstructed = ans
460 /// .decode_iid_symbols(4, &model)
461 /// .collect::<Result<Vec<_>, _>>()
462 /// .unwrap();
463 /// assert_eq!(reconstructed, symbols);
464 /// ```
465 ///
466 /// [`bulk`]: #method.bulk
467 /// [`state`]: #method.state
468 /// [`from_compressed`]: #method.from_compressed
469 /// [`iter_compressed`]: #method.iter_compressed
470 /// [`into_compressed`]: #method.into_compressed
471 pub fn get_compressed(
472 &mut self,
473 ) -> Result<impl Deref<Target = Backend> + Debug + Drop + '_, Backend::WriteError>
474 where
475 Backend: ReadWords<Word, Stack> + WriteWords<Word> + Debug,
476 {
477 CoderGuard::<'_, _, _, _, false>::new(self).map_err(|err| match err {
478 CoderError::Frontend(()) => unreachable!("Can't happen for SEALED==false."),
479 CoderError::Backend(err) => err,
480 })
481 }
482
483 pub fn get_binary(
484 &mut self,
485 ) -> Result<impl Deref<Target = Backend> + Debug + Drop + '_, CoderError<(), Backend::WriteError>>
486 where
487 Backend: ReadWords<Word, Stack> + WriteWords<Word> + Debug,
488 {
489 CoderGuard::<'_, _, _, _, true>::new(self)
490 }
491
492 /// Iterates over the compressed data currently on the ans.
493 ///
494 /// In contrast to [`get_compressed`] or [`into_compressed`], this method does
495 /// not require mutable access or even ownership of the `AnsCoder`.
496 ///
497 /// # Example
498 ///
499 /// ```
500 /// use constriction::stream::{model::DefaultLeakyQuantizer, stack::DefaultAnsCoder, Decode};
501 ///
502 /// // Create a stack and encode some stuff.
503 /// let mut ans = DefaultAnsCoder::new();
504 /// let symbols = vec![8, -12, 0, 7];
505 /// let quantizer = DefaultLeakyQuantizer::new(-100..=100);
506 /// let model =
507 /// quantizer.quantize(probability::distribution::Gaussian::new(0.0, 10.0));
508 /// ans.encode_iid_symbols_reverse(&symbols, &model).unwrap();
509 ///
510 /// // Iterate over compressed data, collect it into to a Vec``, and compare to direct method.
511 /// let compressed_iter = ans.iter_compressed();
512 /// let compressed_collected = compressed_iter.collect::<Vec<_>>();
513 /// assert!(!compressed_collected.is_empty());
514 /// assert_eq!(compressed_collected, *ans.get_compressed().unwrap());
515 /// ```
516 ///
517 /// [`get_compressed`]: #method.get_compressed
518 /// [`into_compressed`]: #method.into_compressed
519 pub fn iter_compressed<'a>(&'a self) -> impl Iterator<Item = Word> + 'a
520 where
521 &'a Backend: IntoIterator<Item = &'a Word>,
522 {
523 let bulk_iter = self.bulk.into_iter().cloned();
524 let state_iter = bit_array_to_chunks_truncated(self.state).rev();
525 bulk_iter.chain(state_iter)
526 }
527
528 /// Returns the number of compressed words on the ANS coder's stack.
529 ///
530 /// This includes a constant overhead of between one and two words unless the
531 /// stack is completely empty.
532 ///
533 /// This method returns the length of the slice, the `Vec<Word>`, or the iterator
534 /// that would be returned by [`get_compressed`], [`into_compressed`], or
535 /// [`iter_compressed`], respectively, when called at this time.
536 ///
537 /// See also [`num_bits`].
538 ///
539 /// [`get_compressed`]: #method.get_compressed
540 /// [`into_compressed`]: #method.into_compressed
541 /// [`iter_compressed`]: #method.iter_compressed
542 /// [`num_bits`]: #method.num_bits
543 pub fn num_words(&self) -> usize
544 where
545 Backend: BoundedReadWords<Word, Stack>,
546 {
547 self.bulk.remaining() + bit_array_to_chunks_truncated::<_, Word>(self.state).len()
548 }
549
550 pub fn num_bits(&self) -> usize
551 where
552 Backend: BoundedReadWords<Word, Stack>,
553 {
554 Word::BITS * self.num_words()
555 }
556
557 pub fn num_valid_bits(&self) -> usize
558 where
559 Backend: BoundedReadWords<Word, Stack>,
560 {
561 Word::BITS * self.bulk.remaining()
562 + core::cmp::max(State::BITS - self.state.leading_zeros() as usize, 1)
563 - 1
564 }
565
566 pub fn into_decoder(self) -> AnsCoder<Word, State, Backend::IntoReadWords>
567 where
568 Backend: IntoReadWords<Word, Stack>,
569 {
570 AnsCoder {
571 bulk: self.bulk.into_read_words(),
572 state: self.state,
573 phantom: PhantomData,
574 }
575 }
576
577 /// Consumes the `AnsCoder` and returns a decoder that implements [`Seek`].
578 ///
579 /// This method is similar to [`as_seekable_decoder`] except that it takes ownership of
580 /// the original `AnsCoder`, so the returned seekable decoder can typically be returned
581 /// from the calling function or put on the heap.
582 ///
583 /// [`as_seekable_decoder`]: Self::as_seekable_decoder
584 pub fn into_seekable_decoder(self) -> AnsCoder<Word, State, Backend::IntoSeekReadWords>
585 where
586 Backend: IntoSeekReadWords<Word, Stack>,
587 {
588 AnsCoder {
589 bulk: self.bulk.into_seek_read_words(),
590 state: self.state,
591 phantom: PhantomData,
592 }
593 }
594
595 pub fn as_decoder<'a>(&'a self) -> AnsCoder<Word, State, Backend::AsReadWords>
596 where
597 Backend: AsReadWords<'a, Word, Stack>,
598 {
599 AnsCoder {
600 bulk: self.bulk.as_read_words(),
601 state: self.state,
602 phantom: PhantomData,
603 }
604 }
605
606 /// Returns a decoder that implements [`Seek`].
607 ///
608 /// The returned decoder shares access to the compressed data with the original
609 /// `AnsCoder` (i.e., `self`). This means that:
610 /// - you can call this method several times to create several seekable decoders
611 /// with independent views into the same compressed data;
612 /// - once the lifetime of all handed out seekable decoders ends, the original
613 /// `AnsCoder` can be used again; and
614 /// - the constructed seekable decoder cannot outlive the original `AnsCoder`; for
615 /// example, if the original `AnsCoder` lives on the calling function's call stack
616 /// frame then you cannot return the constructed seekable decoder from the calling
617 /// function. If this is a problem then call [`into_seekable_decoder`] instead.
618 ///
619 /// # Limitations
620 ///
621 /// TODO: this text is outdated.
622 ///
623 /// This method is only implemented for `AnsCoder`s whose backing store of compressed
624 /// data (`Backend`) implements `AsRef<[Word]>`. This includes the default
625 /// backing data store `Backend = Vec<Word>`.
626 ///
627 /// [`into_seekable_decoder`]: Self::into_seekable_decoder
628 pub fn as_seekable_decoder<'a>(&'a self) -> AnsCoder<Word, State, Backend::AsSeekReadWords>
629 where
630 Backend: AsSeekReadWords<'a, Word, Stack>,
631 {
632 AnsCoder {
633 bulk: self.bulk.as_seek_read_words(),
634 state: self.state,
635 phantom: PhantomData,
636 }
637 }
638}
639
640impl<Word, State> AnsCoder<Word, State>
641where
642 Word: BitArray + Into<State>,
643 State: BitArray + AsPrimitive<Word>,
644{
645 /// Discards all compressed data and resets the coder to the same state as
646 /// [`Coder::new`](#method.new).
647 pub fn clear(&mut self) {
648 self.bulk.clear();
649 self.state = State::zero();
650 }
651}
652
653impl<'bulk, Word, State> AnsCoder<Word, State, Cursor<Word, &'bulk [Word]>>
654where
655 Word: BitArray + Into<State>,
656 State: BitArray + AsPrimitive<Word>,
657{
658 // TODO: proper error type (also for `from_compressed`)
659 #[allow(clippy::result_unit_err)]
660 pub fn from_compressed_slice(compressed: &'bulk [Word]) -> Result<Self, ()> {
661 Self::from_compressed(backends::Cursor::new_at_write_end(compressed)).map_err(|_| ())
662 }
663
664 pub fn from_binary_slice(data: &'bulk [Word]) -> Self {
665 Self::from_binary(backends::Cursor::new_at_write_end(data)).unwrap_infallible()
666 }
667}
668
669impl<Word, State, Buf> AnsCoder<Word, State, Reverse<Cursor<Word, Buf>>>
670where
671 Word: BitArray + Into<State>,
672 State: BitArray + AsPrimitive<Word>,
673 Buf: AsRef<[Word]>,
674{
675 pub fn from_reversed_compressed(compressed: Buf) -> Result<Self, Buf> {
676 Self::from_compressed(Reverse(Cursor::new_at_write_beginning(compressed)))
677 .map_err(|Reverse(cursor)| cursor.into_buf_and_pos().0)
678 }
679
680 pub fn from_reversed_binary(data: Buf) -> Self {
681 Self::from_binary(Reverse(Cursor::new_at_write_beginning(data))).unwrap_infallible()
682 }
683}
684
685impl<Word, State, Iter, ReadError> AnsCoder<Word, State, FallibleIteratorReadWords<Iter>>
686where
687 Word: BitArray + Into<State>,
688 State: BitArray + AsPrimitive<Word>,
689 Iter: Iterator<Item = Result<Word, ReadError>>,
690 FallibleIteratorReadWords<Iter>: ReadWords<Word, Stack, ReadError = ReadError>,
691{
692 pub fn from_reversed_compressed_iter(compressed: Iter) -> Result<Self, Fuse<Iter>> {
693 Self::from_compressed(FallibleIteratorReadWords::new(compressed))
694 .map_err(|iterator_backend| iterator_backend.into_iter())
695 }
696
697 pub fn from_reversed_binary_iter(data: Iter) -> Result<Self, ReadError> {
698 Self::from_binary(FallibleIteratorReadWords::new(data))
699 }
700}
701
702impl<Word, State, Backend> AnsCoder<Word, State, Backend>
703where
704 Word: BitArray + Into<State>,
705 State: BitArray + AsPrimitive<Word>,
706 Backend: WriteWords<Word>,
707{
708 /// Recommended way to encode a heterogeneously distributed sequence of
709 /// symbols onto an `AnsCoder`.
710 ///
711 /// This method is similar to the trait method [`Encode::encode_symbols`],
712 /// but it encodes the symbols in *reverse* order (and therefore requires
713 /// the provided iterator to implement [`DoubleEndedIterator`]). Encoding
714 /// in reverse order is the recommended way to encode onto an `AnsCoder`
715 /// because an `AnsCoder` is a *stack*, i.e., the last symbol you encode
716 /// onto an `AnsCoder` is the first symbol that you will decode from it.
717 /// Thus, encoding a sequence of symbols in reverse order will allow you to
718 /// decode them in normal order.
719 pub fn encode_symbols_reverse<S, M, I, const PRECISION: usize>(
720 &mut self,
721 symbols_and_models: I,
722 ) -> Result<(), DefaultEncoderError<Backend::WriteError>>
723 where
724 S: Borrow<M::Symbol>,
725 M: EncoderModel<PRECISION>,
726 M::Probability: Into<Word>,
727 Word: AsPrimitive<M::Probability>,
728 I: IntoIterator<Item = (S, M)>,
729 I::IntoIter: DoubleEndedIterator,
730 {
731 self.encode_symbols(symbols_and_models.into_iter().rev())
732 }
733
734 /// Recommended way to encode onto an `AnsCoder` from a fallible iterator.
735 ///
736 /// This method is similar to the trait method
737 /// [`Encode::try_encode_symbols`], but it encodes the symbols in *reverse*
738 /// order (and therefore requires the provided iterator to implement
739 /// [`DoubleEndedIterator`]). Encoding in reverse order is the recommended
740 /// way to encode onto an `AnsCoder` because an `AnsCoder` is a *stack*,
741 /// i.e., the last symbol you encode onto an `AnsCoder` is the first symbol
742 /// that you will decode from it. Thus, encoding a sequence of symbols in
743 /// reverse order will allow you to decode them in normal order.
744 pub fn try_encode_symbols_reverse<S, M, E, I, const PRECISION: usize>(
745 &mut self,
746 symbols_and_models: I,
747 ) -> Result<(), TryCodingError<DefaultEncoderError<Backend::WriteError>, E>>
748 where
749 S: Borrow<M::Symbol>,
750 M: EncoderModel<PRECISION>,
751 M::Probability: Into<Word>,
752 Word: AsPrimitive<M::Probability>,
753 I: IntoIterator<Item = core::result::Result<(S, M), E>>,
754 I::IntoIter: DoubleEndedIterator,
755 {
756 self.try_encode_symbols(symbols_and_models.into_iter().rev())
757 }
758
759 /// Recommended way to encode a sequence of i.i.d. symbols onto an
760 /// `AnsCoder`.
761 ///
762 /// This method is similar to the trait method
763 /// [`Encode::encode_iid_symbols`], but it encodes the symbols in *reverse*
764 /// order (and therefore requires the provided iterator to implement
765 /// [`DoubleEndedIterator`]). Encoding in reverse order is the recommended
766 /// way to encode onto an `AnsCoder` because an `AnsCoder` is a *stack*,
767 /// i.e., the last symbol you encode onto an `AnsCoder` is the first symbol
768 /// that you will decode from it. Thus, encoding a sequence of symbols in
769 /// reverse order will allow you to decode them in normal order.
770 pub fn encode_iid_symbols_reverse<S, M, I, const PRECISION: usize>(
771 &mut self,
772 symbols: I,
773 model: M,
774 ) -> Result<(), DefaultEncoderError<Backend::WriteError>>
775 where
776 S: Borrow<M::Symbol>,
777 M: EncoderModel<PRECISION> + Copy,
778 M::Probability: Into<Word>,
779 Word: AsPrimitive<M::Probability>,
780 I: IntoIterator<Item = S>,
781 I::IntoIter: DoubleEndedIterator,
782 {
783 self.encode_iid_symbols(symbols.into_iter().rev(), model)
784 }
785
786 /// Consumes the ANS coder and returns the compressed data.
787 ///
788 /// The returned data can be used to recreate an ANS coder with the same state
789 /// (e.g., for decoding) by passing it to
790 /// [`from_compressed`](#method.from_compressed).
791 ///
792 /// If you don't want to consume the ANS coder, consider calling
793 /// [`get_compressed`](#method.get_compressed),
794 /// [`iter_compressed`](#method.iter_compressed) instead.
795 ///
796 /// # Example
797 ///
798 /// ```
799 /// use constriction::stream::{
800 /// model::DefaultContiguousCategoricalEntropyModel, stack::DefaultAnsCoder, Decode
801 /// };
802 ///
803 /// let mut ans = DefaultAnsCoder::new();
804 ///
805 /// // Push some data onto the ANS coder's stack:
806 /// let symbols = vec![8, 2, 0, 7];
807 /// let probabilities = vec![0.03, 0.07, 0.1, 0.1, 0.2, 0.2, 0.1, 0.15, 0.05];
808 /// let model = DefaultContiguousCategoricalEntropyModel
809 /// ::from_floating_point_probabilities_fast(&probabilities, None).unwrap();
810 /// ans.encode_iid_symbols_reverse(&symbols, &model).unwrap();
811 ///
812 /// // Get the compressed data, consuming the ANS coder:
813 /// let compressed = ans.into_compressed().unwrap();
814 ///
815 /// // ... write `compressed` to a file and then read it back later ...
816 ///
817 /// // Create a new ANS coder with the same state and use it for decompression:
818 /// let mut ans = DefaultAnsCoder::from_compressed(compressed).expect("Corrupted compressed file.");
819 /// let reconstructed = ans
820 /// .decode_iid_symbols(4, &model)
821 /// .collect::<Result<Vec<_>, _>>()
822 /// .unwrap();
823 /// assert_eq!(reconstructed, symbols);
824 /// assert!(ans.is_empty())
825 /// ```
826 pub fn into_compressed(mut self) -> Result<Backend, Backend::WriteError> {
827 self.bulk
828 .extend_from_iter(bit_array_to_chunks_truncated(self.state).rev())?;
829 Ok(self.bulk)
830 }
831
832 /// Returns the binary data if it fits precisely into an integer number of
833 /// `Word`s
834 ///
835 /// This method is meant for rather advanced use cases. For most common use cases,
836 /// you probably want to call [`into_compressed`] instead.
837 ///
838 /// This method is the inverse of [`from_binary`]. It is equivalent to calling
839 /// [`into_compressed`], verifying that the returned vector ends in a `1` word, and
840 /// popping off that trailing `1` word.
841 ///
842 /// Returns `Err(())` if the compressed data (excluding an obligatory trailing
843 /// `1` bit) does not fit into an integer number of `Word`s. This error
844 /// case includes the case of an empty `AnsCoder` (since an empty `AnsCoder` lacks the
845 /// obligatory trailing one-bit).
846 ///
847 /// # Example
848 ///
849 /// ```
850 /// // Some binary data we want to represent on a `AnsCoder`.
851 /// let data = vec![0x89ab_cdef, 0x0123_4567];
852 ///
853 /// // Constructing a `AnsCoder` with `from_binary` indicates that all bits of `data` are
854 /// // considered part of the information-carrying payload.
855 /// let stack1 = constriction::stream::stack::DefaultAnsCoder::from_binary(data.clone()).unwrap();
856 /// assert_eq!(stack1.clone().into_binary().unwrap(), data); // <-- Retrieves the original `data`.
857 ///
858 /// // By contrast, if we construct a `AnsCoder` with `from_compressed`, we indicate that
859 /// // - any leading `0` bits of the last entry of `data` are not considered part of
860 /// // the information-carrying payload; and
861 /// // - the (obligatory) first `1` bit of the last entry of `data` defines the
862 /// // boundary between unused bits and information-carrying bits; it is therefore
863 /// // also not considered part of the payload.
864 /// // Therefore, `stack2` below only contains `32 * 2 - 7 - 1 = 56` bits of payload,
865 /// // which cannot be exported into an integer number of `u32` words:
866 /// let stack2 = constriction::stream::stack::DefaultAnsCoder::from_compressed(data.clone()).unwrap();
867 /// assert!(stack2.clone().into_binary().is_err()); // <-- Returns an error.
868 ///
869 /// // Use `into_compressed` to retrieve the data in this case:
870 /// assert_eq!(stack2.into_compressed().unwrap(), data);
871 ///
872 /// // Calling `into_compressed` on `stack1` would append an extra `1` bit to indicate
873 /// // the boundary between information-carrying bits and padding `0` bits:
874 /// assert_eq!(stack1.into_compressed().unwrap(), vec![0x89ab_cdef, 0x0123_4567, 0x0000_0001]);
875 /// ```
876 ///
877 /// [`from_binary`]: #method.from_binary
878 /// [`into_compressed`]: #method.into_compressed
879 pub fn into_binary(mut self) -> Result<Backend, Option<Backend::WriteError>> {
880 let valid_bits = (State::BITS - 1).wrapping_sub(self.state.leading_zeros() as usize);
881
882 if valid_bits % Word::BITS != 0 || valid_bits == usize::MAX {
883 Err(None)
884 } else {
885 let truncated_state = self.state ^ (State::one() << valid_bits);
886 self.bulk
887 .extend_from_iter(bit_array_to_chunks_truncated(truncated_state).rev())?;
888 Ok(self.bulk)
889 }
890 }
891}
892
893impl<Word, State, Buf> AnsCoder<Word, State, Cursor<Word, Buf>>
894where
895 Word: BitArray,
896 State: BitArray + AsPrimitive<Word> + From<Word>,
897 Buf: AsRef<[Word]> + AsMut<[Word]>,
898{
899 pub fn into_reversed(self) -> AnsCoder<Word, State, Reverse<Cursor<Word, Buf>>> {
900 let (bulk, state) = self.into_raw_parts();
901 AnsCoder {
902 bulk: bulk.into_reversed(),
903 state,
904 phantom: PhantomData,
905 }
906 }
907}
908
909impl<Word, State, Buf> AnsCoder<Word, State, Reverse<Cursor<Word, Buf>>>
910where
911 Word: BitArray,
912 State: BitArray + AsPrimitive<Word> + From<Word>,
913 Buf: AsRef<[Word]> + AsMut<[Word]>,
914{
915 pub fn into_reversed(self) -> AnsCoder<Word, State, Cursor<Word, Buf>> {
916 let (bulk, state) = self.into_raw_parts();
917 AnsCoder {
918 bulk: bulk.into_reversed(),
919 state,
920 phantom: PhantomData,
921 }
922 }
923}
924
925impl<Word, State, Backend> Code for AnsCoder<Word, State, Backend>
926where
927 Word: BitArray + Into<State>,
928 State: BitArray + AsPrimitive<Word>,
929{
930 type Word = Word;
931 type State = State;
932
933 #[inline(always)]
934 fn state(&self) -> Self::State {
935 self.state
936 }
937}
938
939impl<Word, State, Backend, const PRECISION: usize> Encode<PRECISION>
940 for AnsCoder<Word, State, Backend>
941where
942 Word: BitArray + Into<State>,
943 State: BitArray + AsPrimitive<Word>,
944 Backend: WriteWords<Word>,
945{
946 type FrontendError = DefaultEncoderFrontendError;
947 type BackendError = Backend::WriteError;
948
949 /// Encodes a single symbol and appends it to the compressed data.
950 ///
951 /// This is a low level method. You probably usually want to call a batch method
952 /// like [`encode_symbols`](#method.encode_symbols) or
953 /// [`encode_iid_symbols`](#method.encode_iid_symbols) instead. See examples there.
954 ///
955 /// The bound `impl Borrow<M::Symbol>` on argument `symbol` essentially means that
956 /// you can provide the symbol either by value or by reference, at your choice.
957 ///
958 /// Returns [`Err(ImpossibleSymbol)`] if `symbol` has zero probability under the
959 /// entropy model `model`. This error can usually be avoided by using a
960 /// "leaky" distribution as the entropy model, i.e., a distribution that assigns a
961 /// nonzero probability to all symbols within a finite domain. Leaky distributions
962 /// can be constructed with, e.g., a
963 /// [`LeakyQuantizer`](models/struct.LeakyQuantizer.html) or with
964 /// [`LeakyCategorical::from_floating_point_probabilities`](
965 /// models/struct.LeakyCategorical.html#method.from_floating_point_probabilities).
966 ///
967 /// TODO: move this and similar doc comments to the trait definition.
968 ///
969 /// [`Err(ImpossibleSymbol)`]: enum.EncodingError.html#variant.ImpossibleSymbol
970 fn encode_symbol<M>(
971 &mut self,
972 symbol: impl Borrow<M::Symbol>,
973 model: M,
974 ) -> Result<(), DefaultEncoderError<Self::BackendError>>
975 where
976 M: EncoderModel<PRECISION>,
977 M::Probability: Into<Self::Word>,
978 Self::Word: AsPrimitive<M::Probability>,
979 {
980 generic_static_asserts!(
981 (Word: BitArray, State:BitArray; const PRECISION: usize);
982 PROBABILITY_SUPPORTS_PRECISION: State::BITS >= Word::BITS + PRECISION;
983 NON_ZERO_PRECISION: PRECISION > 0;
984 STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
985 );
986
987 let (left_sided_cumulative, probability) = model
988 .left_cumulative_and_probability(symbol)
989 .ok_or_else(|| DefaultEncoderFrontendError::ImpossibleSymbol.into_coder_error())?;
990
991 if (self.state >> (State::BITS - PRECISION)) >= probability.get().into().into() {
992 self.bulk.write(self.state.as_())?;
993 self.state = self.state >> Word::BITS;
994 // At this point, the invariant on `self.state` (see its doc comment) is
995 // temporarily violated, but it will be restored below.
996 }
997
998 let remainder = (self.state % probability.get().into().into()).as_().as_();
999 let prefix = self.state / probability.get().into().into();
1000 let quantile = left_sided_cumulative + remainder;
1001 self.state = (prefix << PRECISION) | quantile.into().into();
1002
1003 Ok(())
1004 }
1005
1006 fn maybe_full(&self) -> bool {
1007 self.bulk.maybe_full()
1008 }
1009}
1010
1011impl<Word, State, Backend, const PRECISION: usize> Decode<PRECISION>
1012 for AnsCoder<Word, State, Backend>
1013where
1014 Word: BitArray + Into<State>,
1015 State: BitArray + AsPrimitive<Word>,
1016 Backend: ReadWords<Word, Stack>,
1017{
1018 /// ANS coding is surjective, and we (deliberately) allow decoding past EOF (in a
1019 /// deterministic way) for consistency. Therefore, decoding cannot fail in the front
1020 /// end.
1021 type FrontendError = Infallible;
1022
1023 type BackendError = Backend::ReadError;
1024
1025 #[inline(always)]
1026 fn decode_symbol<M>(
1027 &mut self,
1028 model: M,
1029 ) -> Result<M::Symbol, CoderError<Self::FrontendError, Self::BackendError>>
1030 where
1031 M: DecoderModel<PRECISION>,
1032 M::Probability: Into<Self::Word>,
1033 Self::Word: AsPrimitive<M::Probability>,
1034 {
1035 generic_static_asserts!(
1036 (Word: BitArray, State:BitArray; const PRECISION: usize);
1037 PROBABILITY_SUPPORTS_PRECISION: State::BITS >= Word::BITS + PRECISION;
1038 NON_ZERO_PRECISION: PRECISION > 0;
1039 STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
1040 );
1041
1042 let quantile = (self.state % (State::one() << PRECISION)).as_().as_();
1043 let (symbol, left_sided_cumulative, probability) = model.quantile_function(quantile);
1044 let remainder = quantile - left_sided_cumulative;
1045 self.state =
1046 (self.state >> PRECISION) * probability.get().into().into() + remainder.into().into();
1047 if self.state < State::one() << (State::BITS - Word::BITS) {
1048 // Invariant on `self.state` (see its doc comment) is violated. Restore it by
1049 // refilling with a compressed word from `self.bulk` if available.
1050 if let Some(word) = self.bulk.read()? {
1051 self.state = (self.state << Word::BITS) | word.into();
1052 }
1053 }
1054
1055 Ok(symbol)
1056 }
1057
1058 fn maybe_exhausted(&self) -> bool {
1059 self.is_empty()
1060 }
1061}
1062
1063impl<Word, State, Backend> PosSeek for AnsCoder<Word, State, Backend>
1064where
1065 Word: BitArray + Into<State>,
1066 State: BitArray + AsPrimitive<Word>,
1067 Backend: PosSeek,
1068 Self: Code,
1069{
1070 type Position = (Backend::Position, <Self as Code>::State);
1071}
1072
1073impl<Word, State, Backend> Seek for AnsCoder<Word, State, Backend>
1074where
1075 Word: BitArray + Into<State>,
1076 State: BitArray + AsPrimitive<Word>,
1077 Backend: Seek,
1078{
1079 fn seek(&mut self, (pos, state): Self::Position) -> Result<(), ()> {
1080 self.bulk.seek(pos)?;
1081 self.state = state;
1082 Ok(())
1083 }
1084}
1085
1086impl<Word, State, Backend> Pos for AnsCoder<Word, State, Backend>
1087where
1088 Word: BitArray + Into<State>,
1089 State: BitArray + AsPrimitive<Word>,
1090 Backend: Pos,
1091{
1092 fn pos(&self) -> Self::Position {
1093 (self.bulk.pos(), self.state())
1094 }
1095}
1096
1097/// Provides temporary read-only access to the compressed data wrapped in a
1098/// [`AnsCoder`].
1099///
1100/// Dereferences to `&[Word]`. See [`Coder::get_compressed`] for an example.
1101///
1102/// [`AnsCoder`]: struct.Coder.html
1103/// [`Coder::get_compressed`]: struct.Coder.html#method.get_compressed
1104struct CoderGuard<'a, Word, State, Backend, const SEALED: bool>
1105where
1106 Word: BitArray + Into<State>,
1107 State: BitArray + AsPrimitive<Word>,
1108 Backend: WriteWords<Word> + ReadWords<Word, Stack>,
1109{
1110 inner: &'a mut AnsCoder<Word, State, Backend>,
1111}
1112
1113impl<'a, Word, State, Backend, const SEALED: bool> CoderGuard<'a, Word, State, Backend, SEALED>
1114where
1115 Word: BitArray + Into<State>,
1116 State: BitArray + AsPrimitive<Word>,
1117 Backend: WriteWords<Word> + ReadWords<Word, Stack>,
1118{
1119 #[inline(always)]
1120 fn new(
1121 ans: &'a mut AnsCoder<Word, State, Backend>,
1122 ) -> Result<Self, CoderError<(), Backend::WriteError>> {
1123 // Append state. Will be undone in `<Self as Drop>::drop`.
1124 let mut chunks_rev = bit_array_to_chunks_truncated(ans.state);
1125 if SEALED && chunks_rev.next() != Some(Word::one()) {
1126 return Err(CoderError::Frontend(()));
1127 }
1128 for chunk in chunks_rev.rev() {
1129 ans.bulk.write(chunk)?
1130 }
1131
1132 Ok(Self { inner: ans })
1133 }
1134}
1135
1136impl<Word, State, Backend, const SEALED: bool> Drop for CoderGuard<'_, Word, State, Backend, SEALED>
1137where
1138 Word: BitArray + Into<State>,
1139 State: BitArray + AsPrimitive<Word>,
1140 Backend: WriteWords<Word> + ReadWords<Word, Stack>,
1141{
1142 fn drop(&mut self) {
1143 // Revert what we did in `Self::new`.
1144 let mut chunks_rev = bit_array_to_chunks_truncated(self.inner.state);
1145 if SEALED {
1146 chunks_rev.next();
1147 }
1148 for _ in chunks_rev {
1149 core::mem::drop(self.inner.bulk.read());
1150 }
1151 }
1152}
1153
1154impl<Word, State, Backend, const SEALED: bool> Deref
1155 for CoderGuard<'_, Word, State, Backend, SEALED>
1156where
1157 Word: BitArray + Into<State>,
1158 State: BitArray + AsPrimitive<Word>,
1159 Backend: WriteWords<Word> + ReadWords<Word, Stack>,
1160{
1161 type Target = Backend;
1162
1163 fn deref(&self) -> &Self::Target {
1164 &self.inner.bulk
1165 }
1166}
1167
1168impl<Word, State, Backend, const SEALED: bool> Debug
1169 for CoderGuard<'_, Word, State, Backend, SEALED>
1170where
1171 Word: BitArray + Into<State>,
1172 State: BitArray + AsPrimitive<Word>,
1173 Backend: WriteWords<Word> + ReadWords<Word, Stack> + Debug,
1174{
1175 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1176 Debug::fmt(&**self, f)
1177 }
1178}
1179
1180#[cfg(test)]
1181mod tests {
1182 use super::super::model::{
1183 ContiguousCategoricalEntropyModel, DefaultLeakyQuantizer, IterableEntropyModel,
1184 LeakyQuantizer,
1185 };
1186 use super::*;
1187 extern crate std;
1188 use std::dbg;
1189
1190 use probability::distribution::{Gaussian, Inverse};
1191 use rand_xoshiro::{
1192 rand_core::{RngCore, SeedableRng},
1193 Xoshiro256StarStar,
1194 };
1195
1196 #[test]
1197 fn compress_none() {
1198 let coder1 = DefaultAnsCoder::new();
1199 assert!(coder1.is_empty());
1200 let compressed = coder1.into_compressed().unwrap();
1201 assert!(compressed.is_empty());
1202
1203 let coder2 = DefaultAnsCoder::from_compressed(compressed).unwrap();
1204 assert!(coder2.is_empty());
1205 }
1206
1207 #[test]
1208 fn compress_one() {
1209 generic_compress_few(core::iter::once(5), 1)
1210 }
1211
1212 #[test]
1213 fn compress_two() {
1214 generic_compress_few([2, 8].iter().cloned(), 1)
1215 }
1216
1217 #[test]
1218 fn compress_ten() {
1219 generic_compress_few(0..10, 2)
1220 }
1221
1222 #[test]
1223 fn compress_twenty() {
1224 generic_compress_few(-10..10, 4)
1225 }
1226
1227 fn generic_compress_few<I>(symbols: I, expected_size: usize)
1228 where
1229 I: IntoIterator<Item = i32>,
1230 I::IntoIter: Clone + DoubleEndedIterator,
1231 {
1232 let symbols = symbols.into_iter();
1233
1234 let mut encoder = DefaultAnsCoder::new();
1235 let quantizer = DefaultLeakyQuantizer::new(-127..=127);
1236 let model = quantizer.quantize(Gaussian::new(3.2, 5.1));
1237
1238 // We don't reuse the same encoder for decoding because we want to test
1239 // if exporting and re-importing of compressed data works.
1240 encoder.encode_iid_symbols(symbols.clone(), model).unwrap();
1241 let compressed = encoder.into_compressed().unwrap();
1242 assert_eq!(compressed.len(), expected_size);
1243
1244 let mut decoder = DefaultAnsCoder::from_compressed(compressed).unwrap();
1245 for symbol in symbols.rev() {
1246 assert_eq!(decoder.decode_symbol(model).unwrap(), symbol);
1247 }
1248 assert!(decoder.is_empty());
1249 }
1250
1251 #[test]
1252 fn compress_many_u32_u64_32() {
1253 generic_compress_many::<u32, u64, u32, 32>();
1254 }
1255
1256 #[test]
1257 fn compress_many_u32_u64_24() {
1258 generic_compress_many::<u32, u64, u32, 24>();
1259 }
1260
1261 #[test]
1262 fn compress_many_u32_u64_16() {
1263 generic_compress_many::<u32, u64, u16, 16>();
1264 }
1265
1266 #[test]
1267 fn compress_many_u32_u64_8() {
1268 generic_compress_many::<u32, u64, u8, 8>();
1269 }
1270
1271 #[test]
1272 fn compress_many_u16_u64_16() {
1273 generic_compress_many::<u16, u64, u16, 16>();
1274 }
1275
1276 #[test]
1277 fn compress_many_u16_u64_12() {
1278 generic_compress_many::<u16, u64, u16, 12>();
1279 }
1280
1281 #[test]
1282 fn compress_many_u16_u64_8() {
1283 generic_compress_many::<u16, u64, u8, 8>();
1284 }
1285
1286 #[test]
1287 fn compress_many_u8_u64_8() {
1288 generic_compress_many::<u8, u64, u8, 8>();
1289 }
1290
1291 #[test]
1292 fn compress_many_u16_u32_16() {
1293 generic_compress_many::<u16, u32, u16, 16>();
1294 }
1295
1296 #[test]
1297 fn compress_many_u16_u32_12() {
1298 generic_compress_many::<u16, u32, u16, 12>();
1299 }
1300
1301 #[test]
1302 fn compress_many_u16_u32_8() {
1303 generic_compress_many::<u16, u32, u8, 8>();
1304 }
1305
1306 #[test]
1307 fn compress_many_u8_u32_8() {
1308 generic_compress_many::<u8, u32, u8, 8>();
1309 }
1310
1311 #[test]
1312 fn compress_many_u8_u16_8() {
1313 generic_compress_many::<u8, u16, u8, 8>();
1314 }
1315
1316 fn generic_compress_many<Word, State, Probability, const PRECISION: usize>()
1317 where
1318 State: BitArray + AsPrimitive<Word>,
1319 Word: BitArray + Into<State> + AsPrimitive<Probability>,
1320 Probability: BitArray + Into<Word> + AsPrimitive<usize> + Into<f64>,
1321 u32: AsPrimitive<Probability>,
1322 usize: AsPrimitive<Probability>,
1323 f64: AsPrimitive<Probability>,
1324 i32: AsPrimitive<Probability>,
1325 {
1326 #[cfg(not(miri))]
1327 const AMT: usize = 1000;
1328
1329 #[cfg(miri)]
1330 const AMT: usize = 100;
1331
1332 let mut symbols_gaussian = Vec::with_capacity(AMT);
1333 let mut means = Vec::with_capacity(AMT);
1334 let mut stds = Vec::with_capacity(AMT);
1335
1336 let mut rng = Xoshiro256StarStar::seed_from_u64(
1337 (Word::BITS as u64).rotate_left(3 * 16)
1338 ^ (State::BITS as u64).rotate_left(2 * 16)
1339 ^ (Probability::BITS as u64).rotate_left(16)
1340 ^ PRECISION as u64,
1341 );
1342
1343 for _ in 0..AMT {
1344 let mean = (200.0 / u32::MAX as f64) * rng.next_u32() as f64 - 100.0;
1345 let std_dev = (10.0 / u32::MAX as f64) * rng.next_u32() as f64 + 0.001;
1346 let quantile = (rng.next_u32() as f64 + 0.5) / (1u64 << 32) as f64;
1347 let dist = Gaussian::new(mean, std_dev);
1348 let symbol = (dist.inverse(quantile).round() as i32).clamp(-127, 127);
1349
1350 symbols_gaussian.push(symbol);
1351 means.push(mean);
1352 stds.push(std_dev);
1353 }
1354
1355 let hist = [
1356 1u32, 186545, 237403, 295700, 361445, 433686, 509456, 586943, 663946, 737772, 1657269,
1357 896675, 922197, 930672, 916665, 0, 0, 0, 0, 0, 723031, 650522, 572300, 494702, 418703,
1358 347600, 1, 283500, 226158, 178194, 136301, 103158, 76823, 55540, 39258, 27988, 54269,
1359 ];
1360 let categorical_probabilities = hist.iter().map(|&x| x as f64).collect::<Vec<_>>();
1361 let categorical =
1362 ContiguousCategoricalEntropyModel::<Probability, _, PRECISION>::from_floating_point_probabilities_fast::<f64>(
1363 &categorical_probabilities,None
1364 )
1365 .unwrap();
1366 let mut symbols_categorical = Vec::with_capacity(AMT);
1367 let max_probability = Probability::max_value() >> (Probability::BITS - PRECISION);
1368 for _ in 0..AMT {
1369 let quantile = rng.next_u32().as_() & max_probability;
1370 let symbol = categorical.quantile_function(quantile).0;
1371 symbols_categorical.push(symbol);
1372 }
1373
1374 let mut ans = AnsCoder::<Word, State>::new();
1375
1376 ans.encode_iid_symbols_reverse(&symbols_categorical, &categorical)
1377 .unwrap();
1378 dbg!(
1379 ans.num_valid_bits(),
1380 AMT as f64 * categorical.entropy_base2::<f64>()
1381 );
1382
1383 let quantizer = LeakyQuantizer::<_, _, Probability, PRECISION>::new(-127..=127);
1384 ans.encode_symbols_reverse(symbols_gaussian.iter().zip(&means).zip(&stds).map(
1385 |((&symbol, &mean), &core)| (symbol, quantizer.quantize(Gaussian::new(mean, core))),
1386 ))
1387 .unwrap();
1388 dbg!(ans.num_valid_bits());
1389
1390 // Test if import/export of compressed data works.
1391 let compressed = ans.into_compressed().unwrap();
1392 let mut ans = AnsCoder::from_compressed(compressed).unwrap();
1393
1394 let reconstructed_gaussian = ans
1395 .decode_symbols(
1396 means
1397 .iter()
1398 .zip(&stds)
1399 .map(|(&mean, &core)| quantizer.quantize(Gaussian::new(mean, core))),
1400 )
1401 .collect::<Result<Vec<_>, CoderError<Infallible, Infallible>>>()
1402 .unwrap();
1403 let reconstructed_categorical = ans
1404 .decode_iid_symbols(AMT, &categorical)
1405 .collect::<Result<Vec<_>, CoderError<Infallible, Infallible>>>()
1406 .unwrap();
1407
1408 assert!(ans.is_empty());
1409
1410 assert_eq!(symbols_gaussian, reconstructed_gaussian);
1411 assert_eq!(symbols_categorical, reconstructed_categorical);
1412 }
1413
1414 #[test]
1415 fn seek() {
1416 #[cfg(not(miri))]
1417 let (num_chunks, symbols_per_chunk) = (100, 100);
1418
1419 #[cfg(miri)]
1420 let (num_chunks, symbols_per_chunk) = (10, 10);
1421
1422 let quantizer = DefaultLeakyQuantizer::new(-100..=100);
1423 let model = quantizer.quantize(Gaussian::new(0.0, 10.0));
1424
1425 let mut encoder = DefaultAnsCoder::new();
1426
1427 let mut rng = Xoshiro256StarStar::seed_from_u64(123);
1428 let mut symbols = Vec::with_capacity(num_chunks);
1429 let mut jump_table = Vec::with_capacity(num_chunks);
1430 let (initial_pos, initial_state) = encoder.pos();
1431
1432 for _ in 0..num_chunks {
1433 let chunk = (0..symbols_per_chunk)
1434 .map(|_| model.quantile_function(rng.next_u32() % (1 << 24)).0)
1435 .collect::<Vec<_>>();
1436 encoder.encode_iid_symbols_reverse(&chunk, &model).unwrap();
1437 symbols.push(chunk);
1438 jump_table.push(encoder.pos());
1439 }
1440
1441 // Test decoding from back to front.
1442 {
1443 let mut seekable_decoder = encoder.as_seekable_decoder();
1444
1445 // Verify that decoding leads to the same positions and states.
1446 for (chunk, &(pos, state)) in symbols.iter().zip(&jump_table).rev() {
1447 assert_eq!(seekable_decoder.pos(), (pos, state));
1448 let decoded = seekable_decoder
1449 .decode_iid_symbols(symbols_per_chunk, &model)
1450 .collect::<Result<Vec<_>, _>>()
1451 .unwrap();
1452 assert_eq!(&decoded, chunk)
1453 }
1454 assert_eq!(seekable_decoder.pos(), (initial_pos, initial_state));
1455 assert!(seekable_decoder.is_empty());
1456
1457 // Seek to some random offsets in the jump table and decode one chunk
1458 for _ in 0..100 {
1459 let chunk_index = rng.next_u32() as usize % num_chunks;
1460 let (pos, state) = jump_table[chunk_index];
1461 seekable_decoder.seek((pos, state)).unwrap();
1462 let decoded = seekable_decoder
1463 .decode_iid_symbols(symbols_per_chunk, &model)
1464 .collect::<Result<Vec<_>, _>>()
1465 .unwrap();
1466 assert_eq!(&decoded, &symbols[chunk_index])
1467 }
1468 }
1469
1470 // Reverse compressed data, map positions in jump table to reversed positions,
1471 // and test decoding from front to back.
1472 let mut compressed = encoder.into_compressed().unwrap();
1473 compressed.reverse();
1474 for (pos, _state) in jump_table.iter_mut() {
1475 *pos = compressed.len() - *pos;
1476 }
1477 let initial_pos = compressed.len() - initial_pos;
1478
1479 {
1480 let mut seekable_decoder = AnsCoder::from_reversed_compressed(compressed).unwrap();
1481
1482 // Verify that decoding leads to the expected positions and states.
1483 for (chunk, &(pos, state)) in symbols.iter().zip(&jump_table).rev() {
1484 assert_eq!(seekable_decoder.pos(), (pos, state));
1485 let decoded = seekable_decoder
1486 .decode_iid_symbols(symbols_per_chunk, &model)
1487 .collect::<Result<Vec<_>, _>>()
1488 .unwrap();
1489 assert_eq!(&decoded, chunk)
1490 }
1491 assert_eq!(seekable_decoder.pos(), (initial_pos, initial_state));
1492 assert!(seekable_decoder.is_empty());
1493
1494 // Seek to some random offsets in the jump table and decode one chunk each time.
1495 for _ in 0..100 {
1496 let chunk_index = rng.next_u32() as usize % num_chunks;
1497 let (pos, state) = jump_table[chunk_index];
1498 seekable_decoder.seek((pos, state)).unwrap();
1499 let decoded = seekable_decoder
1500 .decode_iid_symbols(symbols_per_chunk, &model)
1501 .collect::<Result<Vec<_>, _>>()
1502 .unwrap();
1503 assert_eq!(&decoded, &symbols[chunk_index])
1504 }
1505 }
1506 }
1507}