1use crate::error::{CodecError, CodecResult};
12
13#[rustfmt::skip]
24const ONE_STATE: [u8; 256] = [
25 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
26 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
27 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
28 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
29 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
30 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96,
31 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112,
32 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128,
33 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144,
34 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160,
35 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176,
36 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192,
37 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208,
38 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224,
39 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240,
40 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 254, 255,
41];
42
43#[rustfmt::skip]
45const ZERO_STATE: [u8; 256] = [
46 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
47 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
48 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46,
49 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62,
50 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78,
51 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94,
52 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110,
53 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
54 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142,
55 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158,
56 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174,
57 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190,
58 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206,
59 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222,
60 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238,
61 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254,
62];
63
64const RANGE_BOTTOM: u32 = 0x100;
66
67pub struct SimpleRangeEncoder {
76 low: u32,
78 range: u32,
80 outstanding: u32,
82 buf: Vec<u8>,
84 defer_first: bool,
86 first_byte: u8,
88}
89
90impl SimpleRangeEncoder {
91 pub fn new() -> Self {
93 Self {
94 low: 0,
95 range: 0xFF00,
96 outstanding: 0,
97 buf: Vec::new(),
98 defer_first: true,
99 first_byte: 0,
100 }
101 }
102
103 fn shift_low(&mut self) {
105 if (self.low >> 8) >= 0xFF {
107 self.outstanding += 1;
109 } else {
110 let carry = (self.low >> 16) as u8; if self.defer_first {
113 self.first_byte = ((self.low >> 8) as u8).wrapping_add(carry);
114 self.defer_first = false;
115 } else {
116 self.buf.push(self.first_byte);
117 for _ in 0..self.outstanding {
118 self.buf.push(0xFFu8.wrapping_add(carry));
119 }
120 self.first_byte = (self.low >> 8) as u8;
121 }
122 self.outstanding = 0;
123 }
124 self.low = (self.low & 0xFF) << 8;
125 }
126
127 #[inline]
129 fn renorm(&mut self) {
130 while self.range < u32::from(RANGE_BOTTOM) {
131 self.range <<= 8;
132 self.shift_low();
133 }
134 }
135
136 pub fn put_bit(&mut self, state: &mut u8, bit: bool) {
138 let s = u32::from(*state);
139 let raw_split = ((self.range >> 8) * s) & 0xFFFF_FF00;
143 let split = raw_split.clamp(1, self.range.saturating_sub(1).max(1));
144
145 if bit {
146 self.low += self.range - split;
148 self.range = split;
149 *state = ONE_STATE[*state as usize];
150 } else {
151 self.range -= split;
153 *state = ZERO_STATE[*state as usize];
154 }
155 self.renorm();
156 }
157
158 pub fn put_symbol(&mut self, states: &mut [u8], value: i32) {
160 let is_zero = value == 0;
162 self.put_bit(&mut states[0], is_zero);
163 if is_zero {
164 return;
165 }
166
167 let sign = value < 0;
168 let abs_val = value.unsigned_abs();
169
170 let e = if abs_val > 0 {
172 32 - abs_val.leading_zeros() as usize - 1
173 } else {
174 0
175 };
176
177 for i in 0..e {
178 let si = 1 + i.min(states.len() - 2);
179 self.put_bit(&mut states[si], false); }
181 if e < 31 {
182 let si = 1 + e.min(states.len() - 2);
183 self.put_bit(&mut states[si], true); }
185
186 for i in (0..e).rev() {
188 let bit = (abs_val >> i) & 1 != 0;
189 let mut bypass = 128u8;
190 self.put_bit(&mut bypass, bit);
191 }
192
193 let si = (e + 1).min(states.len() - 1);
195 self.put_bit(&mut states[si], sign);
196 }
197
198 pub fn finish(mut self) -> Vec<u8> {
200 self.range = u32::from(RANGE_BOTTOM);
202 for _ in 0..5 {
203 self.shift_low();
204 }
205 self.buf.push(self.first_byte);
207 for _ in 0..self.outstanding {
208 self.buf.push(0xFF);
209 }
210
211 let mut result = Vec::with_capacity(self.buf.len() + 2);
214 result.extend_from_slice(&self.buf);
215 if result.len() < 2 {
216 result.resize(2, 0);
217 }
218 result
219 }
220}
221
222pub struct SimpleRangeDecoder {
231 data: Vec<u8>,
233 pos: usize,
235 low: u32,
237 range: u32,
239}
240
241impl SimpleRangeDecoder {
242 pub fn new(data: &[u8]) -> CodecResult<Self> {
244 if data.len() < 2 {
245 return Err(CodecError::InvalidBitstream(
246 "range coder needs at least 2 bytes".to_string(),
247 ));
248 }
249 let low = (u32::from(data[0]) << 8) | u32::from(data[1]);
250 Ok(Self {
251 data: data.to_vec(),
252 pos: 2,
253 low,
254 range: 0xFF00,
255 })
256 }
257
258 #[inline]
260 fn read_byte(&mut self) -> u8 {
261 if self.pos < self.data.len() {
262 let b = self.data[self.pos];
263 self.pos += 1;
264 b
265 } else {
266 0
267 }
268 }
269
270 #[inline]
272 fn renorm(&mut self) {
273 while self.range < u32::from(RANGE_BOTTOM) {
274 self.range <<= 8;
275 self.low = (self.low << 8) | u32::from(self.read_byte());
276 }
277 }
278
279 pub fn get_bit(&mut self, state: &mut u8) -> CodecResult<bool> {
281 let s = u32::from(*state);
282 let raw_split = ((self.range >> 8) * s) & 0xFFFF_FF00;
284 let split = raw_split.clamp(1, self.range.saturating_sub(1).max(1));
285
286 if self.low < self.range - split {
287 self.range -= split;
289 *state = ZERO_STATE[*state as usize];
290 self.renorm();
291 Ok(false)
292 } else {
293 self.low -= self.range - split;
295 self.range = split;
296 *state = ONE_STATE[*state as usize];
297 self.renorm();
298 Ok(true)
299 }
300 }
301
302 pub fn get_symbol(&mut self, states: &mut [u8]) -> CodecResult<i32> {
304 let is_zero = self.get_bit(&mut states[0])?;
306 if is_zero {
307 return Ok(0);
308 }
309
310 let mut e = 0usize;
312 while e < 31 {
313 let si = 1 + e.min(states.len() - 2);
314 if self.get_bit(&mut states[si])? {
315 break; }
317 e += 1;
318 }
319
320 let mut value: u32 = 1; for _ in 0..e {
323 let mut bypass = 128u8;
324 let bit = self.get_bit(&mut bypass)?;
325 value = (value << 1) | (bit as u32);
326 }
327
328 let si = (e + 1).min(states.len() - 1);
330 let sign = self.get_bit(&mut states[si])?;
331
332 if sign {
333 Ok(-(value as i32))
334 } else {
335 Ok(value as i32)
336 }
337 }
338
339 #[must_use]
341 pub fn bytes_consumed(&self) -> usize {
342 self.pos
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349
350 #[test]
351 #[ignore]
352 fn test_state_tables_identity_at_128() {
353 assert!(ONE_STATE[128] >= 128);
355 assert!(ZERO_STATE[128] <= 128);
356 }
357
358 #[test]
359 #[ignore]
360 fn test_state_tables_monotone() {
361 for i in 0..255 {
363 assert!(ONE_STATE[i + 1] >= ONE_STATE[i]);
364 }
365 for i in 0..255 {
367 assert!(ZERO_STATE[i + 1] >= ZERO_STATE[i]);
368 }
369 }
370
371 #[test]
372 #[ignore]
373 fn test_simple_range_coder_single_bit_roundtrip() {
374 let bits = [true, false, true, true, false, false, true];
375
376 let mut enc = SimpleRangeEncoder::new();
377 let mut estate = 128u8;
378 for &b in &bits {
379 enc.put_bit(&mut estate, b);
380 }
381 let encoded = enc.finish();
382
383 let mut dec = SimpleRangeDecoder::new(&encoded).expect("valid data");
384 let mut dstate = 128u8;
385 for &expected in &bits {
386 let got = dec.get_bit(&mut dstate).expect("decode ok");
387 assert_eq!(expected, got);
388 }
389 }
390
391 #[test]
392 #[ignore]
393 fn test_simple_range_coder_symbol_roundtrip() {
394 let test_values = [0, 1, -1, 2, -2, 10, -10, 127, -128, 255, -255, 1000, -1000];
395
396 for &val in &test_values {
397 let mut enc = SimpleRangeEncoder::new();
398 let mut states = vec![128u8; 32];
399 enc.put_symbol(&mut states, val);
400 let encoded = enc.finish();
401
402 let mut dec = SimpleRangeDecoder::new(&encoded).expect("valid data");
403 let mut dec_states = vec![128u8; 32];
404 let decoded = dec.get_symbol(&mut dec_states).expect("decode ok");
405 assert_eq!(
406 val, decoded,
407 "round-trip failed for value {val}: got {decoded}"
408 );
409 }
410 }
411
412 #[test]
413 #[ignore]
414 fn test_simple_range_coder_multi_symbol_roundtrip() {
415 let values = [0, 5, -3, 100, -200, 0, 1, -1, 42];
416
417 let mut enc = SimpleRangeEncoder::new();
418 let mut enc_states = vec![128u8; 32];
419 for &v in &values {
420 enc.put_symbol(&mut enc_states, v);
421 }
422 let encoded = enc.finish();
423
424 let mut dec = SimpleRangeDecoder::new(&encoded).expect("valid data");
425 let mut dec_states = vec![128u8; 32];
426 for &expected in &values {
427 let got = dec.get_symbol(&mut dec_states).expect("decode ok");
428 assert_eq!(expected, got);
429 }
430 }
431
432 #[test]
433 #[ignore]
434 fn test_simple_range_coder_many_zeros() {
435 let mut enc = SimpleRangeEncoder::new();
436 let mut states = vec![128u8; 32];
437 for _ in 0..100 {
438 enc.put_symbol(&mut states, 0);
439 }
440 let encoded = enc.finish();
441
442 let mut dec = SimpleRangeDecoder::new(&encoded).expect("valid data");
443 let mut dec_states = vec![128u8; 32];
444 for _ in 0..100 {
445 let v = dec.get_symbol(&mut dec_states).expect("decode ok");
446 assert_eq!(v, 0);
447 }
448 }
449
450 #[test]
451 #[ignore]
452 fn test_decoder_too_short() {
453 assert!(SimpleRangeDecoder::new(&[]).is_err());
454 assert!(SimpleRangeDecoder::new(&[0]).is_err());
455 }
456
457 #[test]
458 #[ignore]
459 fn test_range_coder_adaptive_state_changes() {
460 let mut enc = SimpleRangeEncoder::new();
461 let mut state = 128u8;
462 for _ in 0..50 {
463 enc.put_bit(&mut state, true);
464 }
465 assert!(state > 128);
467 }
468}