arithmetic_coding_adder_dep/
decoder.rs

1//! The [`Decoder`] half of the arithmetic coding library.
2
3use std::{io, ops::Range};
4use std::marker::PhantomData;
5
6use bitstream_io::BitRead;
7
8use crate::{BitStore, Model};
9
10// this algorithm is derived from this article - https://marknelson.us/posts/2014/10/19/data-compression-with-arithmetic-coding.html
11
12/// An arithmetic decoder
13///
14/// An arithmetic decoder converts a stream of bytes into a stream of some
15/// output symbol, using a predictive [`Model`].
16#[derive(Debug)]
17pub struct Decoder<M, R>
18where
19    M: Model,
20    R: BitRead,
21{
22    /// The model used to predict the next symbol
23    pub model: M,
24    state: State<M::B, R>,
25}
26
27trait BitReadExt {
28    fn next_bit(&mut self) -> io::Result<Option<bool>>;
29}
30
31impl<R: BitRead> BitReadExt for R {
32    fn next_bit(&mut self) -> io::Result<Option<bool>> {
33        match self.read_bit() {
34            Ok(bit) => Ok(Some(bit)),
35            Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => Ok(None),
36            Err(e) => Err(e),
37        }
38    }
39}
40
41impl<M, R> Decoder<M, R>
42where
43    M: Model,
44    R: BitRead,
45{
46    /// Construct a new [`Decoder`]
47    ///
48    /// The 'precision' of the encoder is maximised, based on the number of bits
49    /// needed to represent the [`Model::denominator`]. 'precision' bits is
50    /// equal to [`u32::BITS`] - [`Model::denominator`] bits.
51    ///
52    /// # Panics
53    ///
54    /// The calculation of the number of bits used for 'precision' is subject to
55    /// the following constraints:
56    ///
57    /// - The total available bits is [`u32::BITS`]
58    /// - The precision must use at least 2 more bits than that needed to
59    ///   represent [`Model::denominator`]
60    ///
61    /// If these constraints cannot be satisfied this method will panic in debug
62    /// builds
63    pub fn new(model: M) -> Self {
64        let frequency_bits = model.max_denominator().log2() + 1;
65        let precision = M::B::BITS - frequency_bits;
66
67        Self::with_precision(model, precision)
68    }
69
70    /// Construct a new [`Decoder`] with a custom precision
71    ///
72    /// # Panics
73    ///
74    /// The calculation of the number of bits used for 'precision' is subject to
75    /// the following constraints:
76    ///
77    /// - The total available bits is [`BitStore::BITS`]
78    /// - The precision must use at least 2 more bits than that needed to
79    ///   represent [`Model::denominator`]
80    ///
81    /// If these constraints cannot be satisfied this method will panic in debug
82    /// builds
83    pub fn with_precision(model: M, precision: u32) -> Self {
84        let frequency_bits = model.max_denominator().log2() + 1;
85        debug_assert!(
86            (precision >= (frequency_bits + 2)),
87            "not enough bits of precision to prevent overflow/underflow",
88        );
89        debug_assert!(
90            (frequency_bits + precision) <= M::B::BITS,
91            "not enough bits in BitStore to support the required precision",
92        );
93
94        let state = State::new(precision);
95
96        Self { model, state }
97    }
98
99    /// todo
100    pub const fn with_state(state: State<M::B, R>, model: M) -> Self {
101        Self { model, state }
102    }
103
104    /// Return an iterator over the decoded symbols.
105    ///
106    /// The iterator will continue returning symbols until EOF is reached
107    pub fn decode_all<'a>(&'a mut self, input: &'a mut R) -> DecodeIter<M, R> {
108        DecodeIter { decoder: self, input}
109    }
110
111    /// Read the next symbol from the stream of bits
112    ///
113    /// This method will return `Ok(None)` when EOF is reached.
114    ///
115    /// # Errors
116    ///
117    /// This method can fail if the underlying [`BitRead`] cannot be read from.
118    pub fn decode(&mut self, input: &mut R) -> io::Result<Option<M::Symbol>> {
119        self.state.initialise(input)?;
120
121        let denominator = self.model.denominator();
122        debug_assert!(
123            denominator <= self.model.max_denominator(),
124            "denominator is greater than maximum!"
125        );
126        let value = self.state.value(denominator);
127        let symbol = self.model.symbol(value);
128
129        let p = self
130            .model
131            .probability(symbol.as_ref())
132            .expect("this should not be able to fail. Check the implementation of the model.");
133
134        self.state.scale(p, denominator, input)?;
135        self.model.update(symbol.as_ref());
136
137        Ok(symbol)
138    }
139
140    /// Reuse the internal state of the Decoder with a new model.
141    ///
142    /// Allows for chaining multiple sequences of symbols from a single stream
143    /// of bits
144    pub fn chain<X>(self, model: X) -> Decoder<X, R>
145    where
146        X: Model<B = M::B>,
147    {
148        Decoder {
149            model,
150            state: self.state,
151        }
152    }
153
154    /// todo
155    pub fn into_inner(self) -> (M, State<M::B, R>) {
156        (self.model, self.state)
157    }
158}
159
160/// The iterator returned by the [`Model::decode_all`] method
161#[allow(missing_debug_implementations)]
162pub struct DecodeIter<'a, M, R>
163where
164    M: Model,
165    R: BitRead,
166{
167    decoder: &'a mut Decoder<M, R>,
168    input: &'a mut R,
169}
170
171impl<'a, M, R> Iterator for DecodeIter<'a, M, R>
172where
173    M: Model,
174    R: BitRead,
175{
176    type Item = io::Result<M::Symbol>;
177
178    fn next(&mut self) -> Option<Self::Item> {
179        self.decoder.decode(self.input).transpose()
180    }
181}
182
183/// A convenience struct which stores the internal state of an [`Decoder`].
184#[derive(Debug)]
185pub struct State<B, R>
186where
187    B: BitStore,
188    R: BitRead,
189{
190    precision: u32,
191    low: B,
192    high: B,
193    _marker: PhantomData<R>,
194    x: B,
195    uninitialised: bool,
196}
197
198impl<B, R> State<B, R>
199where
200    B: BitStore,
201    R: BitRead,
202{
203    /// todo
204    #[must_use] pub fn new(precision: u32) -> Self {
205        let low = B::ZERO;
206        let high = B::ONE << precision;
207        let x = B::ZERO;
208
209        Self {
210            precision,
211            low,
212            high,
213            _marker: PhantomData,
214            x,
215            uninitialised: true,
216        }
217    }
218
219    fn half(&self) -> B {
220        B::ONE << (self.precision - 1)
221    }
222
223    fn quarter(&self) -> B {
224        B::ONE << (self.precision - 2)
225    }
226
227    fn three_quarter(&self) -> B {
228        self.half() + self.quarter()
229    }
230
231    fn normalise(&mut self, input: &mut R) -> io::Result<()> {
232        while self.high < self.half() || self.low >= self.half() {
233            if self.high < self.half() {
234                self.high <<= 1;
235                self.low <<= 1;
236                self.x <<= 1;
237            } else {
238                // self.low >= self.half()
239                self.low = (self.low - self.half()) << 1;
240                self.high = (self.high - self.half()) << 1;
241                self.x = (self.x - self.half()) << 1;
242            }
243
244            match input.next_bit()? {
245                Some(true) => {
246                    self.x += B::ONE;
247                }
248                Some(false) | None => (),
249            }
250        }
251
252        while self.low >= self.quarter() && self.high < (self.three_quarter()) {
253            self.low = (self.low - self.quarter()) << 1;
254            self.high = (self.high - self.quarter()) << 1;
255            self.x = (self.x - self.quarter()) << 1;
256
257            match input.next_bit()? {
258                Some(true) => {
259                    self.x += B::ONE;
260                }
261                Some(false) | None => (),
262            }
263        }
264
265        Ok(())
266    }
267
268    fn scale(&mut self, p: Range<B>, denominator: B, input: &mut R) -> io::Result<()> {
269        let range = self.high - self.low + B::ONE;
270
271        self.high = self.low + (range * p.end) / denominator - B::ONE;
272        self.low += (range * p.start) / denominator;
273
274        self.normalise(input)
275    }
276
277    fn value(&self, denominator: B) -> B {
278        let range = self.high - self.low + B::ONE;
279        ((self.x - self.low + B::ONE) * denominator - B::ONE) / range
280    }
281
282    fn fill(&mut self, input: &mut R) -> io::Result<()> {
283        for _ in 0..self.precision {
284            self.x <<= 1;
285            match input.next_bit()? {
286                Some(true) => {
287                    self.x += B::ONE;
288                }
289                Some(false) | None => (),
290            }
291        }
292        Ok(())
293    }
294
295    fn initialise(&mut self, input: &mut R) -> io::Result<()> {
296        if self.uninitialised {
297            self.fill(input)?;
298            self.uninitialised = false;
299        }
300        Ok(())
301    }
302}