rans 0.2.1

rANS (range variant of Asymmetric Numeral Systems) encoder and decoder
Documentation
/// Interleaved multi-stream rANS decoder interface.
pub trait RansDecoderMulti<const N: usize> {
    /// Type of a Symbol value that can be encoded using this decoder.
    type Symbol: RansDecSymbol;

    /// Gets the cumulative frequency for the current symbol at specified
    /// channel. Note that this does not advance the data position; for
    /// that, use [`Self::advance_at()`].
    ///
    /// # Examples
    /// ```
    /// use rans::byte_decoder::ByteRansDecoderMulti;
    /// use rans::RansDecoderMulti;
    ///
    /// let mut decoder = ByteRansDecoderMulti::<2>::new([2, 0, 0, 1, 0, 0, 0, 1]);
    /// assert_eq!(decoder.get_at(0, 2), 2);
    /// assert_eq!(decoder.get_at(1, 2), 0);
    /// ```
    #[must_use]
    fn get_at(&mut self, channel: usize, scale_bits: u32) -> u32;

    /// Advances the data position after reading a symbol at given channel.
    /// Equivalent to calling [`Self::advance_step_at()`] and
    /// [`Self::renorm_at()`].
    ///
    /// # Examples
    /// ```
    /// use rans::byte_decoder::{ByteRansDecSymbol, ByteRansDecoderMulti};
    /// use rans::{RansDecSymbol, RansDecoderMulti};
    ///
    /// let mut decoder = ByteRansDecoderMulti::<2>::new([2, 0, 0, 1, 0, 0, 0, 1]);
    /// let symbol_1 = ByteRansDecSymbol::new(0, 2);
    /// let symbol_2 = ByteRansDecSymbol::new(2, 2);
    /// assert_eq!(decoder.get_at(0, 2), 2);
    /// decoder.advance_at(0, &symbol_2, 2);
    /// assert_eq!(decoder.get_at(1, 2), 0);
    /// decoder.advance_at(1, &symbol_1, 2);
    /// assert_eq!(decoder.get_at(0, 2), 0);
    /// assert_eq!(decoder.get_at(1, 2), 0);
    /// ```
    fn advance_at(&mut self, channel: usize, symbol: &Self::Symbol, scale_bits: u32);

    /// Pops a single symbol from the internal state, without doing
    /// renormalization or modifying the internal buffer.
    ///
    /// # Examples
    /// ```
    /// use rans::byte_decoder::{ByteRansDecSymbol, ByteRansDecoderMulti};
    /// use rans::{RansDecSymbol, RansDecoderMulti};
    ///
    /// let mut decoder = ByteRansDecoderMulti::<2>::new([2, 0, 0, 1, 0, 0, 0, 1]);
    /// let symbol_1 = ByteRansDecSymbol::new(0, 2);
    /// let symbol_2 = ByteRansDecSymbol::new(2, 2);
    /// assert_eq!(decoder.get_at(0, 2), 2);
    /// decoder.advance_step_at(0, &symbol_2, 2);
    /// assert_eq!(decoder.get_at(1, 2), 0);
    /// decoder.advance_step_at(1, &symbol_1, 2);
    /// decoder.renorm_all();
    /// assert_eq!(decoder.get_at(0, 2), 0);
    /// assert_eq!(decoder.get_at(1, 2), 0);
    /// ```
    fn advance_step_at(&mut self, channel: usize, symbol: &Self::Symbol, scale_bits: u32);

    /// Renormalizes the data in the internal buffer after advancing a symbol.
    ///
    /// # Examples
    /// ```
    /// use rans::byte_decoder::{ByteRansDecSymbol, ByteRansDecoderMulti};
    /// use rans::{RansDecSymbol, RansDecoderMulti};
    ///
    /// let mut decoder = ByteRansDecoderMulti::<2>::new([2, 0, 0, 1, 0, 0, 0, 1]);
    /// let symbol_1 = ByteRansDecSymbol::new(0, 2);
    /// let symbol_2 = ByteRansDecSymbol::new(2, 2);
    /// assert_eq!(decoder.get_at(0, 2), 2);
    /// decoder.advance_step_at(0, &symbol_2, 2);
    /// assert_eq!(decoder.get_at(1, 2), 0);
    /// decoder.advance_step_at(1, &symbol_1, 2);
    /// decoder.renorm_at(0);
    /// decoder.renorm_at(1);
    /// assert_eq!(decoder.get_at(0, 2), 0);
    /// assert_eq!(decoder.get_at(1, 2), 0);
    /// ```
    fn renorm_at(&mut self, channel: usize);

