arithmetic_coding/encoder.rs
1//! The [`Encoder`] half of the arithmetic coding library.
2
3use std::{io, ops::Range};
4
5use bitstream_io::BitWrite;
6
7use crate::{common, BitStore, Error, Model};
8
9// this algorithm is derived from this article - https://marknelson.us/posts/2014/10/19/data-compression-with-arithmetic-coding.html
10
11/// An arithmetic encoder
12///
13/// An arithmetic decoder converts a stream of symbols into a stream of bits,
14/// using a predictive [`Model`].
15#[derive(Debug)]
16pub struct Encoder<'a, M, W>
17where
18 M: Model,
19 W: BitWrite,
20{
21 model: M,
22 state: State<'a, M::B, W>,
23}
24
25impl<'a, M, W> Encoder<'a, M, W>
26where
27 M: Model,
28 W: BitWrite,
29{
30 /// Construct a new [`Encoder`].
31 ///
32 /// The 'precision' of the encoder is maximised, based on the number of bits
33 /// needed to represent the [`Model::denominator`]. 'precision' bits is
34 /// equal to [`BitStore::BITS`] - [`Model::denominator`] bits. If you need
35 /// to set the precision manually, use [`Encoder::with_precision`].
36 ///
37 /// # Panics
38 ///
39 /// The calculation of the number of bits used for 'precision' is subject to
40 /// the following constraints:
41 ///
42 /// - The total available bits is [`BitStore::BITS`]
43 /// - The precision must use at least 2 more bits than that needed to
44 /// represent [`Model::denominator`]
45 ///
46 /// If these constraints cannot be satisfied this method will panic in debug
47 /// builds
48 pub fn new(model: M, bitwriter: &'a mut W) -> Self {
49 let frequency_bits = model.max_denominator().log2() + 1;
50 let precision = M::B::BITS - frequency_bits;
51 Self::with_precision(model, bitwriter, precision)
52 }
53
54 /// Construct a new [`Encoder`] with a custom precision.
55 ///
56 /// # Panics
57 ///
58 /// The calculation of the number of bits used for 'precision' is subject to
59 /// the following constraints:
60 ///
61 /// - The total available bits is [`BitStore::BITS`]
62 /// - The precision must use at least 2 more bits than that needed to
63 /// represent [`Model::denominator`]
64 ///
65 /// If these constraints cannot be satisfied this method will panic in debug
66 /// builds
67 pub fn with_precision(model: M, bitwriter: &'a mut W, precision: u32) -> Self {
68 let frequency_bits = model.max_denominator().log2() + 1;
69 debug_assert!(
70 (precision >= (frequency_bits + 2)),
71 "not enough bits of precision to prevent overflow/underflow",
72 );
73 debug_assert!(
74 (frequency_bits + precision) <= M::B::BITS,
75 "not enough bits in BitStore to support the required precision",
76 );
77
78 Self {
79 model,
80 state: State::new(precision, bitwriter),
81 }
82 }
83
84 /// Create an encoder from an existing [`State`].
85 ///
86 /// This is useful for manually chaining a shared buffer through multiple
87 /// encoders.
88 pub const fn with_state(state: State<'a, M::B, W>, model: M) -> Self {
89 Self { model, state }
90 }
91
92 /// Encode a stream of symbols into the provided output.
93 ///
94 /// This method will encode all the symbols in the iterator, followed by EOF
95 /// (`None`), and then call [`Encoder::flush`].
96 ///
97 /// # Errors
98 ///
99 /// This method can fail if the underlying [`BitWrite`] cannot be written
100 /// to.
101 pub fn encode_all(
102 &mut self,
103 symbols: impl IntoIterator<Item = M::Symbol>,
104 ) -> Result<(), Error<M::ValueError>> {
105 for symbol in symbols {
106 self.encode(Some(&symbol))?;
107 }
108 self.encode(None)?;
109 self.flush()?;
110 Ok(())
111 }
112
113 /// Encode a symbol into the provided output.
114 ///
115 /// When you finish encoding symbols, you must manually encode an EOF symbol
116 /// by calling [`Encoder::encode`] with `None`.
117 ///
118 /// The internal buffer must be manually flushed using [`Encoder::flush`].
119 ///
120 /// # Errors
121 ///
122 /// This method can fail if the underlying [`BitWrite`] cannot be written
123 /// to.
124 pub fn encode(&mut self, symbol: Option<&M::Symbol>) -> Result<(), Error<M::ValueError>> {
125 let p = self.model.probability(symbol).map_err(Error::ValueError)?;
126 let denominator = self.model.denominator();
127 debug_assert!(
128 denominator <= self.model.max_denominator(),
129 "denominator is greater than maximum!"
130 );
131
132 self.state.scale(p, denominator)?;
133 self.model.update(symbol);
134
135 Ok(())
136 }
137
138 /// Flush any pending bits from the buffer
139 ///
140 /// This method must be called when you finish writing symbols to a stream
141 /// of bits. This is called automatically when you use
142 /// [`Encoder::encode_all`].
143 ///
144 /// # Errors
145 ///
146 /// This method can fail if the underlying [`BitWrite`] cannot be written
147 /// to.
148 pub fn flush(&mut self) -> io::Result<()> {
149 self.state.flush()
150 }
151
152 /// todo
153 pub fn into_inner(self) -> (M, State<'a, M::B, W>) {
154 (self.model, self.state)
155 }
156
157 /// Reuse the internal state of the Encoder with a new model.
158 ///
159 /// Allows for chaining multiple sequences of symbols into a single stream
160 /// of bits
161 pub fn chain<X>(self, model: X) -> Encoder<'a, X, W>
162 where
163 X: Model<B = M::B>,
164 {
165 Encoder {
166 model,
167 state: self.state,
168 }
169 }
170}
171
172/// A convenience struct which stores the internal state of an [`Encoder`].
173#[derive(Debug)]
174pub struct State<'a, B, W>
175where
176 B: BitStore,
177 W: BitWrite,
178{
179 state: common::State<B>,
180 pending: u32,
181 output: &'a mut W,
182}
183
184impl<'a, B, W> State<'a, B, W>
185where
186 B: BitStore,
187 W: BitWrite,
188{
189 /// Manually construct a [`State`].
190 ///
191 /// Normally this would be done automatically using the [`Encoder::new`]
192 /// method.
193 pub fn new(precision: u32, output: &'a mut W) -> Self {
194 let state = common::State::new(precision);
195 let pending = 0;
196
197 Self {
198 state,
199 pending,
200 output,
201 }
202 }
203
204 fn scale(&mut self, p: Range<B>, denominator: B) -> io::Result<()> {
205 self.state.scale(p, denominator);
206 self.normalise()
207 }
208
209 fn normalise(&mut self) -> io::Result<()> {
210 while self.state.high < self.state.half() || self.state.low >= self.state.half() {
211 if self.state.high < self.state.half() {
212 self.emit(false)?;
213 self.state.high <<= 1;
214 self.state.low <<= 1;
215 } else {
216 self.emit(true)?;
217 self.state.low = (self.state.low - self.state.half()) << 1;
218 self.state.high = (self.state.high - self.state.half()) << 1;
219 }
220 }
221
222 while self.state.low >= self.state.quarter()
223 && self.state.high < (self.state.three_quarter())
224 {
225 self.pending += 1;
226 self.state.low = (self.state.low - self.state.quarter()) << 1;
227 self.state.high = (self.state.high - self.state.quarter()) << 1;
228 }
229
230 Ok(())
231 }
232
233 fn emit(&mut self, bit: bool) -> io::Result<()> {
234 self.output.write_bit(bit)?;
235 for _ in 0..self.pending {
236 self.output.write_bit(!bit)?;
237 }
238 self.pending = 0;
239 Ok(())
240 }
241
242 /// Flush the internal buffer and write all remaining bits to the output.
243 /// This method MUST be called when you finish writing symbols to ensure
244 /// they are fully written to the output.
245 ///
246 /// # Errors
247 ///
248 /// This method can fail if the output cannot be written to
249 pub fn flush(&mut self) -> io::Result<()> {
250 self.pending += 1;
251 if self.state.low <= self.state.quarter() {
252 self.emit(false)?;
253 } else {
254 self.emit(true)?;
255 }
256
257 Ok(())
258 }
259}