Skip to main content

datacortex_core/entropy/
arithmetic.rs

1//! Binary Arithmetic Coder — PAQ8-style, 12-bit precision, carry-free.
2//!
3//! Encodes/decodes one bit at a time given a 12-bit probability of bit=1.
4//! Uses 32-bit range [low, high] with byte-wise normalization.
5//!
6//! Probabilities must be in [1, 4095]. 0 and 4096 are forbidden.
7
8/// Precision bits for probability (12-bit).
9const PROB_BITS: u32 = 12;
10
11/// Maximum probability value (exclusive upper bound for scaling).
12const PROB_SCALE: u32 = 1 << PROB_BITS; // 4096
13
14// ─── Encoder ────────────────────────────────────────────────────────────
15
16/// Binary arithmetic encoder. Accumulates compressed bytes.
17pub struct ArithmeticEncoder {
18    low: u32,
19    high: u32,
20    output: Vec<u8>,
21}
22
23impl ArithmeticEncoder {
24    /// Create a new encoder.
25    pub fn new() -> Self {
26        ArithmeticEncoder {
27            low: 0,
28            high: 0xFFFF_FFFF,
29            output: Vec::new(),
30        }
31    }
32
33    /// Encode a single bit with probability `p` of bit=1 (12-bit, [1, 4095]).
34    #[inline(always)]
35    pub fn encode(&mut self, bit: u8, p: u32) {
36        debug_assert!(
37            (1..=4095).contains(&p),
38            "probability {p} out of range [1,4095]"
39        );
40
41        let range = self.high - self.low;
42        // mid divides the range: [low, low+mid) = bit 0, [low+mid, high] = bit 1
43        let mid = self.low
44            + (range >> PROB_BITS) * (PROB_SCALE - p)
45            + (((range & (PROB_SCALE - 1)) * (PROB_SCALE - p)) >> PROB_BITS);
46
47        if bit != 0 {
48            self.low = mid + 1;
49        } else {
50            self.high = mid;
51        }
52
53        // Normalize: output matching top bytes.
54        while (self.low ^ self.high) < 0x0100_0000 {
55            self.output.push((self.low >> 24) as u8);
56            self.low <<= 8;
57            self.high = (self.high << 8) | 0xFF;
58        }
59    }
60
61    /// Flush the encoder — write remaining state bytes.
62    /// Must be called after encoding all bits.
63    pub fn finish(mut self) -> Vec<u8> {
64        // Write 4 more bytes to flush the state.
65        self.output.push((self.low >> 24) as u8);
66        self.output.push((self.low >> 16) as u8);
67        self.output.push((self.low >> 8) as u8);
68        self.output.push(self.low as u8);
69        self.output
70    }
71}
72
73impl Default for ArithmeticEncoder {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79// ─── Decoder ────────────────────────────────────────────────────────────
80
81/// Binary arithmetic decoder. Reads bits from compressed data.
82pub struct ArithmeticDecoder<'a> {
83    low: u32,
84    high: u32,
85    code: u32, // current code value from input
86    data: &'a [u8],
87    pos: usize,
88}
89
90impl<'a> ArithmeticDecoder<'a> {
91    /// Create a new decoder from compressed data.
92    pub fn new(data: &'a [u8]) -> Self {
93        let mut dec = ArithmeticDecoder {
94            low: 0,
95            high: 0xFFFF_FFFF,
96            code: 0,
97            data,
98            pos: 0,
99        };
100        // Load initial 4 bytes into code.
101        for _ in 0..4 {
102            dec.code = (dec.code << 8) | dec.read_byte() as u32;
103        }
104        dec
105    }
106
107    /// Decode a single bit given probability `p` of bit=1 (12-bit, [1, 4095]).
108    #[inline(always)]
109    pub fn decode(&mut self, p: u32) -> u8 {
110        debug_assert!(
111            (1..=4095).contains(&p),
112            "probability {p} out of range [1,4095]"
113        );
114
115        let range = self.high - self.low;
116        let mid = self.low
117            + (range >> PROB_BITS) * (PROB_SCALE - p)
118            + (((range & (PROB_SCALE - 1)) * (PROB_SCALE - p)) >> PROB_BITS);
119
120        let bit = if self.code > mid { 1u8 } else { 0u8 };
121
122        if bit != 0 {
123            self.low = mid + 1;
124        } else {
125            self.high = mid;
126        }
127
128        // Normalize: shift out matching top bytes, read new byte.
129        while (self.low ^ self.high) < 0x0100_0000 {
130            self.low <<= 8;
131            self.high = (self.high << 8) | 0xFF;
132            self.code = (self.code << 8) | self.read_byte() as u32;
133        }
134
135        bit
136    }
137
138    /// Read the next byte from compressed data (0 if past end).
139    #[inline(always)]
140    fn read_byte(&mut self) -> u8 {
141        if self.pos < self.data.len() {
142            let b = self.data[self.pos];
143            self.pos += 1;
144            b
145        } else {
146            0
147        }
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[test]
156    fn encode_decode_single_bit_0() {
157        let mut enc = ArithmeticEncoder::new();
158        enc.encode(0, 2048); // p=0.5
159        let compressed = enc.finish();
160
161        let mut dec = ArithmeticDecoder::new(&compressed);
162        let bit = dec.decode(2048);
163        assert_eq!(bit, 0);
164    }
165
166    #[test]
167    fn encode_decode_single_bit_1() {
168        let mut enc = ArithmeticEncoder::new();
169        enc.encode(1, 2048);
170        let compressed = enc.finish();
171
172        let mut dec = ArithmeticDecoder::new(&compressed);
173        let bit = dec.decode(2048);
174        assert_eq!(bit, 1);
175    }
176
177    #[test]
178    fn encode_decode_sequence() {
179        let bits: Vec<u8> = vec![1, 0, 1, 1, 0, 0, 1, 0];
180        let probs: Vec<u32> = vec![2048, 1000, 3000, 500, 2048, 100, 3900, 2048];
181
182        let mut enc = ArithmeticEncoder::new();
183        for (&bit, &p) in bits.iter().zip(probs.iter()) {
184            enc.encode(bit, p);
185        }
186        let compressed = enc.finish();
187
188        let mut dec = ArithmeticDecoder::new(&compressed);
189        for (i, (&expected_bit, &p)) in bits.iter().zip(probs.iter()).enumerate() {
190            let decoded = dec.decode(p);
191            assert_eq!(
192                decoded, expected_bit,
193                "mismatch at bit {i}: expected {expected_bit}, got {decoded}"
194            );
195        }
196    }
197
198    #[test]
199    fn encode_decode_all_zeros() {
200        let n = 100;
201        let mut enc = ArithmeticEncoder::new();
202        for _ in 0..n {
203            enc.encode(0, 2048);
204        }
205        let compressed = enc.finish();
206
207        let mut dec = ArithmeticDecoder::new(&compressed);
208        for i in 0..n {
209            let bit = dec.decode(2048);
210            assert_eq!(bit, 0, "mismatch at bit {i}");
211        }
212    }
213
214    #[test]
215    fn encode_decode_all_ones() {
216        let n = 100;
217        let mut enc = ArithmeticEncoder::new();
218        for _ in 0..n {
219            enc.encode(1, 2048);
220        }
221        let compressed = enc.finish();
222
223        let mut dec = ArithmeticDecoder::new(&compressed);
224        for i in 0..n {
225            let bit = dec.decode(2048);
226            assert_eq!(bit, 1, "mismatch at bit {i}");
227        }
228    }
229
230    #[test]
231    fn high_probability_compresses() {
232        // All 1s with high P(1) should compress well.
233        let n = 1000;
234        let mut enc = ArithmeticEncoder::new();
235        for _ in 0..n {
236            enc.encode(1, 4000); // P(1)≈0.98
237        }
238        let compressed = enc.finish();
239
240        // 1000 bits at high probability should compress to much less than 125 bytes.
241        assert!(
242            compressed.len() < 50,
243            "expected good compression, got {} bytes for {} bits at p=4000",
244            compressed.len(),
245            n
246        );
247
248        // Verify roundtrip.
249        let mut dec = ArithmeticDecoder::new(&compressed);
250        for i in 0..n {
251            assert_eq!(dec.decode(4000), 1, "mismatch at bit {i}");
252        }
253    }
254
255    #[test]
256    fn extreme_probabilities() {
257        // Test with probabilities near the bounds.
258        let bits = [0, 1, 0, 1, 1, 0];
259        let probs = [1, 4095, 1, 4095, 1, 4095];
260
261        let mut enc = ArithmeticEncoder::new();
262        for (&b, &p) in bits.iter().zip(probs.iter()) {
263            enc.encode(b, p);
264        }
265        let compressed = enc.finish();
266
267        let mut dec = ArithmeticDecoder::new(&compressed);
268        for (i, (&expected, &p)) in bits.iter().zip(probs.iter()).enumerate() {
269            let decoded = dec.decode(p);
270            assert_eq!(decoded, expected, "mismatch at bit {i}");
271        }
272    }
273
274    #[test]
275    fn byte_roundtrip() {
276        // Encode a full byte (8 bits) and decode it.
277        let byte_val: u8 = 0xA5; // 10100101
278        let mut enc = ArithmeticEncoder::new();
279        for bpos in 0..8 {
280            let bit = (byte_val >> (7 - bpos)) & 1;
281            enc.encode(bit, 2048);
282        }
283        let compressed = enc.finish();
284
285        let mut dec = ArithmeticDecoder::new(&compressed);
286        let mut decoded_byte: u8 = 0;
287        for bpos in 0..8 {
288            let bit = dec.decode(2048);
289            decoded_byte |= bit << (7 - bpos);
290        }
291        assert_eq!(decoded_byte, byte_val);
292    }
293
294    #[test]
295    fn varying_probabilities_per_bit() {
296        // Simulate a model that adapts probabilities.
297        let data: Vec<u8> = (0u32..50).map(|i| ((i * 7 + 13) & 0xFF) as u8).collect();
298
299        let mut enc = ArithmeticEncoder::new();
300        let mut p: u32 = 2048;
301        for &byte in &data {
302            for bpos in 0..8 {
303                let bit = (byte >> (7 - bpos)) & 1;
304                enc.encode(bit, p);
305                // Simple adaptation.
306                if bit == 1 {
307                    p = (p + 100).min(4095);
308                } else {
309                    p = if p > 101 { p - 100 } else { 1 };
310                }
311            }
312        }
313        let compressed = enc.finish();
314
315        let mut dec = ArithmeticDecoder::new(&compressed);
316        let mut p: u32 = 2048;
317        for (i, &byte) in data.iter().enumerate() {
318            let mut decoded: u8 = 0;
319            for bpos in 0..8 {
320                let bit = dec.decode(p);
321                decoded |= bit << (7 - bpos);
322                if bit == 1 {
323                    p = (p + 100).min(4095);
324                } else {
325                    p = if p > 101 { p - 100 } else { 1 };
326                }
327            }
328            assert_eq!(decoded, byte, "byte mismatch at index {i}");
329        }
330    }
331}