arithmetic_coding/
decoder.rs1use std::{io, ops::Range};
4
5use bitstream_io::BitRead;
6
7use crate::{common, BitStore, Model};
8
9#[derive(Debug)]
16pub struct Decoder<M, R>
17where
18 M: Model,
19 R: BitRead,
20{
21 model: M,
22 state: State<M::B, R>,
23}
24
25trait BitReadExt {
26 fn next_bit(&mut self) -> io::Result<Option<bool>>;
27}
28
29impl<R: BitRead> BitReadExt for R {
30 fn next_bit(&mut self) -> io::Result<Option<bool>> {
31 match self.read_bit() {
32 Ok(bit) => Ok(Some(bit)),
33 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => Ok(None),
34 Err(e) => Err(e),
35 }
36 }
37}
38
39impl<M, R> Decoder<M, R>
40where
41 M: Model,
42 R: BitRead,
43{
44 pub fn new(model: M, input: R) -> Self {
62 let frequency_bits = model.max_denominator().log2() + 1;
63 let precision = M::B::BITS - frequency_bits;
64
65 Self::with_precision(model, input, precision)
66 }
67
68 pub fn with_precision(model: M, input: R, precision: u32) -> Self {
82 let frequency_bits = model.max_denominator().log2() + 1;
83 debug_assert!(
84 (precision >= (frequency_bits + 2)),
85 "not enough bits of precision to prevent overflow/underflow",
86 );
87 debug_assert!(
88 (frequency_bits + precision) <= M::B::BITS,
89 "not enough bits in BitStore to support the required precision",
90 );
91
92 let state = State::new(precision, input);
93
94 Self { model, state }
95 }
96
97 pub const fn with_state(state: State<M::B, R>, model: M) -> Self {
99 Self { model, state }
100 }
101
102 pub fn decode_all(&mut self) -> DecodeIter<M, R> {
106 DecodeIter { decoder: self }
107 }
108
109 #[allow(clippy::missing_panics_doc)]
117 pub fn decode(&mut self) -> io::Result<Option<M::Symbol>> {
118 self.state.initialise()?;
119
120 let denominator = self.model.denominator();
121 debug_assert!(
122 denominator <= self.model.max_denominator(),
123 "denominator is greater than maximum!"
124 );
125 let value = self.state.value(denominator);
126 let symbol = self.model.symbol(value);
127
128 let p = self
129 .model
130 .probability(symbol.as_ref())
131 .expect("this should not be able to fail. Check the implementation of the model.");
132
133 self.state.scale(p, denominator)?;
134 self.model.update(symbol.as_ref());
135
136 Ok(symbol)
137 }
138
139 pub fn chain<X>(self, model: X) -> Decoder<X, R>
144 where
145 X: Model<B = M::B>,
146 {
147 Decoder {
148 model,
149 state: self.state,
150 }
151 }
152
153 pub fn into_inner(self) -> (M, State<M::B, R>) {
155 (self.model, self.state)
156 }
157}
158
159#[allow(missing_debug_implementations)]
161pub struct DecodeIter<'a, M, R>
162where
163 M: Model,
164 R: BitRead,
165{
166 decoder: &'a mut Decoder<M, R>,
167}
168
169impl<M, R> Iterator for DecodeIter<'_, M, R>
170where
171 M: Model,
172 R: BitRead,
173{
174 type Item = io::Result<M::Symbol>;
175
176 fn next(&mut self) -> Option<Self::Item> {
177 self.decoder.decode().transpose()
178 }
179}
180
181#[derive(Debug)]
183pub struct State<B, R>
184where
185 B: BitStore,
186 R: BitRead,
187{
188 state: common::State<B>,
189 input: R,
190 x: B,
191 uninitialised: bool,
192}
193
194impl<B, R> State<B, R>
195where
196 B: BitStore,
197 R: BitRead,
198{
199 pub fn new(precision: u32, input: R) -> Self {
201 let state = common::State::new(precision);
202 let x = B::ZERO;
203
204 Self {
205 state,
206 input,
207 x,
208 uninitialised: true,
209 }
210 }
211
212 fn normalise(&mut self) -> io::Result<()> {
213 while self.state.high < self.state.half() || self.state.low >= self.state.half() {
214 if self.state.high < self.state.half() {
215 self.state.high <<= 1;
216 self.state.low <<= 1;
217 self.x <<= 1;
218 } else {
219 self.state.low = (self.state.low - self.state.half()) << 1;
221 self.state.high = (self.state.high - self.state.half()) << 1;
222 self.x = (self.x - self.state.half()) << 1;
223 }
224
225 if self.input.next_bit()? == Some(true) {
226 self.x += B::ONE;
227 }
228 }
229
230 while self.state.low >= self.state.quarter()
231 && self.state.high < (self.state.three_quarter())
232 {
233 self.state.low = (self.state.low - self.state.quarter()) << 1;
234 self.state.high = (self.state.high - self.state.quarter()) << 1;
235 self.x = (self.x - self.state.quarter()) << 1;
236
237 if self.input.next_bit()? == Some(true) {
238 self.x += B::ONE;
239 }
240 }
241
242 Ok(())
243 }
244
245 fn scale(&mut self, p: Range<B>, denominator: B) -> io::Result<()> {
246 self.state.scale(p, denominator);
247 self.normalise()
248 }
249
250 fn value(&self, denominator: B) -> B {
251 let range = self.state.high - self.state.low + B::ONE;
252 ((self.x - self.state.low + B::ONE) * denominator - B::ONE) / range
253 }
254
255 fn fill(&mut self) -> io::Result<()> {
256 for _ in 0..self.state.precision {
257 self.x <<= 1;
258 if self.input.next_bit()? == Some(true) {
259 self.x += B::ONE;
260 }
261 }
262 Ok(())
263 }
264
265 fn initialise(&mut self) -> io::Result<()> {
266 if self.uninitialised {
267 self.fill()?;
268 self.uninitialised = false;
269 }
270 Ok(())
271 }
272}