arithmetic_coding_core/model/fixed_length.rs
1//! Helper trait for creating fixed-length Models
2
3use std::ops::Range;
4
5use crate::BitStore;
6
7/// A [`Model`] is used to calculate the probability of a given symbol occurring
8/// in a sequence.
9///
10/// The [`Model`] is used both for encoding and decoding. A
11/// 'fixed-length' model always expects an exact number of symbols, and so does
12/// not need to encode an EOF symbol.
13///
14/// A fixed length model can be converted into a regular model using the
15/// convenience [`Wrapper`] type.
16///
17/// The more accurately a [`Model`] is able to predict the next symbol, the
18/// greater the compression ratio will be.
19///
20/// # Example
21///
22/// ```
23/// # use std::convert::Infallible;
24/// # use std::ops::Range;
25/// #
26/// # use arithmetic_coding_core::fixed_length;
27///
28/// pub enum Symbol {
29/// A,
30/// B,
31/// C,
32/// }
33///
34/// pub struct MyModel;
35///
36/// impl fixed_length::Model for MyModel {
37/// type B = u32;
38/// type Symbol = Symbol;
39/// type ValueError = Infallible;
40///
41/// fn probability(&self, symbol: &Self::Symbol) -> Result<Range<u32>, Infallible> {
42/// Ok(match symbol {
43/// Symbol::A => 0..1,
44/// Symbol::B => 1..2,
45/// Symbol::C => 2..3,
46/// })
47/// }
48///
49/// fn symbol(&self, value: Self::B) -> Self::Symbol {
50/// match value {
51/// 0..1 => Symbol::A,
52/// 1..2 => Symbol::B,
53/// 2..3 => Symbol::C,
54/// _ => unreachable!(),
55/// }
56/// }
57///
58/// fn max_denominator(&self) -> u32 {
59/// 3
60/// }
61///
62/// fn length(&self) -> usize {
63/// 3
64/// }
65/// }
66/// ```
67pub trait Model {
68 /// The type of symbol this [`Model`] describes
69 type Symbol;
70
71 /// Invalid symbol error
72 type ValueError: std::error::Error;
73
74 /// The internal representation to use for storing integers
75 type B: BitStore;
76
77 /// Given a symbol, return an interval representing the probability of that
78 /// symbol occurring.
79 ///
80 /// This is given as a range, over the denominator given by
81 /// [`Model::denominator`]. This range should in general include `EOF`,
82 /// which is denoted by `None`.
83 ///
84 /// For example, from the set {heads, tails}, the interval representing
85 /// heads could be `0..1`, and tails would be `1..2`, and `EOF` could be
86 /// `2..3` (with a denominator of `3`).
87 ///
88 /// This is the inverse of the [`Model::symbol`] method
89 ///
90 /// # Errors
91 ///
92 /// This returns a custom error if the given symbol is not valid
93 fn probability(&self, symbol: &Self::Symbol) -> Result<Range<Self::B>, Self::ValueError>;
94
95 /// The denominator for probability ranges. See [`Model::probability`].
96 ///
97 /// By default this method simply returns the [`Model::max_denominator`],
98 /// which is suitable for non-adaptive models.
99 ///
100 /// In adaptive models this value may change, however it should never exceed
101 /// [`Model::max_denominator`], or it becomes possible for the
102 /// [`Encoder`](crate::Encoder) and [`Decoder`](crate::Decoder) to panic due
103 /// to overflow or underflow.
104 fn denominator(&self) -> Self::B {
105 self.max_denominator()
106 }
107
108 /// The maximum denominator used for probability ranges. See
109 /// [`Model::probability`].
110 ///
111 /// This value is used to calculate an appropriate precision for the
112 /// encoding, therefore this value must not change, and
113 /// [`Model::denominator`] must never exceed it.
114 fn max_denominator(&self) -> Self::B;
115
116 /// Given a value, return the symbol whose probability range it falls in.
117 ///
118 /// `None` indicates `EOF`
119 ///
120 /// This is the inverse of the [`Model::probability`] method
121 fn symbol(&self, value: Self::B) -> Self::Symbol;
122
123 /// Update the current state of the model with the latest symbol.
124 ///
125 /// This method only needs to be implemented for 'adaptive' models. It's a
126 /// no-op by default.
127 fn update(&mut self, _symbol: &Self::Symbol) {}
128
129 /// The total number of symbols to encode
130 fn length(&self) -> usize;
131}
132
133/// A wrapper which converts a [`fixed_length::Model`](Model) to a
134/// [`crate::Model`].
135#[derive(Debug, Clone)]
136pub struct Wrapper<M>
137where
138 M: Model,
139{
140 model: M,
141 remaining: usize,
142}
143
144impl<M> Wrapper<M>
145where
146 M: Model,
147{
148 /// Construct a new wrapper from a [`fixed_length::Model`](Model)
149 pub fn new(model: M) -> Self {
150 let remaining = model.length();
151 Self { model, remaining }
152 }
153}
154
155impl<M> crate::Model for Wrapper<M>
156where
157 M: Model,
158{
159 type B = M::B;
160 type Symbol = M::Symbol;
161 type ValueError = Error<M::ValueError>;
162
163 fn probability(
164 &self,
165 symbol: Option<&Self::Symbol>,
166 ) -> Result<Range<Self::B>, Self::ValueError> {
167 if self.remaining > 0 {
168 symbol.map_or(
169 // We are expecting more symbols, but got an EOF
170 Err(Self::ValueError::UnexpectedEof),
171 // Expected a symbol and got one. return the probability.
172 |s| self.model.probability(s).map_err(Self::ValueError::Value),
173 )
174 } else if symbol.is_some() {
175 // we should be finished, but got an extra symbol
176 Err(Error::UnexpectedSymbol)
177 } else {
178 // got an EOF when we expected it, return a 100% probability
179 Ok(Self::B::ZERO..self.denominator())
180 }
181 }
182
183 fn max_denominator(&self) -> Self::B {
184 self.model.max_denominator()
185 }
186
187 fn symbol(&self, value: Self::B) -> Option<Self::Symbol> {
188 if self.remaining > 0 {
189 Some(self.model.symbol(value))
190 } else {
191 None
192 }
193 }
194
195 fn denominator(&self) -> Self::B {
196 self.model.denominator()
197 }
198
199 fn update(&mut self, symbol: Option<&Self::Symbol>) {
200 if let Some(s) = symbol {
201 self.model.update(s);
202 self.remaining -= 1;
203 }
204 }
205}
206
207/// Fixed-length encoding/decoding errors
208#[derive(Debug, thiserror::Error)]
209pub enum Error<E>
210where
211 E: std::error::Error,
212{
213 /// Model received an EOF when it expected more symbols
214 #[error("Unexpected EOF")]
215 UnexpectedEof,
216
217 /// Model received a symbol when it expected an EOF
218 #[error("Unexpected Symbol")]
219 UnexpectedSymbol,
220
221 /// The model received an invalid symbol
222 #[error(transparent)]
223 Value(E),
224}