arithmetic_coding_core/model/
max_length.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
//! Helper trait for creating fixed-length Models

use std::ops::Range;

use crate::BitStore;

/// A [`Model`] is used to calculate the probability of a given symbol occurring
/// in a sequence.
///
/// The [`Model`] is used both for encoding and decoding. A
/// 'max-length' model has a maximum length. The compressed size of a message
/// equal to the maximum length is larger than with a
/// [`fixed_length::Model`](crate::fixed_length::Model), but smaller than with a
/// [`Model`](crate::Model).
///
/// A max-length model can be converted into a regular model using the
/// convenience [`Wrapper`] type.
///
/// The more accurately a [`Model`] is able to predict the next symbol, the
/// greater the compression ratio will be.
///
/// # Example
///
/// ```
/// # use std::convert::Infallible;
/// # use std::ops::Range;
/// #
/// # use arithmetic_coding_core::max_length;
///
/// pub enum Symbol {
///     A,
///     B,
///     C,
/// }
///
/// pub struct MyModel;
///
/// impl max_length::Model for MyModel {
///     type B = u32;
///     type Symbol = Symbol;
///     type ValueError = Infallible;
///
///     fn probability(&self, symbol: Option<&Self::Symbol>) -> Result<Range<u32>, Infallible> {
///         Ok(match symbol {
///             Some(Symbol::A) => 0..1,
///             Some(Symbol::B) => 1..2,
///             Some(Symbol::C) => 2..3,
///             None => 3..4,
///         })
///     }
///
///     fn symbol(&self, value: Self::B) -> Option<Self::Symbol> {
///         match value {
///             0..1 => Some(Symbol::A),
///             1..2 => Some(Symbol::B),
///             2..3 => Some(Symbol::C),
///             3..4 => None,
///             _ => unreachable!(),
///         }
///     }
///
///     fn max_denominator(&self) -> u32 {
///         4
///     }
///
///     fn max_length(&self) -> usize {
///         3
///     }
/// }
/// ```
pub trait Model {
    /// The type of symbol this [`Model`] describes
    type Symbol;

    /// Invalid symbol error
    type ValueError: std::error::Error;

    /// The internal representation to use for storing integers
    type B: BitStore;

    /// Given a symbol, return an interval representing the probability of that
    /// symbol occurring.
    ///
    /// This is given as a range, over the denominator given by
    /// [`Model::denominator`]. This range should in general include `EOF`,
    /// which is denoted by `None`.
    ///
    /// For example, from the set {heads, tails}, the interval representing
    /// heads could be `0..1`, and tails would be `1..2`, and `EOF` could be
    /// `2..3` (with a denominator of `3`).
    ///
    /// This is the inverse of the [`Model::symbol`] method
    ///
    /// # Errors
    ///
    /// This returns a custom error if the given symbol is not valid
    fn probability(
        &self,
        symbol: Option<&Self::Symbol>,
    ) -> Result<Range<Self::B>, Self::ValueError>;

    /// The denominator for probability ranges. See [`Model::probability`].
    ///
    /// By default this method simply returns the [`Model::max_denominator`],
    /// which is suitable for non-adaptive models.
    ///
    /// In adaptive models this value may change, however it should never exceed
    /// [`Model::max_denominator`], or it becomes possible for the
    /// [`Encoder`](crate::Encoder) and [`Decoder`](crate::Decoder) to panic due
    /// to overflow or underflow.
    fn denominator(&self) -> Self::B {
        self.max_denominator()
    }

    /// The maximum denominator used for probability ranges. See
    /// [`Model::probability`].
    ///
    /// This value is used to calculate an appropriate precision for the
    /// encoding, therefore this value must not change, and
    /// [`Model::denominator`] must never exceed it.
    fn max_denominator(&self) -> Self::B;

    /// Given a value, return the symbol whose probability range it falls in.
    ///
    /// `None` indicates `EOF`
    ///
    /// This is the inverse of the [`Model::probability`] method
    fn symbol(&self, value: Self::B) -> Option<Self::Symbol>;

    /// Update the current state of the model with the latest symbol.
    ///
    /// This method only needs to be implemented for 'adaptive' models. It's a
    /// no-op by default.
    fn update(&mut self, _symbol: &Self::Symbol) {}

    /// The maximum number of symbols to encode
    fn max_length(&self) -> usize;
}

/// A wrapper which converts a [`max_length::Model`](Model) to a
/// [`crate::Model`].
#[derive(Debug, Clone)]
pub struct Wrapper<M>
where
    M: Model,
{
    model: M,
    remaining: usize,
}

impl<M> Wrapper<M>
where
    M: Model,
{
    /// Construct a new wrapper from a [`Model`]
    pub fn new(model: M) -> Self {
        let remaining = model.max_length();
        Self { model, remaining }
    }
}

impl<M> crate::Model for Wrapper<M>
where
    M: Model,
{
    type B = M::B;
    type Symbol = M::Symbol;
    type ValueError = Error<M::ValueError>;

    fn probability(
        &self,
        symbol: Option<&Self::Symbol>,
    ) -> Result<Range<Self::B>, Self::ValueError> {
        if self.remaining == 0 {
            if symbol.is_some() {
                Err(Error::UnexpectedSymbol)
            } else {
                // got an EOF when we expected it, return a 100% probability
                Ok(Self::B::ZERO..self.denominator())
            }
        } else {
            self.model
                .probability(symbol)
                .map_err(Self::ValueError::Value)
        }
    }

    fn max_denominator(&self) -> Self::B {
        self.model.max_denominator()
    }

    fn symbol(&self, value: Self::B) -> Option<Self::Symbol> {
        if self.remaining > 0 {
            self.model.symbol(value)
        } else {
            None
        }
    }

    fn denominator(&self) -> Self::B {
        self.model.denominator()
    }

    fn update(&mut self, symbol: Option<&Self::Symbol>) {
        if let Some(s) = symbol {
            self.model.update(s);
            self.remaining -= 1;
        }
    }
}

/// Fixed-length encoding/decoding errors
#[derive(Debug, thiserror::Error)]
pub enum Error<E>
where
    E: std::error::Error,
{
    /// Model received a symbol when it expected an EOF
    #[error("Unexpected Symbol")]
    UnexpectedSymbol,

    /// The model received an invalid symbol
    #[error(transparent)]
    Value(E),
}