1use crate::error::{CodecError, CodecResult};
12
13#[rustfmt::skip]
24const ONE_STATE: [u8; 256] = [
25 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
26 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
27 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
28 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
29 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
30 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96,
31 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112,
32 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128,
33 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144,
34 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160,
35 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176,
36 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192,
37 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208,
38 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224,
39 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240,
40 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 254, 255,
41];
42
43#[rustfmt::skip]
45const ZERO_STATE: [u8; 256] = [
46 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
47 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
48 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46,
49 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62,
50 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78,
51 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94,
52 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110,
53 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
54 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142,
55 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158,
56 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174,
57 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190,
58 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206,
59 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222,
60 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238,
61 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254,
62];
63
64const RANGE_BOTTOM: u32 = 0x100;
66
67pub struct SimpleRangeEncoder {
76 low: u32,
78 range: u32,
80 outstanding: u32,
82 buf: Vec<u8>,
84 defer_first: bool,
86 first_byte: u8,
88}
89
90impl SimpleRangeEncoder {
91 pub fn new() -> Self {
93 Self {
94 low: 0,
95 range: 0xFF00,
96 outstanding: 0,
97 buf: Vec::new(),
98 defer_first: true,
99 first_byte: 0,
100 }
101 }
102
103 fn shift_low(&mut self) {
105 if (self.low >> 8) >= 0xFF {
107 self.outstanding += 1;
109 } else {
110 let carry = (self.low >> 16) as u8; if self.defer_first {
113 self.first_byte = ((self.low >> 8) as u8).wrapping_add(carry);
114 self.defer_first = false;
115 } else {
116 self.buf.push(self.first_byte);
117 for _ in 0..self.outstanding {
118 self.buf.push(0xFFu8.wrapping_add(carry));
119 }
120 self.first_byte = (self.low >> 8) as u8;
121 }
122 self.outstanding = 0;
123 }
124 self.low = (self.low & 0xFF) << 8;
125 }
126
127 #[inline]
129 fn renorm(&mut self) {
130 while self.range < u32::from(RANGE_BOTTOM) {
131 self.range <<= 8;
132 self.shift_low();
133 }
134 }
135
136 pub fn put_bit(&mut self, state: &mut u8, bit: bool) {
138 let s = u32::from(*state);
139 let split = ((self.range >> 8) * s) & 0xFFFF_FF00;
141
142 if bit {
143 self.low += self.range - split;
145 self.range = split;
146 *state = ONE_STATE[*state as usize];
147 } else {
148 self.range -= split;
150 *state = ZERO_STATE[*state as usize];
151 }
152 self.renorm();
153 }
154
155 pub fn put_symbol(&mut self, states: &mut [u8], value: i32) {
157 let is_zero = value == 0;
159 self.put_bit(&mut states[0], is_zero);
160 if is_zero {
161 return;
162 }
163
164 let sign = value < 0;
165 let abs_val = value.unsigned_abs();
166
167 let e = if abs_val > 0 {
169 32 - abs_val.leading_zeros() as usize - 1
170 } else {
171 0
172 };
173
174 for i in 0..e {
175 let si = 1 + i.min(states.len() - 2);
176 self.put_bit(&mut states[si], false); }
178 if e < 31 {
179 let si = 1 + e.min(states.len() - 2);
180 self.put_bit(&mut states[si], true); }
182
183 for i in (0..e).rev() {
185 let bit = (abs_val >> i) & 1 != 0;
186 let mut bypass = 128u8;
187 self.put_bit(&mut bypass, bit);
188 }
189
190 let si = (e + 1).min(states.len() - 1);
192 self.put_bit(&mut states[si], sign);
193 }
194
195 pub fn finish(mut self) -> Vec<u8> {
197 self.range = u32::from(RANGE_BOTTOM);
199 for _ in 0..5 {
200 self.shift_low();
201 }
202 self.buf.push(self.first_byte);
204 for _ in 0..self.outstanding {
205 self.buf.push(0xFF);
206 }
207
208 let mut result = Vec::with_capacity(self.buf.len() + 2);
211 result.extend_from_slice(&self.buf);
212 if result.len() < 2 {
213 result.resize(2, 0);
214 }
215 result
216 }
217}
218
219pub struct SimpleRangeDecoder {
228 data: Vec<u8>,
230 pos: usize,
232 low: u32,
234 range: u32,
236}
237
238impl SimpleRangeDecoder {
239 pub fn new(data: &[u8]) -> CodecResult<Self> {
241 if data.len() < 2 {
242 return Err(CodecError::InvalidBitstream(
243 "range coder needs at least 2 bytes".to_string(),
244 ));
245 }
246 let low = (u32::from(data[0]) << 8) | u32::from(data[1]);
247 Ok(Self {
248 data: data.to_vec(),
249 pos: 2,
250 low,
251 range: 0xFF00,
252 })
253 }
254
255 #[inline]
257 fn read_byte(&mut self) -> u8 {
258 if self.pos < self.data.len() {
259 let b = self.data[self.pos];
260 self.pos += 1;
261 b
262 } else {
263 0
264 }
265 }
266
267 #[inline]
269 fn renorm(&mut self) {
270 while self.range < u32::from(RANGE_BOTTOM) {
271 self.range <<= 8;
272 self.low = (self.low << 8) | u32::from(self.read_byte());
273 }
274 }
275
276 pub fn get_bit(&mut self, state: &mut u8) -> CodecResult<bool> {
278 let s = u32::from(*state);
279 let split = ((self.range >> 8) * s) & 0xFFFF_FF00;
280
281 if self.low < self.range - split {
282 self.range -= split;
284 *state = ZERO_STATE[*state as usize];
285 self.renorm();
286 Ok(false)
287 } else {
288 self.low -= self.range - split;
290 self.range = split;
291 *state = ONE_STATE[*state as usize];
292 self.renorm();
293 Ok(true)
294 }
295 }
296
297 pub fn get_symbol(&mut self, states: &mut [u8]) -> CodecResult<i32> {
299 let is_zero = self.get_bit(&mut states[0])?;
301 if is_zero {
302 return Ok(0);
303 }
304
305 let mut e = 0usize;
307 while e < 31 {
308 let si = 1 + e.min(states.len() - 2);
309 if self.get_bit(&mut states[si])? {
310 break; }
312 e += 1;
313 }
314
315 let mut value: u32 = 1; for _ in 0..e {
318 let mut bypass = 128u8;
319 let bit = self.get_bit(&mut bypass)?;
320 value = (value << 1) | (bit as u32);
321 }
322
323 let si = (e + 1).min(states.len() - 1);
325 let sign = self.get_bit(&mut states[si])?;
326
327 if sign {
328 Ok(-(value as i32))
329 } else {
330 Ok(value as i32)
331 }
332 }
333
334 #[must_use]
336 pub fn bytes_consumed(&self) -> usize {
337 self.pos
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344
345 #[test]
346 #[ignore]
347 fn test_state_tables_identity_at_128() {
348 assert!(ONE_STATE[128] >= 128);
350 assert!(ZERO_STATE[128] <= 128);
351 }
352
353 #[test]
354 #[ignore]
355 fn test_state_tables_monotone() {
356 for i in 0..255 {
358 assert!(ONE_STATE[i + 1] >= ONE_STATE[i]);
359 }
360 for i in 0..255 {
362 assert!(ZERO_STATE[i + 1] >= ZERO_STATE[i]);
363 }
364 }
365
366 #[test]
367 #[ignore]
368 fn test_simple_range_coder_single_bit_roundtrip() {
369 let bits = [true, false, true, true, false, false, true];
370
371 let mut enc = SimpleRangeEncoder::new();
372 let mut estate = 128u8;
373 for &b in &bits {
374 enc.put_bit(&mut estate, b);
375 }
376 let encoded = enc.finish();
377
378 let mut dec = SimpleRangeDecoder::new(&encoded).expect("valid data");
379 let mut dstate = 128u8;
380 for &expected in &bits {
381 let got = dec.get_bit(&mut dstate).expect("decode ok");
382 assert_eq!(expected, got);
383 }
384 }
385
386 #[test]
387 #[ignore]
388 fn test_simple_range_coder_symbol_roundtrip() {
389 let test_values = [0, 1, -1, 2, -2, 10, -10, 127, -128, 255, -255, 1000, -1000];
390
391 for &val in &test_values {
392 let mut enc = SimpleRangeEncoder::new();
393 let mut states = vec![128u8; 32];
394 enc.put_symbol(&mut states, val);
395 let encoded = enc.finish();
396
397 let mut dec = SimpleRangeDecoder::new(&encoded).expect("valid data");
398 let mut dec_states = vec![128u8; 32];
399 let decoded = dec.get_symbol(&mut dec_states).expect("decode ok");
400 assert_eq!(
401 val, decoded,
402 "round-trip failed for value {val}: got {decoded}"
403 );
404 }
405 }
406
407 #[test]
408 #[ignore]
409 fn test_simple_range_coder_multi_symbol_roundtrip() {
410 let values = [0, 5, -3, 100, -200, 0, 1, -1, 42];
411
412 let mut enc = SimpleRangeEncoder::new();
413 let mut enc_states = vec![128u8; 32];
414 for &v in &values {
415 enc.put_symbol(&mut enc_states, v);
416 }
417 let encoded = enc.finish();
418
419 let mut dec = SimpleRangeDecoder::new(&encoded).expect("valid data");
420 let mut dec_states = vec![128u8; 32];
421 for &expected in &values {
422 let got = dec.get_symbol(&mut dec_states).expect("decode ok");
423 assert_eq!(expected, got);
424 }
425 }
426
427 #[test]
428 #[ignore]
429 fn test_simple_range_coder_many_zeros() {
430 let mut enc = SimpleRangeEncoder::new();
431 let mut states = vec![128u8; 32];
432 for _ in 0..100 {
433 enc.put_symbol(&mut states, 0);
434 }
435 let encoded = enc.finish();
436
437 let mut dec = SimpleRangeDecoder::new(&encoded).expect("valid data");
438 let mut dec_states = vec![128u8; 32];
439 for _ in 0..100 {
440 let v = dec.get_symbol(&mut dec_states).expect("decode ok");
441 assert_eq!(v, 0);
442 }
443 }
444
445 #[test]
446 #[ignore]
447 fn test_decoder_too_short() {
448 assert!(SimpleRangeDecoder::new(&[]).is_err());
449 assert!(SimpleRangeDecoder::new(&[0]).is_err());
450 }
451
452 #[test]
453 #[ignore]
454 fn test_range_coder_adaptive_state_changes() {
455 let mut enc = SimpleRangeEncoder::new();
456 let mut state = 128u8;
457 for _ in 0..50 {
458 enc.put_bit(&mut state, true);
459 }
460 assert!(state > 128);
462 }
463}