1use std::io::{Error, Write};
5
6use bitbit::BitWriter;
7
8use crate::{Model, Range};
9
10pub struct ArithmeticEncoder {
11 _precision: u64,
12 pending_bit_count: u32,
13 range: Range,
14}
15
16impl ArithmeticEncoder {
17 pub fn new(precision: u64) -> Self {
22 Self {
23 _precision: precision,
24 pending_bit_count: 0,
25 range: Range::new(precision),
26 }
27 }
28
29 pub fn encode<T: Write>(
30 &mut self,
31 symbol: u32,
32 source_model: &Model,
33 output: &mut BitWriter<T>,
34 ) -> Result<(), Error> {
35 let low_high = self.range.calculate_range(symbol, source_model);
36 self.range.update_range(low_high);
37
38 while self.range.in_bottom_half() || self.range.in_upper_half() {
39 if self.range.in_bottom_half() {
40 self.range.scale_bottom_half();
41 self.emit(false, output)?;
42 } else if self.range.in_upper_half() {
43 self.range.scale_upper_half();
44 self.emit(true, output)?;
45 }
46 }
47
48 while self.range.in_middle_half() {
49 self.pending_bit_count += 1;
50 self.range.scale_middle_half();
51 }
52
53 Ok(())
54 }
55
56 fn emit<T: Write>(&mut self, bit: bool, output: &mut BitWriter<T>) -> Result<(), Error> {
57 output.write_bit(bit)?;
58
59 while self.pending_bit_count > 0 {
60 output.write_bit(!bit)?;
61 self.pending_bit_count -= 1;
62 }
63
64 Ok(())
65 }
66
67 pub fn finish_encode<T: Write>(&mut self, output: &mut BitWriter<T>) -> Result<(), Error> {
68 self.pending_bit_count += 1;
69
70 if self.range.in_bottom_quarter() {
71 self.emit(false, output)?;
72 } else {
73 self.emit(true, output)?;
74 }
75
76 Ok(())
77 }
78}
79
80#[cfg(test)]
81mod test {
82 use std::io::Cursor;
83
84 use bitbit::BitWriter;
85
86 use super::ArithmeticEncoder;
87 use crate::{EOFKind, Model};
88
89 #[test]
90 fn e2e() {
91 let mut encoder = ArithmeticEncoder::new(30);
92 let mut source_model = Model::builder().num_symbols(10).eof(EOFKind::End).build();
93 let mut output = Cursor::new(vec![]);
94 let mut out_writer = BitWriter::new(&mut output);
95 let to_encode: [u32; 5] = [7, 2, 2, 2, 7];
96 for x in &to_encode {
97 encoder.encode(*x, &source_model, &mut out_writer).unwrap();
98 source_model.update_symbol(*x);
99 }
100 encoder
101 .encode(source_model.eof(), &source_model, &mut out_writer)
102 .unwrap();
103 encoder.finish_encode(&mut out_writer).unwrap();
104 out_writer.pad_to_byte().unwrap();
105 assert_eq!(output.get_ref(), &[184, 96, 208]);
106 }
107}