arithmetic_coding_adder_dep/
encoder.rs

1//! The [`Encoder`] half of the arithmetic coding library.
2
3use std::marker::PhantomData;
4use std::{io, ops::Range};
5
6use bitstream_io::BitWrite;
7
8use crate::Error::ValueError;
9use crate::{BitStore, Error, Model};
10
11// this algorithm is derived from this article - https://marknelson.us/posts/2014/10/19/data-compression-with-arithmetic-coding.html
12
13/// An arithmetic encoder
14///
15/// An arithmetic decoder converts a stream of symbols into a stream of bits,
16/// using a predictive [`Model`].
17#[derive(Debug)]
18pub struct Encoder<M, W>
19where
20    M: Model,
21    W: BitWrite,
22{
23    /// The model used for the encoder
24    pub model: M,
25    state: State<M::B, W>,
26}
27
28impl<M, W> Encoder<M, W>
29where
30    M: Model,
31    W: BitWrite,
32{
33    /// Construct a new [`Encoder`].
34    ///
35    /// The 'precision' of the encoder is maximised, based on the number of bits
36    /// needed to represent the [`Model::denominator`]. 'precision' bits is
37    /// equal to [`BitStore::BITS`] - [`Model::denominator`] bits. If you need
38    /// to set the precision manually, use [`Encoder::with_precision`].
39    ///
40    /// # Panics
41    ///
42    /// The calculation of the number of bits used for 'precision' is subject to
43    /// the following constraints:
44    ///
45    /// - The total available bits is [`BitStore::BITS`]
46    /// - The precision must use at least 2 more bits than that needed to
47    ///   represent [`Model::denominator`]
48    ///
49    /// If these constraints cannot be satisfied this method will panic in debug
50    /// builds
51    pub fn new(model: M) -> Self {
52        let frequency_bits = model.max_denominator().log2() + 1;
53        let precision = M::B::BITS - frequency_bits;
54        Self::with_precision(model, precision)
55    }
56
57    /// Construct a new [`Encoder`] with a custom precision.
58    ///
59    /// # Panics
60    ///
61    /// The calculation of the number of bits used for 'precision' is subject to
62    /// the following constraints:
63    ///
64    /// - The total available bits is [`BitStore::BITS`]
65    /// - The precision must use at least 2 more bits than that needed to
66    ///   represent [`Model::denominator`]
67    ///
68    /// If these constraints cannot be satisfied this method will panic in debug
69    /// builds
70    pub fn with_precision(model: M, precision: u32) -> Self {
71        let frequency_bits = model.max_denominator().log2() + 1;
72        debug_assert!(
73            (precision >= (frequency_bits + 2)),
74            "not enough bits of precision to prevent overflow/underflow",
75        );
76        debug_assert!(
77            (frequency_bits + precision) <= M::B::BITS,
78            "not enough bits in BitStore to support the required precision",
79        );
80
81        Self {
82            model,
83            state: State::new(precision),
84        }
85    }
86
87    /// todo
88    pub const fn with_state(state: State<M::B, W>, model: M) -> Self {
89        Self { model, state }
90    }
91
92    /// Encode a stream of symbols into the provided output.
93    ///
94    /// This method will encode all the symbols in the iterator, followed by EOF
95    /// (`None`), and then call [`Encoder::flush`].
96    ///
97    /// # Errors
98    ///
99    /// This method can fail if the underlying [`BitWrite`] cannot be written
100    /// to.
101    pub fn encode_all(
102        &mut self,
103        symbols: impl IntoIterator<Item = M::Symbol>,
104        output: &mut W,
105    ) -> Result<(), Error> {
106        for symbol in symbols {
107            self.encode(Some(&symbol), output)?;
108        }
109        self.encode(None, output)?;
110        self.flush(output)?;
111        Ok(())
112    }
113
114    /// Encode a symbol into the provided output.
115    ///
116    /// When you finish encoding symbols, you must manually encode an EOF symbol
117    /// by calling [`Encoder::encode`] with `None`.
118    ///
119    /// The internal buffer must be manually flushed using [`Encoder::flush`].
120    ///
121    /// # Errors
122    ///
123    /// This method can fail if the underlying [`BitWrite`] cannot be written
124    /// to.
125    pub fn encode(&mut self, symbol: Option<&M::Symbol>, output: &mut W) -> Result<(), Error> {
126        let Ok(p) = self.model.probability(symbol) else {
127            return Err(ValueError);
128        };
129        let denominator = self.model.denominator();
130        debug_assert!(
131            denominator <= self.model.max_denominator(),
132            "denominator is greater than maximum!"
133        );
134
135        self.state.scale(p, denominator, output)?;
136        self.model.update(symbol);
137
138        Ok(())
139    }
140
141    /// Flush any pending bits from the buffer
142    ///
143    /// This method must be called when you finish writing symbols to a stream
144    /// of bits. This is called automatically when you use
145    /// [`Encoder::encode_all`].
146    ///
147    /// # Errors
148    ///
149    /// This method can fail if the underlying [`BitWrite`] cannot be written
150    /// to.
151    pub fn flush(&mut self, output: &mut W) -> io::Result<()> {
152        self.state.flush(output)
153    }
154
155    /// todo
156    pub fn into_inner(self) -> (M, State<M::B, W>) {
157        (self.model, self.state)
158    }
159
160    /// Reuse the internal state of the Encoder with a new model.
161    ///
162    /// Allows for chaining multiple sequences of symbols into a single stream
163    /// of bits
164    pub fn chain<X>(self, model: X) -> Encoder<X, W>
165    where
166        X: Model<B = M::B>,
167    {
168        Encoder {
169            model,
170            state: self.state,
171        }
172    }
173}
174
175/// A convenience struct which stores the internal state of an [`Encoder`].
176#[derive(Debug)]
177pub struct State<B, W>
178where
179    B: BitStore,
180    W: BitWrite,
181{
182    precision: u32,
183    low: B,
184    high: B,
185    pending: u32,
186    _marker: PhantomData<W>,
187}
188
189impl<B, W> State<B, W>
190where
191    B: BitStore,
192    W: BitWrite,
193{
194    /// todo
195    #[must_use]
196    pub fn new(precision: u32) -> Self {
197        let low = B::ZERO;
198        let high = B::ONE << precision;
199        let pending = 0;
200
201        Self {
202            precision,
203            low,
204            high,
205            pending,
206            _marker: PhantomData,
207        }
208    }
209
210    fn three_quarter(&self) -> B {
211        self.half() + self.quarter()
212    }
213
214    fn half(&self) -> B {
215        B::ONE << (self.precision - 1)
216    }
217
218    fn quarter(&self) -> B {
219        B::ONE << (self.precision - 2)
220    }
221
222    fn scale(&mut self, p: Range<B>, denominator: B, output: &mut W) -> io::Result<()> {
223        let range = self.high - self.low + B::ONE;
224
225        self.high = self.low + (range * p.end) / denominator - B::ONE;
226        self.low += (range * p.start) / denominator;
227
228        self.normalise(output)
229    }
230
231    fn normalise(&mut self, output: &mut W) -> io::Result<()> {
232        while self.high < self.half() || self.low >= self.half() {
233            if self.high < self.half() {
234                self.emit(false, output)?;
235                self.high <<= 1;
236                self.low <<= 1;
237            } else {
238                // self.low >= self.half()
239                self.emit(true, output)?;
240                self.low = (self.low - self.half()) << 1;
241                self.high = (self.high - self.half()) << 1;
242            }
243        }
244
245        while self.low >= self.quarter() && self.high < (self.three_quarter()) {
246            self.pending += 1;
247            self.low = (self.low - self.quarter()) << 1;
248            self.high = (self.high - self.quarter()) << 1;
249        }
250
251        Ok(())
252    }
253
254    fn emit(&mut self, bit: bool, output: &mut W) -> io::Result<()> {
255        output.write_bit(bit)?;
256        for _ in 0..self.pending {
257            output.write_bit(!bit)?;
258        }
259        self.pending = 0;
260        Ok(())
261    }
262
263    /// todo
264    pub fn flush(&mut self, output: &mut W) -> io::Result<()> {
265        self.pending += 1;
266        if self.low <= self.quarter() {
267            self.emit(false, output)?;
268        } else {
269            self.emit(true, output)?;
270        }
271
272        Ok(())
273    }
274}