arithmetic_coding_adder_dep/
encoder.rs1use std::marker::PhantomData;
4use std::{io, ops::Range};
5
6use bitstream_io::BitWrite;
7
8use crate::Error::ValueError;
9use crate::{BitStore, Error, Model};
10
11#[derive(Debug)]
18pub struct Encoder<M, W>
19where
20 M: Model,
21 W: BitWrite,
22{
23 pub model: M,
25 state: State<M::B, W>,
26}
27
28impl<M, W> Encoder<M, W>
29where
30 M: Model,
31 W: BitWrite,
32{
33 pub fn new(model: M) -> Self {
52 let frequency_bits = model.max_denominator().log2() + 1;
53 let precision = M::B::BITS - frequency_bits;
54 Self::with_precision(model, precision)
55 }
56
57 pub fn with_precision(model: M, precision: u32) -> Self {
71 let frequency_bits = model.max_denominator().log2() + 1;
72 debug_assert!(
73 (precision >= (frequency_bits + 2)),
74 "not enough bits of precision to prevent overflow/underflow",
75 );
76 debug_assert!(
77 (frequency_bits + precision) <= M::B::BITS,
78 "not enough bits in BitStore to support the required precision",
79 );
80
81 Self {
82 model,
83 state: State::new(precision),
84 }
85 }
86
87 pub const fn with_state(state: State<M::B, W>, model: M) -> Self {
89 Self { model, state }
90 }
91
92 pub fn encode_all(
102 &mut self,
103 symbols: impl IntoIterator<Item = M::Symbol>,
104 output: &mut W,
105 ) -> Result<(), Error> {
106 for symbol in symbols {
107 self.encode(Some(&symbol), output)?;
108 }
109 self.encode(None, output)?;
110 self.flush(output)?;
111 Ok(())
112 }
113
114 pub fn encode(&mut self, symbol: Option<&M::Symbol>, output: &mut W) -> Result<(), Error> {
126 let Ok(p) = self.model.probability(symbol) else {
127 return Err(ValueError);
128 };
129 let denominator = self.model.denominator();
130 debug_assert!(
131 denominator <= self.model.max_denominator(),
132 "denominator is greater than maximum!"
133 );
134
135 self.state.scale(p, denominator, output)?;
136 self.model.update(symbol);
137
138 Ok(())
139 }
140
141 pub fn flush(&mut self, output: &mut W) -> io::Result<()> {
152 self.state.flush(output)
153 }
154
155 pub fn into_inner(self) -> (M, State<M::B, W>) {
157 (self.model, self.state)
158 }
159
160 pub fn chain<X>(self, model: X) -> Encoder<X, W>
165 where
166 X: Model<B = M::B>,
167 {
168 Encoder {
169 model,
170 state: self.state,
171 }
172 }
173}
174
175#[derive(Debug)]
177pub struct State<B, W>
178where
179 B: BitStore,
180 W: BitWrite,
181{
182 precision: u32,
183 low: B,
184 high: B,
185 pending: u32,
186 _marker: PhantomData<W>,
187}
188
189impl<B, W> State<B, W>
190where
191 B: BitStore,
192 W: BitWrite,
193{
194 #[must_use]
196 pub fn new(precision: u32) -> Self {
197 let low = B::ZERO;
198 let high = B::ONE << precision;
199 let pending = 0;
200
201 Self {
202 precision,
203 low,
204 high,
205 pending,
206 _marker: PhantomData,
207 }
208 }
209
210 fn three_quarter(&self) -> B {
211 self.half() + self.quarter()
212 }
213
214 fn half(&self) -> B {
215 B::ONE << (self.precision - 1)
216 }
217
218 fn quarter(&self) -> B {
219 B::ONE << (self.precision - 2)
220 }
221
222 fn scale(&mut self, p: Range<B>, denominator: B, output: &mut W) -> io::Result<()> {
223 let range = self.high - self.low + B::ONE;
224
225 self.high = self.low + (range * p.end) / denominator - B::ONE;
226 self.low += (range * p.start) / denominator;
227
228 self.normalise(output)
229 }
230
231 fn normalise(&mut self, output: &mut W) -> io::Result<()> {
232 while self.high < self.half() || self.low >= self.half() {
233 if self.high < self.half() {
234 self.emit(false, output)?;
235 self.high <<= 1;
236 self.low <<= 1;
237 } else {
238 self.emit(true, output)?;
240 self.low = (self.low - self.half()) << 1;
241 self.high = (self.high - self.half()) << 1;
242 }
243 }
244
245 while self.low >= self.quarter() && self.high < (self.three_quarter()) {
246 self.pending += 1;
247 self.low = (self.low - self.quarter()) << 1;
248 self.high = (self.high - self.quarter()) << 1;
249 }
250
251 Ok(())
252 }
253
254 fn emit(&mut self, bit: bool, output: &mut W) -> io::Result<()> {
255 output.write_bit(bit)?;
256 for _ in 0..self.pending {
257 output.write_bit(!bit)?;
258 }
259 self.pending = 0;
260 Ok(())
261 }
262
263 pub fn flush(&mut self, output: &mut W) -> io::Result<()> {
265 self.pending += 1;
266 if self.low <= self.quarter() {
267 self.emit(false, output)?;
268 } else {
269 self.emit(true, output)?;
270 }
271
272 Ok(())
273 }
274}