arithmetic_coding_core/model/
max_length.rs

1//! Helper trait for creating fixed-length Models
2
3use std::ops::Range;
4
5use crate::BitStore;
6
7/// A [`Model`] is used to calculate the probability of a given symbol occurring
8/// in a sequence.
9///
10/// The [`Model`] is used both for encoding and decoding. A
11/// 'max-length' model has a maximum length. The compressed size of a message
12/// equal to the maximum length is larger than with a
13/// [`fixed_length::Model`](crate::fixed_length::Model), but smaller than with a
14/// [`Model`](crate::Model).
15///
16/// A max-length model can be converted into a regular model using the
17/// convenience [`Wrapper`] type.
18///
19/// The more accurately a [`Model`] is able to predict the next symbol, the
20/// greater the compression ratio will be.
21///
22/// # Example
23///
24/// ```
25/// # use std::convert::Infallible;
26/// # use std::ops::Range;
27/// #
28/// # use arithmetic_coding_core::max_length;
29///
30/// pub enum Symbol {
31///     A,
32///     B,
33///     C,
34/// }
35///
36/// pub struct MyModel;
37///
38/// impl max_length::Model for MyModel {
39///     type B = u32;
40///     type Symbol = Symbol;
41///     type ValueError = Infallible;
42///
43///     fn probability(&self, symbol: Option<&Self::Symbol>) -> Result<Range<u32>, Infallible> {
44///         Ok(match symbol {
45///             Some(Symbol::A) => 0..1,
46///             Some(Symbol::B) => 1..2,
47///             Some(Symbol::C) => 2..3,
48///             None => 3..4,
49///         })
50///     }
51///
52///     fn symbol(&self, value: Self::B) -> Option<Self::Symbol> {
53///         match value {
54///             0..1 => Some(Symbol::A),
55///             1..2 => Some(Symbol::B),
56///             2..3 => Some(Symbol::C),
57///             3..4 => None,
58///             _ => unreachable!(),
59///         }
60///     }
61///
62///     fn max_denominator(&self) -> u32 {
63///         4
64///     }
65///
66///     fn max_length(&self) -> usize {
67///         3
68///     }
69/// }
70/// ```
71pub trait Model {
72    /// The type of symbol this [`Model`] describes
73    type Symbol;
74
75    /// Invalid symbol error
76    type ValueError: std::error::Error;
77
78    /// The internal representation to use for storing integers
79    type B: BitStore;
80
81    /// Given a symbol, return an interval representing the probability of that
82    /// symbol occurring.
83    ///
84    /// This is given as a range, over the denominator given by
85    /// [`Model::denominator`]. This range should in general include `EOF`,
86    /// which is denoted by `None`.
87    ///
88    /// For example, from the set {heads, tails}, the interval representing
89    /// heads could be `0..1`, and tails would be `1..2`, and `EOF` could be
90    /// `2..3` (with a denominator of `3`).
91    ///
92    /// This is the inverse of the [`Model::symbol`] method
93    ///
94    /// # Errors
95    ///
96    /// This returns a custom error if the given symbol is not valid
97    fn probability(
98        &self,
99        symbol: Option<&Self::Symbol>,
100    ) -> Result<Range<Self::B>, Self::ValueError>;
101
102    /// The denominator for probability ranges. See [`Model::probability`].
103    ///
104    /// By default this method simply returns the [`Model::max_denominator`],
105    /// which is suitable for non-adaptive models.
106    ///
107    /// In adaptive models this value may change, however it should never exceed
108    /// [`Model::max_denominator`], or it becomes possible for the
109    /// [`Encoder`](crate::Encoder) and [`Decoder`](crate::Decoder) to panic due
110    /// to overflow or underflow.
111    fn denominator(&self) -> Self::B {
112        self.max_denominator()
113    }
114
115    /// The maximum denominator used for probability ranges. See
116    /// [`Model::probability`].
117    ///
118    /// This value is used to calculate an appropriate precision for the
119    /// encoding, therefore this value must not change, and
120    /// [`Model::denominator`] must never exceed it.
121    fn max_denominator(&self) -> Self::B;
122
123    /// Given a value, return the symbol whose probability range it falls in.
124    ///
125    /// `None` indicates `EOF`
126    ///
127    /// This is the inverse of the [`Model::probability`] method
128    fn symbol(&self, value: Self::B) -> Option<Self::Symbol>;
129
130    /// Update the current state of the model with the latest symbol.
131    ///
132    /// This method only needs to be implemented for 'adaptive' models. It's a
133    /// no-op by default.
134    fn update(&mut self, _symbol: &Self::Symbol) {}
135
136    /// The maximum number of symbols to encode
137    fn max_length(&self) -> usize;
138}
139
140/// A wrapper which converts a [`max_length::Model`](Model) to a
141/// [`crate::Model`].
142#[derive(Debug, Clone)]
143pub struct Wrapper<M>
144where
145    M: Model,
146{
147    model: M,
148    remaining: usize,
149}
150
151impl<M> Wrapper<M>
152where
153    M: Model,
154{
155    /// Construct a new wrapper from a [`Model`]
156    pub fn new(model: M) -> Self {
157        let remaining = model.max_length();
158        Self { model, remaining }
159    }
160}
161
162impl<M> crate::Model for Wrapper<M>
163where
164    M: Model,
165{
166    type B = M::B;
167    type Symbol = M::Symbol;
168    type ValueError = Error<M::ValueError>;
169
170    fn probability(
171        &self,
172        symbol: Option<&Self::Symbol>,
173    ) -> Result<Range<Self::B>, Self::ValueError> {
174        if self.remaining == 0 {
175            if symbol.is_some() {
176                Err(Error::UnexpectedSymbol)
177            } else {
178                // got an EOF when we expected it, return a 100% probability
179                Ok(Self::B::ZERO..self.denominator())
180            }
181        } else {
182            self.model
183                .probability(symbol)
184                .map_err(Self::ValueError::Value)
185        }
186    }
187
188    fn max_denominator(&self) -> Self::B {
189        self.model.max_denominator()
190    }
191
192    fn symbol(&self, value: Self::B) -> Option<Self::Symbol> {
193        if self.remaining > 0 {
194            self.model.symbol(value)
195        } else {
196            None
197        }
198    }
199
200    fn denominator(&self) -> Self::B {
201        self.model.denominator()
202    }
203
204    fn update(&mut self, symbol: Option<&Self::Symbol>) {
205        if let Some(s) = symbol {
206            self.model.update(s);
207            self.remaining -= 1;
208        }
209    }
210}
211
212/// Fixed-length encoding/decoding errors
213#[derive(Debug, thiserror::Error)]
214pub enum Error<E>
215where
216    E: std::error::Error,
217{
218    /// Model received a symbol when it expected an EOF
219    #[error("Unexpected Symbol")]
220    UnexpectedSymbol,
221
222    /// The model received an invalid symbol
223    #[error(transparent)]
224    Value(E),
225}