    /// Renormalizes the data in all channels' internal buffers after advancing
    /// a symbol.
    ///
    /// # Examples
    /// ```
    /// use rans::byte_decoder::{ByteRansDecSymbol, ByteRansDecoderMulti};
    /// use rans::{RansDecSymbol, RansDecoderMulti};
    ///
    /// let mut decoder = ByteRansDecoderMulti::<2>::new([2, 0, 0, 1, 0, 0, 0, 1]);
    /// let symbol_1 = ByteRansDecSymbol::new(0, 2);
    /// let symbol_2 = ByteRansDecSymbol::new(2, 2);
    /// assert_eq!(decoder.get_at(0, 2), 2);
    /// decoder.advance_step_at(0, &symbol_2, 2);
    /// assert_eq!(decoder.get_at(1, 2), 0);
    /// decoder.advance_step_at(1, &symbol_1, 2);
    /// decoder.renorm_all();
    /// assert_eq!(decoder.get_at(0, 2), 0);
    /// assert_eq!(decoder.get_at(1, 2), 0);
    /// ```
    fn renorm_all(&mut self) {
        for i in 0..N {
            self.renorm_at(i);
        }
    }
}

/// Single-stream rANS decoder interface.
pub trait RansDecoder: RansDecoderMulti<1> {
    /// Gets the cumulative frequency for the current symbol. Note that this
    /// does not advance the data position; for that, use [`Self::advance()`].
    ///
    /// # Examples
    /// ```
    /// use rans::byte_decoder::{ByteRansDecSymbol, ByteRansDecoder};
    /// use rans::{RansDecSymbol, RansDecoder};
    ///
    /// let mut decoder = ByteRansDecoder::new([2, 0, 0, 2]);
    /// assert_eq!(decoder.get(4), 2);
    /// assert_eq!(decoder.get(4), 2);
    /// ```
    #[must_use]
    fn get(&mut self, scale_bits: u32) -> u32 {
        self.get_at(0, scale_bits)
    }

    /// Advances the data position after reading a symbol.
    ///
    /// # Examples
    /// ```
    /// use rans::byte_decoder::{ByteRansDecSymbol, ByteRansDecoder};
    /// use rans::{RansDecSymbol, RansDecoder};
    ///
    /// let mut decoder = ByteRansDecoder::new([2, 0, 0, 2]);
    /// let symbol = ByteRansDecSymbol::new(2, 2);
    /// assert_eq!(decoder.get(2), 2);
    /// decoder.advance(&symbol, 2);
    /// assert_eq!(decoder.get(2), 0);
    /// ```
    fn advance(&mut self, symbol: &Self::Symbol, scale_bits: u32) {
        self.advance_at(0, symbol, scale_bits);
    }
}

/// A symbol that can be decoded using a rANS decoder.
pub trait RansDecSymbol {
    /// Creates a new rANS decoder symbol instance.
    ///
    /// # Examples
    /// ```
    /// use rans::byte_decoder::ByteRansDecSymbol;
    /// use rans::RansDecSymbol;
    ///
    /// let _symbol = ByteRansDecSymbol::new(0, 2);
    /// ```
    #[must_use]
    fn new(cum_freq: u32, freq: u32) -> Self;

    /// Returns this symbol's cumulative frequency.
    ///
    /// # Examples
    /// ```
    /// use rans::byte_decoder::ByteRansDecSymbol;
    /// use rans::RansDecSymbol;
    ///
    /// let symbol = ByteRansDecSymbol::new(0, 2);
    /// assert_eq!(symbol.cum_freq(), 0);
    /// ```
    #[must_use]
    fn cum_freq(&self) -> u32;

    /// Returns this symbol's frequency.
    ///
    /// # Examples
    /// ```
    /// use rans::byte_decoder::ByteRansDecSymbol;
    /// use rans::RansDecSymbol;
    ///
    /// let symbol = ByteRansDecSymbol::new(0, 2);
    /// assert_eq!(symbol.freq(), 2);
    /// ```
    #[must_use]
    fn freq(&self) -> u32;
}

#[cfg(test)]
pub(crate) mod tests {
    use crate::decoder::RansDecSymbol;
    use crate::{RansDecoder, RansDecoderMulti};

    pub(crate) fn test_decode_empty<T: RansDecoder>(mut decoder: T) {
        assert_eq!(decoder.get(2), 0);
    }

