datacortex_core/entropy/
arithmetic.rs1const PROB_BITS: u32 = 12;
10
11const PROB_SCALE: u32 = 1 << PROB_BITS; pub struct ArithmeticEncoder {
18 low: u32,
19 high: u32,
20 output: Vec<u8>,
21}
22
23impl ArithmeticEncoder {
24 pub fn new() -> Self {
26 ArithmeticEncoder {
27 low: 0,
28 high: 0xFFFF_FFFF,
29 output: Vec::new(),
30 }
31 }
32
33 #[inline(always)]
35 pub fn encode(&mut self, bit: u8, p: u32) {
36 debug_assert!(
37 (1..=4095).contains(&p),
38 "probability {p} out of range [1,4095]"
39 );
40
41 let range = self.high - self.low;
42 let mid = self.low
44 + (range >> PROB_BITS) * (PROB_SCALE - p)
45 + (((range & (PROB_SCALE - 1)) * (PROB_SCALE - p)) >> PROB_BITS);
46
47 if bit != 0 {
48 self.low = mid + 1;
49 } else {
50 self.high = mid;
51 }
52
53 while (self.low ^ self.high) < 0x0100_0000 {
55 self.output.push((self.low >> 24) as u8);
56 self.low <<= 8;
57 self.high = (self.high << 8) | 0xFF;
58 }
59 }
60
61 pub fn finish(mut self) -> Vec<u8> {
64 self.output.push((self.low >> 24) as u8);
66 self.output.push((self.low >> 16) as u8);
67 self.output.push((self.low >> 8) as u8);
68 self.output.push(self.low as u8);
69 self.output
70 }
71}
72
73impl Default for ArithmeticEncoder {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79pub struct ArithmeticDecoder<'a> {
83 low: u32,
84 high: u32,
85 code: u32, data: &'a [u8],
87 pos: usize,
88}
89
90impl<'a> ArithmeticDecoder<'a> {
91 pub fn new(data: &'a [u8]) -> Self {
93 let mut dec = ArithmeticDecoder {
94 low: 0,
95 high: 0xFFFF_FFFF,
96 code: 0,
97 data,
98 pos: 0,
99 };
100 for _ in 0..4 {
102 dec.code = (dec.code << 8) | dec.read_byte() as u32;
103 }
104 dec
105 }
106
107 #[inline(always)]
109 pub fn decode(&mut self, p: u32) -> u8 {
110 debug_assert!(
111 (1..=4095).contains(&p),
112 "probability {p} out of range [1,4095]"
113 );
114
115 let range = self.high - self.low;
116 let mid = self.low
117 + (range >> PROB_BITS) * (PROB_SCALE - p)
118 + (((range & (PROB_SCALE - 1)) * (PROB_SCALE - p)) >> PROB_BITS);
119
120 let bit = if self.code > mid { 1u8 } else { 0u8 };
121
122 if bit != 0 {
123 self.low = mid + 1;
124 } else {
125 self.high = mid;
126 }
127
128 while (self.low ^ self.high) < 0x0100_0000 {
130 self.low <<= 8;
131 self.high = (self.high << 8) | 0xFF;
132 self.code = (self.code << 8) | self.read_byte() as u32;
133 }
134
135 bit
136 }
137
138 #[inline(always)]
140 fn read_byte(&mut self) -> u8 {
141 if self.pos < self.data.len() {
142 let b = self.data[self.pos];
143 self.pos += 1;
144 b
145 } else {
146 0
147 }
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 #[test]
156 fn encode_decode_single_bit_0() {
157 let mut enc = ArithmeticEncoder::new();
158 enc.encode(0, 2048); let compressed = enc.finish();
160
161 let mut dec = ArithmeticDecoder::new(&compressed);
162 let bit = dec.decode(2048);
163 assert_eq!(bit, 0);
164 }
165
166 #[test]
167 fn encode_decode_single_bit_1() {
168 let mut enc = ArithmeticEncoder::new();
169 enc.encode(1, 2048);
170 let compressed = enc.finish();
171
172 let mut dec = ArithmeticDecoder::new(&compressed);
173 let bit = dec.decode(2048);
174 assert_eq!(bit, 1);
175 }
176
177 #[test]
178 fn encode_decode_sequence() {
179 let bits: Vec<u8> = vec![1, 0, 1, 1, 0, 0, 1, 0];
180 let probs: Vec<u32> = vec![2048, 1000, 3000, 500, 2048, 100, 3900, 2048];
181
182 let mut enc = ArithmeticEncoder::new();
183 for (&bit, &p) in bits.iter().zip(probs.iter()) {
184 enc.encode(bit, p);
185 }
186 let compressed = enc.finish();
187
188 let mut dec = ArithmeticDecoder::new(&compressed);
189 for (i, (&expected_bit, &p)) in bits.iter().zip(probs.iter()).enumerate() {
190 let decoded = dec.decode(p);
191 assert_eq!(
192 decoded, expected_bit,
193 "mismatch at bit {i}: expected {expected_bit}, got {decoded}"
194 );
195 }
196 }
197
198 #[test]
199 fn encode_decode_all_zeros() {
200 let n = 100;
201 let mut enc = ArithmeticEncoder::new();
202 for _ in 0..n {
203 enc.encode(0, 2048);
204 }
205 let compressed = enc.finish();
206
207 let mut dec = ArithmeticDecoder::new(&compressed);
208 for i in 0..n {
209 let bit = dec.decode(2048);
210 assert_eq!(bit, 0, "mismatch at bit {i}");
211 }
212 }
213
214 #[test]
215 fn encode_decode_all_ones() {
216 let n = 100;
217 let mut enc = ArithmeticEncoder::new();
218 for _ in 0..n {
219 enc.encode(1, 2048);
220 }
221 let compressed = enc.finish();
222
223 let mut dec = ArithmeticDecoder::new(&compressed);
224 for i in 0..n {
225 let bit = dec.decode(2048);
226 assert_eq!(bit, 1, "mismatch at bit {i}");
227 }
228 }
229
230 #[test]
231 fn high_probability_compresses() {
232 let n = 1000;
234 let mut enc = ArithmeticEncoder::new();
235 for _ in 0..n {
236 enc.encode(1, 4000); }
238 let compressed = enc.finish();
239
240 assert!(
242 compressed.len() < 50,
243 "expected good compression, got {} bytes for {} bits at p=4000",
244 compressed.len(),
245 n
246 );
247
248 let mut dec = ArithmeticDecoder::new(&compressed);
250 for i in 0..n {
251 assert_eq!(dec.decode(4000), 1, "mismatch at bit {i}");
252 }
253 }
254
255 #[test]
256 fn extreme_probabilities() {
257 let bits = [0, 1, 0, 1, 1, 0];
259 let probs = [1, 4095, 1, 4095, 1, 4095];
260
261 let mut enc = ArithmeticEncoder::new();
262 for (&b, &p) in bits.iter().zip(probs.iter()) {
263 enc.encode(b, p);
264 }
265 let compressed = enc.finish();
266
267 let mut dec = ArithmeticDecoder::new(&compressed);
268 for (i, (&expected, &p)) in bits.iter().zip(probs.iter()).enumerate() {
269 let decoded = dec.decode(p);
270 assert_eq!(decoded, expected, "mismatch at bit {i}");
271 }
272 }
273
274 #[test]
275 fn byte_roundtrip() {
276 let byte_val: u8 = 0xA5; let mut enc = ArithmeticEncoder::new();
279 for bpos in 0..8 {
280 let bit = (byte_val >> (7 - bpos)) & 1;
281 enc.encode(bit, 2048);
282 }
283 let compressed = enc.finish();
284
285 let mut dec = ArithmeticDecoder::new(&compressed);
286 let mut decoded_byte: u8 = 0;
287 for bpos in 0..8 {
288 let bit = dec.decode(2048);
289 decoded_byte |= bit << (7 - bpos);
290 }
291 assert_eq!(decoded_byte, byte_val);
292 }
293
294 #[test]
295 fn varying_probabilities_per_bit() {
296 let data: Vec<u8> = (0u32..50).map(|i| ((i * 7 + 13) & 0xFF) as u8).collect();
298
299 let mut enc = ArithmeticEncoder::new();
300 let mut p: u32 = 2048;
301 for &byte in &data {
302 for bpos in 0..8 {
303 let bit = (byte >> (7 - bpos)) & 1;
304 enc.encode(bit, p);
305 if bit == 1 {
307 p = (p + 100).min(4095);
308 } else {
309 p = if p > 101 { p - 100 } else { 1 };
310 }
311 }
312 }
313 let compressed = enc.finish();
314
315 let mut dec = ArithmeticDecoder::new(&compressed);
316 let mut p: u32 = 2048;
317 for (i, &byte) in data.iter().enumerate() {
318 let mut decoded: u8 = 0;
319 for bpos in 0..8 {
320 let bit = dec.decode(p);
321 decoded |= bit << (7 - bpos);
322 if bit == 1 {
323 p = (p + 100).min(4095);
324 } else {
325 p = if p > 101 { p - 100 } else { 1 };
326 }
327 }
328 assert_eq!(decoded, byte, "byte mismatch at index {i}");
329 }
330 }
331}