arithmetic_coding_core/model/max_length.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
//! Helper trait for creating fixed-length Models
use std::ops::Range;
use crate::BitStore;
/// A [`Model`] is used to calculate the probability of a given symbol occurring
/// in a sequence.
///
/// The [`Model`] is used both for encoding and decoding. A
/// 'max-length' model has a maximum length. The compressed size of a message
/// equal to the maximum length is larger than with a
/// [`fixed_length::Model`](crate::fixed_length::Model), but smaller than with a
/// [`Model`](crate::Model).
///
/// A max-length model can be converted into a regular model using the
/// convenience [`Wrapper`] type.
///
/// The more accurately a [`Model`] is able to predict the next symbol, the
/// greater the compression ratio will be.
///
/// # Example
///
/// ```
/// # use std::convert::Infallible;
/// # use std::ops::Range;
/// #
/// # use arithmetic_coding_core::max_length;
///
/// pub enum Symbol {
/// A,
/// B,
/// C,
/// }
///
/// pub struct MyModel;
///
/// impl max_length::Model for MyModel {
/// type B = u32;
/// type Symbol = Symbol;
/// type ValueError = Infallible;
///
/// fn probability(&self, symbol: Option<&Self::Symbol>) -> Result<Range<u32>, Infallible> {
/// Ok(match symbol {
/// Some(Symbol::A) => 0..1,
/// Some(Symbol::B) => 1..2,
/// Some(Symbol::C) => 2..3,
/// None => 3..4,
/// })
/// }
///
/// fn symbol(&self, value: Self::B) -> Option<Self::Symbol> {
/// match value {
/// 0..1 => Some(Symbol::A),
/// 1..2 => Some(Symbol::B),
/// 2..3 => Some(Symbol::C),
/// 3..4 => None,
/// _ => unreachable!(),
/// }
/// }
///
/// fn max_denominator(&self) -> u32 {
/// 4
/// }
///
/// fn max_length(&self) -> usize {
/// 3
/// }
/// }
/// ```
pub trait Model {
/// The type of symbol this [`Model`] describes
type Symbol;
/// Invalid symbol error
type ValueError: std::error::Error;
/// The internal representation to use for storing integers
type B: BitStore;
/// Given a symbol, return an interval representing the probability of that
/// symbol occurring.
///
/// This is given as a range, over the denominator given by
/// [`Model::denominator`]. This range should in general include `EOF`,
/// which is denoted by `None`.
///
/// For example, from the set {heads, tails}, the interval representing
/// heads could be `0..1`, and tails would be `1..2`, and `EOF` could be
/// `2..3` (with a denominator of `3`).
///
/// This is the inverse of the [`Model::symbol`] method
///
/// # Errors
///
/// This returns a custom error if the given symbol is not valid
fn probability(
&self,
symbol: Option<&Self::Symbol>,
) -> Result<Range<Self::B>, Self::ValueError>;
/// The denominator for probability ranges. See [`Model::probability`].
///
/// By default this method simply returns the [`Model::max_denominator`],
/// which is suitable for non-adaptive models.
///
/// In adaptive models this value may change, however it should never exceed
/// [`Model::max_denominator`], or it becomes possible for the
/// [`Encoder`](crate::Encoder) and [`Decoder`](crate::Decoder) to panic due
/// to overflow or underflow.
fn denominator(&self) -> Self::B {
self.max_denominator()
}
/// The maximum denominator used for probability ranges. See
/// [`Model::probability`].
///
/// This value is used to calculate an appropriate precision for the
/// encoding, therefore this value must not change, and
/// [`Model::denominator`] must never exceed it.
fn max_denominator(&self) -> Self::B;
/// Given a value, return the symbol whose probability range it falls in.
///
/// `None` indicates `EOF`
///
/// This is the inverse of the [`Model::probability`] method
fn symbol(&self, value: Self::B) -> Option<Self::Symbol>;
/// Update the current state of the model with the latest symbol.
///
/// This method only needs to be implemented for 'adaptive' models. It's a
/// no-op by default.
fn update(&mut self, _symbol: &Self::Symbol) {}
/// The maximum number of symbols to encode
fn max_length(&self) -> usize;
}
/// A wrapper which converts a [`max_length::Model`](Model) to a
/// [`crate::Model`].
#[derive(Debug, Clone)]
pub struct Wrapper<M>
where
M: Model,
{
model: M,
remaining: usize,
}
impl<M> Wrapper<M>
where
M: Model,
{
/// Construct a new wrapper from a [`Model`]
pub fn new(model: M) -> Self {
let remaining = model.max_length();
Self { model, remaining }
}
}
impl<M> crate::Model for Wrapper<M>
where
M: Model,
{
type B = M::B;
type Symbol = M::Symbol;
type ValueError = Error<M::ValueError>;
fn probability(
&self,
symbol: Option<&Self::Symbol>,
) -> Result<Range<Self::B>, Self::ValueError> {
if self.remaining == 0 {
if symbol.is_some() {
Err(Error::UnexpectedSymbol)
} else {
// got an EOF when we expected it, return a 100% probability
Ok(Self::B::ZERO..self.denominator())
}
} else {
self.model
.probability(symbol)
.map_err(Self::ValueError::Value)
}
}
fn max_denominator(&self) -> Self::B {
self.model.max_denominator()
}
fn symbol(&self, value: Self::B) -> Option<Self::Symbol> {
if self.remaining > 0 {
self.model.symbol(value)
} else {
None
}
}
fn denominator(&self) -> Self::B {
self.model.denominator()
}
fn update(&mut self, symbol: Option<&Self::Symbol>) {
if let Some(s) = symbol {
self.model.update(s);
self.remaining -= 1;
}
}
}
/// Fixed-length encoding/decoding errors
#[derive(Debug, thiserror::Error)]
pub enum Error<E>
where
E: std::error::Error,
{
/// Model received a symbol when it expected an EOF
#[error("Unexpected Symbol")]
UnexpectedSymbol,
/// The model received an invalid symbol
#[error(transparent)]
Value(E),
}