arithmetic_coding/
encoder.rs

1//! The [`Encoder`] half of the arithmetic coding library.
2
3use std::{io, ops::Range};
4
5use bitstream_io::BitWrite;
6
7use crate::{common, BitStore, Error, Model};
8
9// this algorithm is derived from this article - https://marknelson.us/posts/2014/10/19/data-compression-with-arithmetic-coding.html
10
11/// An arithmetic encoder
12///
13/// An arithmetic decoder converts a stream of symbols into a stream of bits,
14/// using a predictive [`Model`].
15#[derive(Debug)]
16pub struct Encoder<'a, M, W>
17where
18    M: Model,
19    W: BitWrite,
20{
21    model: M,
22    state: State<'a, M::B, W>,
23}
24
25impl<'a, M, W> Encoder<'a, M, W>
26where
27    M: Model,
28    W: BitWrite,
29{
30    /// Construct a new [`Encoder`].
31    ///
32    /// The 'precision' of the encoder is maximised, based on the number of bits
33    /// needed to represent the [`Model::denominator`]. 'precision' bits is
34    /// equal to [`BitStore::BITS`] - [`Model::denominator`] bits. If you need
35    /// to set the precision manually, use [`Encoder::with_precision`].
36    ///
37    /// # Panics
38    ///
39    /// The calculation of the number of bits used for 'precision' is subject to
40    /// the following constraints:
41    ///
42    /// - The total available bits is [`BitStore::BITS`]
43    /// - The precision must use at least 2 more bits than that needed to
44    ///   represent [`Model::denominator`]
45    ///
46    /// If these constraints cannot be satisfied this method will panic in debug
47    /// builds
48    pub fn new(model: M, bitwriter: &'a mut W) -> Self {
49        let frequency_bits = model.max_denominator().log2() + 1;
50        let precision = M::B::BITS - frequency_bits;
51        Self::with_precision(model, bitwriter, precision)
52    }
53
54    /// Construct a new [`Encoder`] with a custom precision.
55    ///
56    /// # Panics
57    ///
58    /// The calculation of the number of bits used for 'precision' is subject to
59    /// the following constraints:
60    ///
61    /// - The total available bits is [`BitStore::BITS`]
62    /// - The precision must use at least 2 more bits than that needed to
63    ///   represent [`Model::denominator`]
64    ///
65    /// If these constraints cannot be satisfied this method will panic in debug
66    /// builds
67    pub fn with_precision(model: M, bitwriter: &'a mut W, precision: u32) -> Self {
68        let frequency_bits = model.max_denominator().log2() + 1;
69        debug_assert!(
70            (precision >= (frequency_bits + 2)),
71            "not enough bits of precision to prevent overflow/underflow",
72        );
73        debug_assert!(
74            (frequency_bits + precision) <= M::B::BITS,
75            "not enough bits in BitStore to support the required precision",
76        );
77
78        Self {
79            model,
80            state: State::new(precision, bitwriter),
81        }
82    }
83
84    /// Create an encoder from an existing [`State`].
85    ///
86    /// This is useful for manually chaining a shared buffer through multiple
87    /// encoders.
88    pub const fn with_state(state: State<'a, M::B, W>, model: M) -> Self {
89        Self { model, state }
90    }
91
92    /// Encode a stream of symbols into the provided output.
93    ///
94    /// This method will encode all the symbols in the iterator, followed by EOF
95    /// (`None`), and then call [`Encoder::flush`].
96    ///
97    /// # Errors
98    ///
99    /// This method can fail if the underlying [`BitWrite`] cannot be written
100    /// to.
101    pub fn encode_all(
102        &mut self,
103        symbols: impl IntoIterator<Item = M::Symbol>,
104    ) -> Result<(), Error<M::ValueError>> {
105        for symbol in symbols {
106            self.encode(Some(&symbol))?;
107        }
108        self.encode(None)?;
109        self.flush()?;
110        Ok(())
111    }
112
113    /// Encode a symbol into the provided output.
114    ///
115    /// When you finish encoding symbols, you must manually encode an EOF symbol
116    /// by calling [`Encoder::encode`] with `None`.
117    ///
118    /// The internal buffer must be manually flushed using [`Encoder::flush`].
119    ///
120    /// # Errors
121    ///
122    /// This method can fail if the underlying [`BitWrite`] cannot be written
123    /// to.
124    pub fn encode(&mut self, symbol: Option<&M::Symbol>) -> Result<(), Error<M::ValueError>> {
125        let p = self.model.probability(symbol).map_err(Error::ValueError)?;
126        let denominator = self.model.denominator();
127        debug_assert!(
128            denominator <= self.model.max_denominator(),
129            "denominator is greater than maximum!"
130        );
131
132        self.state.scale(p, denominator)?;
133        self.model.update(symbol);
134
135        Ok(())
136    }
137
138    /// Flush any pending bits from the buffer
139    ///
140    /// This method must be called when you finish writing symbols to a stream
141    /// of bits. This is called automatically when you use
142    /// [`Encoder::encode_all`].
143    ///
144    /// # Errors
145    ///
146    /// This method can fail if the underlying [`BitWrite`] cannot be written
147    /// to.
148    pub fn flush(&mut self) -> io::Result<()> {
149        self.state.flush()
150    }
151
152    /// todo
153    pub fn into_inner(self) -> (M, State<'a, M::B, W>) {
154        (self.model, self.state)
155    }
156
157    /// Reuse the internal state of the Encoder with a new model.
158    ///
159    /// Allows for chaining multiple sequences of symbols into a single stream
160    /// of bits
161    pub fn chain<X>(self, model: X) -> Encoder<'a, X, W>
162    where
163        X: Model<B = M::B>,
164    {
165        Encoder {
166            model,
167            state: self.state,
168        }
169    }
170}
171
172/// A convenience struct which stores the internal state of an [`Encoder`].
173#[derive(Debug)]
174pub struct State<'a, B, W>
175where
176    B: BitStore,
177    W: BitWrite,
178{
179    state: common::State<B>,
180    pending: u32,
181    output: &'a mut W,
182}
183
184impl<'a, B, W> State<'a, B, W>
185where
186    B: BitStore,
187    W: BitWrite,
188{
189    /// Manually construct a [`State`].
190    ///
191    /// Normally this would be done automatically using the [`Encoder::new`]
192    /// method.
193    pub fn new(precision: u32, output: &'a mut W) -> Self {
194        let state = common::State::new(precision);
195        let pending = 0;
196
197        Self {
198            state,
199            pending,
200            output,
201        }
202    }
203
204    fn scale(&mut self, p: Range<B>, denominator: B) -> io::Result<()> {
205        self.state.scale(p, denominator);
206        self.normalise()
207    }
208
209    fn normalise(&mut self) -> io::Result<()> {
210        while self.state.high < self.state.half() || self.state.low >= self.state.half() {
211            if self.state.high < self.state.half() {
212                self.emit(false)?;
213                self.state.high <<= 1;
214                self.state.low <<= 1;
215            } else {
216                self.emit(true)?;
217                self.state.low = (self.state.low - self.state.half()) << 1;
218                self.state.high = (self.state.high - self.state.half()) << 1;
219            }
220        }
221
222        while self.state.low >= self.state.quarter()
223            && self.state.high < (self.state.three_quarter())
224        {
225            self.pending += 1;
226            self.state.low = (self.state.low - self.state.quarter()) << 1;
227            self.state.high = (self.state.high - self.state.quarter()) << 1;
228        }
229
230        Ok(())
231    }
232
233    fn emit(&mut self, bit: bool) -> io::Result<()> {
234        self.output.write_bit(bit)?;
235        for _ in 0..self.pending {
236            self.output.write_bit(!bit)?;
237        }
238        self.pending = 0;
239        Ok(())
240    }
241
242    /// Flush the internal buffer and write all remaining bits to the output.
243    /// This method MUST be called when you finish writing symbols to ensure
244    /// they are fully written to the output.
245    ///
246    /// # Errors
247    ///
248    /// This method can fail if the output cannot be written to
249    pub fn flush(&mut self) -> io::Result<()> {
250        self.pending += 1;
251        if self.state.low <= self.state.quarter() {
252            self.emit(false)?;
253        } else {
254            self.emit(true)?;
255        }
256
257        Ok(())
258    }
259}