arithmetic_coding_core/model/max_length.rs
1//! Helper trait for creating fixed-length Models
2
3use std::ops::Range;
4
5use crate::BitStore;
6
7/// A [`Model`] is used to calculate the probability of a given symbol occurring
8/// in a sequence.
9///
10/// The [`Model`] is used both for encoding and decoding. A
11/// 'max-length' model has a maximum length. The compressed size of a message
12/// equal to the maximum length is larger than with a
13/// [`fixed_length::Model`](crate::fixed_length::Model), but smaller than with a
14/// [`Model`](crate::Model).
15///
16/// A max-length model can be converted into a regular model using the
17/// convenience [`Wrapper`] type.
18///
19/// The more accurately a [`Model`] is able to predict the next symbol, the
20/// greater the compression ratio will be.
21///
22/// # Example
23///
24/// ```
25/// # use std::convert::Infallible;
26/// # use std::ops::Range;
27/// #
28/// # use arithmetic_coding_core::max_length;
29///
30/// pub enum Symbol {
31/// A,
32/// B,
33/// C,
34/// }
35///
36/// pub struct MyModel;
37///
38/// impl max_length::Model for MyModel {
39/// type B = u32;
40/// type Symbol = Symbol;
41/// type ValueError = Infallible;
42///
43/// fn probability(&self, symbol: Option<&Self::Symbol>) -> Result<Range<u32>, Infallible> {
44/// Ok(match symbol {
45/// Some(Symbol::A) => 0..1,
46/// Some(Symbol::B) => 1..2,
47/// Some(Symbol::C) => 2..3,
48/// None => 3..4,
49/// })
50/// }
51///
52/// fn symbol(&self, value: Self::B) -> Option<Self::Symbol> {
53/// match value {
54/// 0..1 => Some(Symbol::A),
55/// 1..2 => Some(Symbol::B),
56/// 2..3 => Some(Symbol::C),
57/// 3..4 => None,
58/// _ => unreachable!(),
59/// }
60/// }
61///
62/// fn max_denominator(&self) -> u32 {
63/// 4
64/// }
65///
66/// fn max_length(&self) -> usize {
67/// 3
68/// }
69/// }
70/// ```
71pub trait Model {
72 /// The type of symbol this [`Model`] describes
73 type Symbol;
74
75 /// Invalid symbol error
76 type ValueError: std::error::Error;
77
78 /// The internal representation to use for storing integers
79 type B: BitStore;
80
81 /// Given a symbol, return an interval representing the probability of that
82 /// symbol occurring.
83 ///
84 /// This is given as a range, over the denominator given by
85 /// [`Model::denominator`]. This range should in general include `EOF`,
86 /// which is denoted by `None`.
87 ///
88 /// For example, from the set {heads, tails}, the interval representing
89 /// heads could be `0..1`, and tails would be `1..2`, and `EOF` could be
90 /// `2..3` (with a denominator of `3`).
91 ///
92 /// This is the inverse of the [`Model::symbol`] method
93 ///
94 /// # Errors
95 ///
96 /// This returns a custom error if the given symbol is not valid
97 fn probability(
98 &self,
99 symbol: Option<&Self::Symbol>,
100 ) -> Result<Range<Self::B>, Self::ValueError>;
101
102 /// The denominator for probability ranges. See [`Model::probability`].
103 ///
104 /// By default this method simply returns the [`Model::max_denominator`],
105 /// which is suitable for non-adaptive models.
106 ///
107 /// In adaptive models this value may change, however it should never exceed
108 /// [`Model::max_denominator`], or it becomes possible for the
109 /// [`Encoder`](crate::Encoder) and [`Decoder`](crate::Decoder) to panic due
110 /// to overflow or underflow.
111 fn denominator(&self) -> Self::B {
112 self.max_denominator()
113 }
114
115 /// The maximum denominator used for probability ranges. See
116 /// [`Model::probability`].
117 ///
118 /// This value is used to calculate an appropriate precision for the
119 /// encoding, therefore this value must not change, and
120 /// [`Model::denominator`] must never exceed it.
121 fn max_denominator(&self) -> Self::B;
122
123 /// Given a value, return the symbol whose probability range it falls in.
124 ///
125 /// `None` indicates `EOF`
126 ///
127 /// This is the inverse of the [`Model::probability`] method
128 fn symbol(&self, value: Self::B) -> Option<Self::Symbol>;
129
130 /// Update the current state of the model with the latest symbol.
131 ///
132 /// This method only needs to be implemented for 'adaptive' models. It's a
133 /// no-op by default.
134 fn update(&mut self, _symbol: &Self::Symbol) {}
135
136 /// The maximum number of symbols to encode
137 fn max_length(&self) -> usize;
138}
139
140/// A wrapper which converts a [`max_length::Model`](Model) to a
141/// [`crate::Model`].
142#[derive(Debug, Clone)]
143pub struct Wrapper<M>
144where
145 M: Model,
146{
147 model: M,
148 remaining: usize,
149}
150
151impl<M> Wrapper<M>
152where
153 M: Model,
154{
155 /// Construct a new wrapper from a [`Model`]
156 pub fn new(model: M) -> Self {
157 let remaining = model.max_length();
158 Self { model, remaining }
159 }
160}
161
162impl<M> crate::Model for Wrapper<M>
163where
164 M: Model,
165{
166 type B = M::B;
167 type Symbol = M::Symbol;
168 type ValueError = Error<M::ValueError>;
169
170 fn probability(
171 &self,
172 symbol: Option<&Self::Symbol>,
173 ) -> Result<Range<Self::B>, Self::ValueError> {
174 if self.remaining == 0 {
175 if symbol.is_some() {
176 Err(Error::UnexpectedSymbol)
177 } else {
178 // got an EOF when we expected it, return a 100% probability
179 Ok(Self::B::ZERO..self.denominator())
180 }
181 } else {
182 self.model
183 .probability(symbol)
184 .map_err(Self::ValueError::Value)
185 }
186 }
187
188 fn max_denominator(&self) -> Self::B {
189 self.model.max_denominator()
190 }
191
192 fn symbol(&self, value: Self::B) -> Option<Self::Symbol> {
193 if self.remaining > 0 {
194 self.model.symbol(value)
195 } else {
196 None
197 }
198 }
199
200 fn denominator(&self) -> Self::B {
201 self.model.denominator()
202 }
203
204 fn update(&mut self, symbol: Option<&Self::Symbol>) {
205 if let Some(s) = symbol {
206 self.model.update(s);
207 self.remaining -= 1;
208 }
209 }
210}
211
212/// Fixed-length encoding/decoding errors
213#[derive(Debug, thiserror::Error)]
214pub enum Error<E>
215where
216 E: std::error::Error,
217{
218 /// Model received a symbol when it expected an EOF
219 #[error("Unexpected Symbol")]
220 UnexpectedSymbol,
221
222 /// The model received an invalid symbol
223 #[error(transparent)]
224 Value(E),
225}