arithmetic_coding_adder_dep/
decoder.rs1use std::marker::PhantomData;
4use std::{io, ops::Range};
5
6use bitstream_io::BitRead;
7
8use crate::{BitStore, Model};
9
10#[derive(Debug)]
17pub struct Decoder<M, R>
18where
19 M: Model,
20 R: BitRead,
21{
22 pub model: M,
24 state: State<M::B, R>,
25}
26
27trait BitReadExt {
28 fn next_bit(&mut self) -> io::Result<Option<bool>>;
29}
30
31impl<R: BitRead> BitReadExt for R {
32 fn next_bit(&mut self) -> io::Result<Option<bool>> {
33 match self.read_bit() {
34 Ok(bit) => Ok(Some(bit)),
35 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => Ok(None),
36 Err(e) => Err(e),
37 }
38 }
39}
40
41impl<M, R> Decoder<M, R>
42where
43 M: Model,
44 R: BitRead,
45{
46 pub fn new(model: M) -> Self {
64 let frequency_bits = model.max_denominator().log2() + 1;
65 let precision = M::B::BITS - frequency_bits;
66
67 Self::with_precision(model, precision)
68 }
69
70 pub fn with_precision(model: M, precision: u32) -> Self {
84 let frequency_bits = model.max_denominator().log2() + 1;
85 debug_assert!(
86 (precision >= (frequency_bits + 2)),
87 "not enough bits of precision to prevent overflow/underflow",
88 );
89 debug_assert!(
90 (frequency_bits + precision) <= M::B::BITS,
91 "not enough bits in BitStore to support the required precision",
92 );
93
94 let state = State::new(precision);
95
96 Self { model, state }
97 }
98
99 pub const fn with_state(state: State<M::B, R>, model: M) -> Self {
101 Self { model, state }
102 }
103
104 pub fn decode_all<'a>(&'a mut self, input: &'a mut R) -> DecodeIter<'a, M, R> {
108 DecodeIter {
109 decoder: self,
110 input,
111 }
112 }
113
114 pub fn decode(&mut self, input: &mut R) -> io::Result<Option<M::Symbol>> {
122 self.state.initialise(input)?;
123
124 let denominator = self.model.denominator();
125 debug_assert!(
126 denominator <= self.model.max_denominator(),
127 "denominator is greater than maximum!"
128 );
129 let value = self.state.value(denominator);
130 let symbol = self.model.symbol(value);
131
132 let p = self
133 .model
134 .probability(symbol.as_ref())
135 .expect("this should not be able to fail. Check the implementation of the model.");
136
137 self.state.scale(p, denominator, input)?;
138 self.model.update(symbol.as_ref());
139
140 Ok(symbol)
141 }
142
143 pub fn chain<X>(self, model: X) -> Decoder<X, R>
148 where
149 X: Model<B = M::B>,
150 {
151 Decoder {
152 model,
153 state: self.state,
154 }
155 }
156
157 pub fn into_inner(self) -> (M, State<M::B, R>) {
159 (self.model, self.state)
160 }
161}
162
163#[allow(missing_debug_implementations)]
165pub struct DecodeIter<'a, M, R>
166where
167 M: Model,
168 R: BitRead,
169{
170 decoder: &'a mut Decoder<M, R>,
171 input: &'a mut R,
172}
173
174impl<M, R> Iterator for DecodeIter<'_, M, R>
175where
176 M: Model,
177 R: BitRead,
178{
179 type Item = io::Result<M::Symbol>;
180
181 fn next(&mut self) -> Option<Self::Item> {
182 self.decoder.decode(self.input).transpose()
183 }
184}
185
186#[derive(Debug)]
188pub struct State<B, R>
189where
190 B: BitStore,
191 R: BitRead,
192{
193 precision: u32,
194 low: B,
195 high: B,
196 _marker: PhantomData<R>,
197 x: B,
198 uninitialised: bool,
199}
200
201impl<B, R> State<B, R>
202where
203 B: BitStore,
204 R: BitRead,
205{
206 #[must_use]
208 pub fn new(precision: u32) -> Self {
209 let low = B::ZERO;
210 let high = B::ONE << precision;
211 let x = B::ZERO;
212
213 Self {
214 precision,
215 low,
216 high,
217 _marker: PhantomData,
218 x,
219 uninitialised: true,
220 }
221 }
222
223 fn half(&self) -> B {
224 B::ONE << (self.precision - 1)
225 }
226
227 fn quarter(&self) -> B {
228 B::ONE << (self.precision - 2)
229 }
230
231 fn three_quarter(&self) -> B {
232 self.half() + self.quarter()
233 }
234
235 fn normalise(&mut self, input: &mut R) -> io::Result<()> {
236 while self.high < self.half() || self.low >= self.half() {
237 if self.high < self.half() {
238 self.high <<= 1;
239 self.low <<= 1;
240 self.x <<= 1;
241 } else {
242 self.low = (self.low - self.half()) << 1;
244 self.high = (self.high - self.half()) << 1;
245 self.x = (self.x - self.half()) << 1;
246 }
247
248 if let Some(true) = input.next_bit()? {
249 self.x += B::ONE;
250 }
251 }
252
253 while self.low >= self.quarter() && self.high < (self.three_quarter()) {
254 self.low = (self.low - self.quarter()) << 1;
255 self.high = (self.high - self.quarter()) << 1;
256 self.x = (self.x - self.quarter()) << 1;
257
258 if input.next_bit()? == Some(true) {
259 self.x += B::ONE;
260 }
261 }
262
263 Ok(())
264 }
265
266 fn scale(&mut self, p: Range<B>, denominator: B, input: &mut R) -> io::Result<()> {
267 let range = self.high - self.low + B::ONE;
268
269 self.high = self.low + (range * p.end) / denominator - B::ONE;
270 self.low += (range * p.start) / denominator;
271
272 self.normalise(input)
273 }
274
275 fn value(&self, denominator: B) -> B {
276 let range = self.high - self.low + B::ONE;
277 ((self.x - self.low + B::ONE) * denominator - B::ONE) / range
278 }
279
280 fn fill(&mut self, input: &mut R) -> io::Result<()> {
281 for _ in 0..self.precision {
282 self.x <<= 1;
283 if input.next_bit()? == Some(true) {
284 self.x += B::ONE;
285 }
286 }
287 Ok(())
288 }
289
290 fn initialise(&mut self, input: &mut R) -> io::Result<()> {
291 if self.uninitialised {
292 self.fill(input)?;
293 self.uninitialised = false;
294 }
295 Ok(())
296 }
297}