arcode/
encode.rs

1//! This module contains the main code for the encoder. It also
2//! contains an simple implementation of a binary encoder.
3
4use std::io::{Error, Write};
5
6use bitbit::BitWriter;
7
8use crate::{Model, Range};
9
10pub struct ArithmeticEncoder {
11    _precision: u64,
12    pending_bit_count: u32,
13    range: Range,
14}
15
16impl ArithmeticEncoder {
17    /// # Arguments
18    /// `precision` is the [bit precision](https://en.wikipedia.org/wiki/Arithmetic_coding#Precision_and_renormalization)
19    /// that the encoder should use. If the
20    /// precision is too low than symbols will not be able to be differentiated.
21    pub fn new(precision: u64) -> Self {
22        Self {
23            _precision: precision,
24            pending_bit_count: 0,
25            range: Range::new(precision),
26        }
27    }
28
29    pub fn encode<T: Write>(
30        &mut self,
31        symbol: u32,
32        source_model: &Model,
33        output: &mut BitWriter<T>,
34    ) -> Result<(), Error> {
35        let low_high = self.range.calculate_range(symbol, source_model);
36        self.range.update_range(low_high);
37
38        while self.range.in_bottom_half() || self.range.in_upper_half() {
39            if self.range.in_bottom_half() {
40                self.range.scale_bottom_half();
41                self.emit(false, output)?;
42            } else if self.range.in_upper_half() {
43                self.range.scale_upper_half();
44                self.emit(true, output)?;
45            }
46        }
47
48        while self.range.in_middle_half() {
49            self.pending_bit_count += 1;
50            self.range.scale_middle_half();
51        }
52
53        Ok(())
54    }
55
56    fn emit<T: Write>(&mut self, bit: bool, output: &mut BitWriter<T>) -> Result<(), Error> {
57        output.write_bit(bit)?;
58
59        while self.pending_bit_count > 0 {
60            output.write_bit(!bit)?;
61            self.pending_bit_count -= 1;
62        }
63
64        Ok(())
65    }
66
67    pub fn finish_encode<T: Write>(&mut self, output: &mut BitWriter<T>) -> Result<(), Error> {
68        self.pending_bit_count += 1;
69
70        if self.range.in_bottom_quarter() {
71            self.emit(false, output)?;
72        } else {
73            self.emit(true, output)?;
74        }
75
76        Ok(())
77    }
78}
79
80#[cfg(test)]
81mod test {
82    use std::io::Cursor;
83
84    use bitbit::BitWriter;
85
86    use super::ArithmeticEncoder;
87    use crate::{EOFKind, Model};
88
89    #[test]
90    fn e2e() {
91        let mut encoder = ArithmeticEncoder::new(30);
92        let mut source_model = Model::builder().num_symbols(10).eof(EOFKind::End).build();
93        let mut output = Cursor::new(vec![]);
94        let mut out_writer = BitWriter::new(&mut output);
95        let to_encode: [u32; 5] = [7, 2, 2, 2, 7];
96        for x in &to_encode {
97            encoder.encode(*x, &source_model, &mut out_writer).unwrap();
98            source_model.update_symbol(*x);
99        }
100        encoder
101            .encode(source_model.eof(), &source_model, &mut out_writer)
102            .unwrap();
103        encoder.finish_encode(&mut out_writer).unwrap();
104        out_writer.pad_to_byte().unwrap();
105        assert_eq!(output.get_ref(), &[184, 96, 208]);
106    }
107}