arithmetic_coding_core/model/
fixed_length.rs

1//! Helper trait for creating fixed-length Models
2
3use std::ops::Range;
4
5use crate::BitStore;
6
7/// A [`Model`] is used to calculate the probability of a given symbol occurring
8/// in a sequence.
9///
10/// The [`Model`] is used both for encoding and decoding. A
11/// 'fixed-length' model always expects an exact number of symbols, and so does
12/// not need to encode an EOF symbol.
13///
14/// A fixed length model can be converted into a regular model using the
15/// convenience [`Wrapper`] type.
16///
17/// The more accurately a [`Model`] is able to predict the next symbol, the
18/// greater the compression ratio will be.
19///
20/// # Example
21///
22/// ```
23/// # use std::convert::Infallible;
24/// # use std::ops::Range;
25/// #
26/// # use arithmetic_coding_core::fixed_length;
27///
28/// pub enum Symbol {
29///     A,
30///     B,
31///     C,
32/// }
33///
34/// pub struct MyModel;
35///
36/// impl fixed_length::Model for MyModel {
37///     type B = u32;
38///     type Symbol = Symbol;
39///     type ValueError = Infallible;
40///
41///     fn probability(&self, symbol: &Self::Symbol) -> Result<Range<u32>, Infallible> {
42///         Ok(match symbol {
43///             Symbol::A => 0..1,
44///             Symbol::B => 1..2,
45///             Symbol::C => 2..3,
46///         })
47///     }
48///
49///     fn symbol(&self, value: Self::B) -> Self::Symbol {
50///         match value {
51///             0..1 => Symbol::A,
52///             1..2 => Symbol::B,
53///             2..3 => Symbol::C,
54///             _ => unreachable!(),
55///         }
56///     }
57///
58///     fn max_denominator(&self) -> u32 {
59///         3
60///     }
61///
62///     fn length(&self) -> usize {
63///         3
64///     }
65/// }
66/// ```
67pub trait Model {
68    /// The type of symbol this [`Model`] describes
69    type Symbol;
70
71    /// Invalid symbol error
72    type ValueError: std::error::Error;
73
74    /// The internal representation to use for storing integers
75    type B: BitStore;
76
77    /// Given a symbol, return an interval representing the probability of that
78    /// symbol occurring.
79    ///
80    /// This is given as a range, over the denominator given by
81    /// [`Model::denominator`]. This range should in general include `EOF`,
82    /// which is denoted by `None`.
83    ///
84    /// For example, from the set {heads, tails}, the interval representing
85    /// heads could be `0..1`, and tails would be `1..2`, and `EOF` could be
86    /// `2..3` (with a denominator of `3`).
87    ///
88    /// This is the inverse of the [`Model::symbol`] method
89    ///
90    /// # Errors
91    ///
92    /// This returns a custom error if the given symbol is not valid
93    fn probability(&self, symbol: &Self::Symbol) -> Result<Range<Self::B>, Self::ValueError>;
94
95    /// The denominator for probability ranges. See [`Model::probability`].
96    ///
97    /// By default this method simply returns the [`Model::max_denominator`],
98    /// which is suitable for non-adaptive models.
99    ///
100    /// In adaptive models this value may change, however it should never exceed
101    /// [`Model::max_denominator`], or it becomes possible for the
102    /// [`Encoder`](crate::Encoder) and [`Decoder`](crate::Decoder) to panic due
103    /// to overflow or underflow.
104    fn denominator(&self) -> Self::B {
105        self.max_denominator()
106    }
107
108    /// The maximum denominator used for probability ranges. See
109    /// [`Model::probability`].
110    ///
111    /// This value is used to calculate an appropriate precision for the
112    /// encoding, therefore this value must not change, and
113    /// [`Model::denominator`] must never exceed it.
114    fn max_denominator(&self) -> Self::B;
115
116    /// Given a value, return the symbol whose probability range it falls in.
117    ///
118    /// `None` indicates `EOF`
119    ///
120    /// This is the inverse of the [`Model::probability`] method
121    fn symbol(&self, value: Self::B) -> Self::Symbol;
122
123    /// Update the current state of the model with the latest symbol.
124    ///
125    /// This method only needs to be implemented for 'adaptive' models. It's a
126    /// no-op by default.
127    fn update(&mut self, _symbol: &Self::Symbol) {}
128
129    /// The total number of symbols to encode.
130    ///
131    /// This must not change during encoding/decoding.
132    fn length(&self) -> usize;
133}
134
135/// A wrapper which converts a [`fixed_length::Model`](Model) to a
136/// [`crate::Model`].
137#[derive(Debug, Clone)]
138pub struct Wrapper<M>
139where
140    M: Model,
141{
142    model: M,
143    remaining: usize,
144}
145
146impl<M> Wrapper<M>
147where
148    M: Model,
149{
150    /// Construct a new wrapper from a [`fixed_length::Model`](Model)
151    pub fn new(model: M) -> Self {
152        let remaining = model.length();
153        Self { model, remaining }
154    }
155}
156
157impl<M> crate::Model for Wrapper<M>
158where
159    M: Model,
160{
161    type B = M::B;
162    type Symbol = M::Symbol;
163    type ValueError = Error<M::ValueError>;
164
165    fn probability(
166        &self,
167        symbol: Option<&Self::Symbol>,
168    ) -> Result<Range<Self::B>, Self::ValueError> {
169        if self.remaining > 0 {
170            symbol.map_or(
171                // We are expecting more symbols, but got an EOF
172                Err(Self::ValueError::UnexpectedEof),
173                // Expected a symbol and got one. return the probability.
174                |s| self.model.probability(s).map_err(Self::ValueError::Value),
175            )
176        } else if symbol.is_some() {
177            // we should be finished, but got an extra symbol
178            Err(Error::UnexpectedSymbol)
179        } else {
180            // got an EOF when we expected it, return a 100% probability
181            Ok(Self::B::ZERO..self.denominator())
182        }
183    }
184
185    fn max_denominator(&self) -> Self::B {
186        self.model.max_denominator()
187    }
188
189    fn symbol(&self, value: Self::B) -> Option<Self::Symbol> {
190        if self.remaining > 0 {
191            Some(self.model.symbol(value))
192        } else {
193            None
194        }
195    }
196
197    fn denominator(&self) -> Self::B {
198        self.model.denominator()
199    }
200
201    fn update(&mut self, symbol: Option<&Self::Symbol>) {
202        if let Some(s) = symbol {
203            self.model.update(s);
204            self.remaining -= 1;
205        }
206    }
207}
208
209/// Fixed-length encoding/decoding errors
210#[derive(Debug, thiserror::Error)]
211pub enum Error<E>
212where
213    E: std::error::Error,
214{
215    /// Model received an EOF when it expected more symbols
216    #[error("Unexpected EOF")]
217    UnexpectedEof,
218
219    /// Model received a symbol when it expected an EOF
220    #[error("Unexpected Symbol")]
221    UnexpectedSymbol,
222
223    /// The model received an invalid symbol
224    #[error(transparent)]
225    Value(E),
226}