arithmetic_coding/
decoder.rs

1//! The [`Decoder`] half of the arithmetic coding library.
2
3use std::{io, ops::Range};
4
5use bitstream_io::BitRead;
6
7use crate::{common, BitStore, Model};
8
9// this algorithm is derived from this article - https://marknelson.us/posts/2014/10/19/data-compression-with-arithmetic-coding.html
10
11/// An arithmetic decoder
12///
13/// An arithmetic decoder converts a stream of bytes into a stream of some
14/// output symbol, using a predictive [`Model`].
15#[derive(Debug)]
16pub struct Decoder<M, R>
17where
18    M: Model,
19    R: BitRead,
20{
21    model: M,
22    state: State<M::B, R>,
23}
24
25trait BitReadExt {
26    fn next_bit(&mut self) -> io::Result<Option<bool>>;
27}
28
29impl<R: BitRead> BitReadExt for R {
30    fn next_bit(&mut self) -> io::Result<Option<bool>> {
31        match self.read_bit() {
32            Ok(bit) => Ok(Some(bit)),
33            Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => Ok(None),
34            Err(e) => Err(e),
35        }
36    }
37}
38
39impl<M, R> Decoder<M, R>
40where
41    M: Model,
42    R: BitRead,
43{
44    /// Construct a new [`Decoder`]
45    ///
46    /// The 'precision' of the encoder is maximised, based on the number of bits
47    /// needed to represent the [`Model::denominator`]. 'precision' bits is
48    /// equal to [`u32::BITS`] - [`Model::denominator`] bits.
49    ///
50    /// # Panics
51    ///
52    /// The calculation of the number of bits used for 'precision' is subject to
53    /// the following constraints:
54    ///
55    /// - The total available bits is [`u32::BITS`]
56    /// - The precision must use at least 2 more bits than that needed to
57    ///   represent [`Model::denominator`]
58    ///
59    /// If these constraints cannot be satisfied this method will panic in debug
60    /// builds
61    pub fn new(model: M, input: R) -> Self {
62        let frequency_bits = model.max_denominator().log2() + 1;
63        let precision = M::B::BITS - frequency_bits;
64
65        Self::with_precision(model, input, precision)
66    }
67
68    /// Construct a new [`Decoder`] with a custom precision
69    ///
70    /// # Panics
71    ///
72    /// The calculation of the number of bits used for 'precision' is subject to
73    /// the following constraints:
74    ///
75    /// - The total available bits is [`BitStore::BITS`]
76    /// - The precision must use at least 2 more bits than that needed to
77    ///   represent [`Model::denominator`]
78    ///
79    /// If these constraints cannot be satisfied this method will panic in debug
80    /// builds
81    pub fn with_precision(model: M, input: R, precision: u32) -> Self {
82        let frequency_bits = model.max_denominator().log2() + 1;
83        debug_assert!(
84            (precision >= (frequency_bits + 2)),
85            "not enough bits of precision to prevent overflow/underflow",
86        );
87        debug_assert!(
88            (frequency_bits + precision) <= M::B::BITS,
89            "not enough bits in BitStore to support the required precision",
90        );
91
92        let state = State::new(precision, input);
93
94        Self { model, state }
95    }
96
97    /// todo
98    pub const fn with_state(state: State<M::B, R>, model: M) -> Self {
99        Self { model, state }
100    }
101
102    /// Return an iterator over the decoded symbols.
103    ///
104    /// The iterator will continue returning symbols until EOF is reached
105    pub fn decode_all(&mut self) -> DecodeIter<M, R> {
106        DecodeIter { decoder: self }
107    }
108
109    /// Read the next symbol from the stream of bits
110    ///
111    /// This method will return `Ok(None)` when EOF is reached.
112    ///
113    /// # Errors
114    ///
115    /// This method can fail if the underlying [`BitRead`] cannot be read from.
116    #[allow(clippy::missing_panics_doc)]
117    pub fn decode(&mut self) -> io::Result<Option<M::Symbol>> {
118        self.state.initialise()?;
119
120        let denominator = self.model.denominator();
121        debug_assert!(
122            denominator <= self.model.max_denominator(),
123            "denominator is greater than maximum!"
124        );
125        let value = self.state.value(denominator);
126        let symbol = self.model.symbol(value);
127
128        let p = self
129            .model
130            .probability(symbol.as_ref())
131            .expect("this should not be able to fail. Check the implementation of the model.");
132
133        self.state.scale(p, denominator)?;
134        self.model.update(symbol.as_ref());
135
136        Ok(symbol)
137    }
138
139    /// Reuse the internal state of the Decoder with a new model.
140    ///
141    /// Allows for chaining multiple sequences of symbols from a single stream
142    /// of bits
143    pub fn chain<X>(self, model: X) -> Decoder<X, R>
144    where
145        X: Model<B = M::B>,
146    {
147        Decoder {
148            model,
149            state: self.state,
150        }
151    }
152
153    /// todo
154    pub fn into_inner(self) -> (M, State<M::B, R>) {
155        (self.model, self.state)
156    }
157}
158
159/// The iterator returned by the [`Model::decode_all`] method
160#[allow(missing_debug_implementations)]
161pub struct DecodeIter<'a, M, R>
162where
163    M: Model,
164    R: BitRead,
165{
166    decoder: &'a mut Decoder<M, R>,
167}
168
169impl<M, R> Iterator for DecodeIter<'_, M, R>
170where
171    M: Model,
172    R: BitRead,
173{
174    type Item = io::Result<M::Symbol>;
175
176    fn next(&mut self) -> Option<Self::Item> {
177        self.decoder.decode().transpose()
178    }
179}
180
181/// A convenience struct which stores the internal state of an [`Decoder`].
182#[derive(Debug)]
183pub struct State<B, R>
184where
185    B: BitStore,
186    R: BitRead,
187{
188    state: common::State<B>,
189    input: R,
190    x: B,
191    uninitialised: bool,
192}
193
194impl<B, R> State<B, R>
195where
196    B: BitStore,
197    R: BitRead,
198{
199    /// todo
200    pub fn new(precision: u32, input: R) -> Self {
201        let state = common::State::new(precision);
202        let x = B::ZERO;
203
204        Self {
205            state,
206            input,
207            x,
208            uninitialised: true,
209        }
210    }
211
212    fn normalise(&mut self) -> io::Result<()> {
213        while self.state.high < self.state.half() || self.state.low >= self.state.half() {
214            if self.state.high < self.state.half() {
215                self.state.high <<= 1;
216                self.state.low <<= 1;
217                self.x <<= 1;
218            } else {
219                // self.low >= self.half()
220                self.state.low = (self.state.low - self.state.half()) << 1;
221                self.state.high = (self.state.high - self.state.half()) << 1;
222                self.x = (self.x - self.state.half()) << 1;
223            }
224
225            if self.input.next_bit()? == Some(true) {
226                self.x += B::ONE;
227            }
228        }
229
230        while self.state.low >= self.state.quarter()
231            && self.state.high < (self.state.three_quarter())
232        {
233            self.state.low = (self.state.low - self.state.quarter()) << 1;
234            self.state.high = (self.state.high - self.state.quarter()) << 1;
235            self.x = (self.x - self.state.quarter()) << 1;
236
237            if self.input.next_bit()? == Some(true) {
238                self.x += B::ONE;
239            }
240        }
241
242        Ok(())
243    }
244
245    fn scale(&mut self, p: Range<B>, denominator: B) -> io::Result<()> {
246        self.state.scale(p, denominator);
247        self.normalise()
248    }
249
250    fn value(&self, denominator: B) -> B {
251        let range = self.state.high - self.state.low + B::ONE;
252        ((self.x - self.state.low + B::ONE) * denominator - B::ONE) / range
253    }
254
255    fn fill(&mut self) -> io::Result<()> {
256        for _ in 0..self.state.precision {
257            self.x <<= 1;
258            if self.input.next_bit()? == Some(true) {
259                self.x += B::ONE;
260            }
261        }
262        Ok(())
263    }
264
265    fn initialise(&mut self) -> io::Result<()> {
266        if self.uninitialised {
267            self.fill()?;
268            self.uninitialised = false;
269        }
270        Ok(())
271    }
272}