arithmetic_coding_adder_dep/
decoder.rs1use std::{io, ops::Range};
4use std::marker::PhantomData;
5
6use bitstream_io::BitRead;
7
8use crate::{BitStore, Model};
9
10#[derive(Debug)]
17pub struct Decoder<M, R>
18where
19 M: Model,
20 R: BitRead,
21{
22 pub model: M,
24 state: State<M::B, R>,
25}
26
27trait BitReadExt {
28 fn next_bit(&mut self) -> io::Result<Option<bool>>;
29}
30
31impl<R: BitRead> BitReadExt for R {
32 fn next_bit(&mut self) -> io::Result<Option<bool>> {
33 match self.read_bit() {
34 Ok(bit) => Ok(Some(bit)),
35 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => Ok(None),
36 Err(e) => Err(e),
37 }
38 }
39}
40
41impl<M, R> Decoder<M, R>
42where
43 M: Model,
44 R: BitRead,
45{
46 pub fn new(model: M) -> Self {
64 let frequency_bits = model.max_denominator().log2() + 1;
65 let precision = M::B::BITS - frequency_bits;
66
67 Self::with_precision(model, precision)
68 }
69
70 pub fn with_precision(model: M, precision: u32) -> Self {
84 let frequency_bits = model.max_denominator().log2() + 1;
85 debug_assert!(
86 (precision >= (frequency_bits + 2)),
87 "not enough bits of precision to prevent overflow/underflow",
88 );
89 debug_assert!(
90 (frequency_bits + precision) <= M::B::BITS,
91 "not enough bits in BitStore to support the required precision",
92 );
93
94 let state = State::new(precision);
95
96 Self { model, state }
97 }
98
99 pub const fn with_state(state: State<M::B, R>, model: M) -> Self {
101 Self { model, state }
102 }
103
104 pub fn decode_all<'a>(&'a mut self, input: &'a mut R) -> DecodeIter<M, R> {
108 DecodeIter { decoder: self, input}
109 }
110
111 pub fn decode(&mut self, input: &mut R) -> io::Result<Option<M::Symbol>> {
119 self.state.initialise(input)?;
120
121 let denominator = self.model.denominator();
122 debug_assert!(
123 denominator <= self.model.max_denominator(),
124 "denominator is greater than maximum!"
125 );
126 let value = self.state.value(denominator);
127 let symbol = self.model.symbol(value);
128
129 let p = self
130 .model
131 .probability(symbol.as_ref())
132 .expect("this should not be able to fail. Check the implementation of the model.");
133
134 self.state.scale(p, denominator, input)?;
135 self.model.update(symbol.as_ref());
136
137 Ok(symbol)
138 }
139
140 pub fn chain<X>(self, model: X) -> Decoder<X, R>
145 where
146 X: Model<B = M::B>,
147 {
148 Decoder {
149 model,
150 state: self.state,
151 }
152 }
153
154 pub fn into_inner(self) -> (M, State<M::B, R>) {
156 (self.model, self.state)
157 }
158}
159
160#[allow(missing_debug_implementations)]
162pub struct DecodeIter<'a, M, R>
163where
164 M: Model,
165 R: BitRead,
166{
167 decoder: &'a mut Decoder<M, R>,
168 input: &'a mut R,
169}
170
171impl<'a, M, R> Iterator for DecodeIter<'a, M, R>
172where
173 M: Model,
174 R: BitRead,
175{
176 type Item = io::Result<M::Symbol>;
177
178 fn next(&mut self) -> Option<Self::Item> {
179 self.decoder.decode(self.input).transpose()
180 }
181}
182
183#[derive(Debug)]
185pub struct State<B, R>
186where
187 B: BitStore,
188 R: BitRead,
189{
190 precision: u32,
191 low: B,
192 high: B,
193 _marker: PhantomData<R>,
194 x: B,
195 uninitialised: bool,
196}
197
198impl<B, R> State<B, R>
199where
200 B: BitStore,
201 R: BitRead,
202{
203 #[must_use] pub fn new(precision: u32) -> Self {
205 let low = B::ZERO;
206 let high = B::ONE << precision;
207 let x = B::ZERO;
208
209 Self {
210 precision,
211 low,
212 high,
213 _marker: PhantomData,
214 x,
215 uninitialised: true,
216 }
217 }
218
219 fn half(&self) -> B {
220 B::ONE << (self.precision - 1)
221 }
222
223 fn quarter(&self) -> B {
224 B::ONE << (self.precision - 2)
225 }
226
227 fn three_quarter(&self) -> B {
228 self.half() + self.quarter()
229 }
230
231 fn normalise(&mut self, input: &mut R) -> io::Result<()> {
232 while self.high < self.half() || self.low >= self.half() {
233 if self.high < self.half() {
234 self.high <<= 1;
235 self.low <<= 1;
236 self.x <<= 1;
237 } else {
238 self.low = (self.low - self.half()) << 1;
240 self.high = (self.high - self.half()) << 1;
241 self.x = (self.x - self.half()) << 1;
242 }
243
244 match input.next_bit()? {
245 Some(true) => {
246 self.x += B::ONE;
247 }
248 Some(false) | None => (),
249 }
250 }
251
252 while self.low >= self.quarter() && self.high < (self.three_quarter()) {
253 self.low = (self.low - self.quarter()) << 1;
254 self.high = (self.high - self.quarter()) << 1;
255 self.x = (self.x - self.quarter()) << 1;
256
257 match input.next_bit()? {
258 Some(true) => {
259 self.x += B::ONE;
260 }
261 Some(false) | None => (),
262 }
263 }
264
265 Ok(())
266 }
267
268 fn scale(&mut self, p: Range<B>, denominator: B, input: &mut R) -> io::Result<()> {
269 let range = self.high - self.low + B::ONE;
270
271 self.high = self.low + (range * p.end) / denominator - B::ONE;
272 self.low += (range * p.start) / denominator;
273
274 self.normalise(input)
275 }
276
277 fn value(&self, denominator: B) -> B {
278 let range = self.high - self.low + B::ONE;
279 ((self.x - self.low + B::ONE) * denominator - B::ONE) / range
280 }
281
282 fn fill(&mut self, input: &mut R) -> io::Result<()> {
283 for _ in 0..self.precision {
284 self.x <<= 1;
285 match input.next_bit()? {
286 Some(true) => {
287 self.x += B::ONE;
288 }
289 Some(false) | None => (),
290 }
291 }
292 Ok(())
293 }
294
295 fn initialise(&mut self, input: &mut R) -> io::Result<()> {
296 if self.uninitialised {
297 self.fill(input)?;
298 self.uninitialised = false;
299 }
300 Ok(())
301 }
302}