arithmetic_coding_core/model/
one_shot.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
//! Helper trait for creating Models which only accept a single symbol

use std::ops::Range;

pub use crate::fixed_length::Wrapper;
use crate::{fixed_length, BitStore};

/// A [`Model`] is used to calculate the probability of a given symbol occurring
/// in a sequence.
///
/// The [`Model`] is used both for encoding and decoding. A
/// 'one-shot' only ever encodes a single symbol, and so does
/// not need to encode an EOF symbol.
///
/// A one-shot [`Model`] is a special case of the [`fixed_length::Model`].
///
/// A one-shot model can be converted into a regular model using the
/// convenience [`Wrapper`] type.
///
/// The more accurately a [`Model`] is able to predict the next symbol, the
/// greater the compression ratio will be.
///
/// # Example
///
/// ```
/// # use std::convert::Infallible;
/// # use std::ops::Range;
/// #
/// # use arithmetic_coding_core::one_shot;
///
/// pub enum Symbol {
///     A,
///     B,
///     C,
/// }
///
/// pub struct MyModel;
///
/// impl one_shot::Model for MyModel {
///     type B = u32;
///     type Symbol = Symbol;
///     type ValueError = Infallible;
///
///     fn probability(&self, symbol: &Self::Symbol) -> Result<Range<u32>, Infallible> {
///         Ok(match symbol {
///             Symbol::A => 0..1,
///             Symbol::B => 1..2,
///             Symbol::C => 2..3,
///         })
///     }
///
///     fn symbol(&self, value: Self::B) -> Self::Symbol {
///         match value {
///             0..1 => Symbol::A,
///             1..2 => Symbol::B,
///             2..3 => Symbol::C,
///             _ => unreachable!(),
///         }
///     }
///
///     fn max_denominator(&self) -> u32 {
///         3
///     }
/// }
/// ```
pub trait Model {
    /// The type of symbol this [`Model`] describes
    type Symbol;

    /// Invalid symbol error
    type ValueError: std::error::Error;

    /// The internal representation to use for storing integers
    type B: BitStore;

    /// Given a symbol, return an interval representing the probability of that
    /// symbol occurring.
    ///
    /// This is given as a range, over the denominator given by
    /// [`Model::denominator`]. This range should in general include `EOF`,
    /// which is denoted by `None`.
    ///
    /// For example, from the set {heads, tails}, the interval representing
    /// heads could be `0..1`, and tails would be `1..2`, and `EOF` could be
    /// `2..3` (with a denominator of `3`).
    ///
    /// This is the inverse of the [`Model::symbol`] method
    ///
    /// # Errors
    ///
    /// This returns a custom error if the given symbol is not valid
    fn probability(&self, symbol: &Self::Symbol) -> Result<Range<Self::B>, Self::ValueError>;

    /// The maximum denominator used for probability ranges. See
    /// [`Model::probability`].
    ///
    /// This value is used to calculate an appropriate precision for the
    /// encoding, therefore this value must not change, and
    /// [`Model::denominator`] must never exceed it.
    fn max_denominator(&self) -> Self::B;

    /// Given a value, return the symbol whose probability range it falls in.
    ///
    /// `None` indicates `EOF`
    ///
    /// This is the inverse of the [`Model::probability`] method
    fn symbol(&self, value: Self::B) -> Self::Symbol;
}

impl<T> fixed_length::Model for T
where
    T: Model,
{
    type B = T::B;
    type Symbol = T::Symbol;
    type ValueError = T::ValueError;

    fn probability(&self, symbol: &Self::Symbol) -> Result<Range<Self::B>, Self::ValueError> {
        Model::probability(self, symbol)
    }

    fn max_denominator(&self) -> Self::B {
        self.max_denominator()
    }

    fn symbol(&self, value: Self::B) -> Self::Symbol {
        Model::symbol(self, value)
    }

    fn length(&self) -> usize {
        1
    }

    fn denominator(&self) -> Self::B {
        self.max_denominator()
    }
}