    pub(crate) fn test_decode_two_symbols<T: RansDecoder>(mut decoder: T) {
        let symbol1 = T::Symbol::new(0, 2);
        let symbol2 = T::Symbol::new(2, 2);

        let cum_freq = decoder.get(2);
        assert_eq!(cum_freq, 2);
        decoder.advance(&symbol2, 2);
        let cum_freq = decoder.get(2);
        assert_eq!(cum_freq, 0);
        decoder.advance(&symbol1, 2);
    }

    pub(crate) fn test_decode_symbols_clone<T>(mut decoder: T)
    where
        T: RansDecoder,
        T::Symbol: Clone,
    {
        let symbol1 = T::Symbol::new(0, 2);
        let symbol2 = T::Symbol::new(2, 2);

        let cum_freq = decoder.get(2);
        assert_eq!(cum_freq, 2);
        #[allow(clippy::redundant_clone)]
        decoder.advance(&symbol2.clone(), 2);
        let cum_freq = decoder.get(2);
        assert_eq!(cum_freq, 0);
        #[allow(clippy::redundant_clone)]
        decoder.advance(&symbol1.clone(), 2);
    }

    pub(crate) fn test_decode_more_data<T: RansDecoder>(mut decoder: T) {
        const SCALE_BITS: u32 = 8;
        let s1 = T::Symbol::new(0, 3);
        let s2 = T::Symbol::new(3, 10);
        let s3 = T::Symbol::new(13, 58);
        let s4 = T::Symbol::new(71, 34);
        let s5 = T::Symbol::new(105, 41);
        let s6 = T::Symbol::new(146, 17);
        let s7 = T::Symbol::new(163, 55);
        let s8 = T::Symbol::new(218, 38);
        let symbols = [&s1, &s2, &s3, &s4, &s5, &s6, &s7, &s8];

        let mut symbol_data = [
            &s1, &s2, &s3, &s4, &s5, &s6, &s7, &s8, &s3, &s3, &s3, &s3, &s3, &s5, &s4, &s3, &s4,
            &s3, &s7, &s8, &s8, &s6, &s5, &s3, &s4, &s7, &s6, &s7, &s7, &s3, &s4, &s5,
        ];
        symbol_data.reverse();

        for (i, expected_symbol) in symbol_data.iter().enumerate() {
            let cum_freq = decoder.get(SCALE_BITS);
            let actual_symbol = get_symbol(&symbols, cum_freq);
            assert_eq!(
                actual_symbol.cum_freq(),
                expected_symbol.cum_freq(),
                "Invalid symbol at position {} (decoded cumulative frequency: {})",
                i,
                cum_freq
            );
            decoder.advance(expected_symbol, SCALE_BITS);
        }
    }

    #[must_use]
    fn get_symbol<'a, T: RansDecSymbol>(symbols: &'a [&T], cum_freq: u32) -> &'a T {
        for symbol in symbols {
            if cum_freq < symbol.cum_freq() as u32 + symbol.freq() as u32 {
                return symbol;
            }
        }
        unreachable!("Invalid symbol frequency");
    }

    pub(crate) fn test_decode_interleaved<T: RansDecoderMulti<2>>(mut decoder: T) {
        const SCALE_BITS: u32 = 4;
        let symbol1 = T::Symbol::new(0, 4);
        let symbol2 = T::Symbol::new(4, 4);
        let symbol3 = T::Symbol::new(8, 4);
        let symbol4 = T::Symbol::new(12, 4);

        assert_eq!(decoder.get_at(0, SCALE_BITS), 12);
        assert_eq!(decoder.get_at(1, SCALE_BITS), 0);
        decoder.advance_step_at(0, &symbol4, SCALE_BITS);
        decoder.advance_step_at(1, &symbol1, SCALE_BITS);
        decoder.renorm_all();
        assert_eq!(decoder.get_at(0, SCALE_BITS), 8);
        assert_eq!(decoder.get_at(1, SCALE_BITS), 0);
        decoder.advance_step_at(0, &symbol3, SCALE_BITS);
        decoder.advance_step_at(1, &symbol1, SCALE_BITS);
        decoder.renorm_all();
        assert_eq!(decoder.get_at(0, SCALE_BITS), 4);
        assert_eq!(decoder.get_at(1, SCALE_BITS), 0);
        decoder.advance_step_at(0, &symbol2, SCALE_BITS);
        decoder.advance_step_at(1, &symbol1, SCALE_BITS);
        decoder.renorm_all();
        assert_eq!(decoder.get_at(0, SCALE_BITS), 0);
        assert_eq!(decoder.get_at(1, SCALE_BITS), 0);
        decoder.advance_step_at(0, &symbol1, SCALE_BITS);
        decoder.advance_step_at(1, &symbol1, SCALE_BITS);
        decoder.renorm_all();
    }
}