use crate::error::{CodecError, CodecResult};
#[rustfmt::skip]
const ONE_STATE: [u8; 256] = [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96,
97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112,
113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128,
129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144,
145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160,
161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176,
177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192,
193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208,
209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224,
225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240,
241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 254, 255,
];
#[rustfmt::skip]
const ZERO_STATE: [u8; 256] = [
0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46,
47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62,
63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78,
79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94,
95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110,
111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142,
143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158,
159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174,
175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190,
191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206,
207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222,
223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238,
239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254,
];
const RANGE_BOTTOM: u32 = 0x100;
pub struct SimpleRangeEncoder {
low: u32,
range: u32,
outstanding: u32,
buf: Vec<u8>,
defer_first: bool,
first_byte: u8,
}
impl SimpleRangeEncoder {
pub fn new() -> Self {
Self {
low: 0,
range: 0xFF00,
outstanding: 0,
buf: Vec::new(),
defer_first: true,
first_byte: 0,
}
}
fn shift_low(&mut self) {
if (self.low >> 8) >= 0xFF {
self.outstanding += 1;
} else {
let carry = (self.low >> 16) as u8; if self.defer_first {
self.first_byte = ((self.low >> 8) as u8).wrapping_add(carry);
self.defer_first = false;
} else {
self.buf.push(self.first_byte);
for _ in 0..self.outstanding {
self.buf.push(0xFFu8.wrapping_add(carry));
}
self.first_byte = (self.low >> 8) as u8;
}
self.outstanding = 0;
}
self.low = (self.low & 0xFF) << 8;
}
#[inline]
fn renorm(&mut self) {
while self.range < u32::from(RANGE_BOTTOM) {
self.range <<= 8;
self.shift_low();
}
}
pub fn put_bit(&mut self, state: &mut u8, bit: bool) {
let s = u32::from(*state);
let raw_split = ((self.range >> 8) * s) & 0xFFFF_FF00;
let split = raw_split.clamp(1, self.range.saturating_sub(1).max(1));
if bit {
self.low += self.range - split;
self.range = split;
*state = ONE_STATE[*state as usize];
} else {
self.range -= split;
*state = ZERO_STATE[*state as usize];
}
self.renorm();
}
pub fn put_symbol(&mut self, states: &mut [u8], value: i32) {
let is_zero = value == 0;
self.put_bit(&mut states[0], is_zero);
if is_zero {
return;
}
let sign = value < 0;
let abs_val = value.unsigned_abs();
let e = if abs_val > 0 {
32 - abs_val.leading_zeros() as usize - 1
} else {
0
};
for i in 0..e {
let si = 1 + i.min(states.len() - 2);
self.put_bit(&mut states[si], false); }
if e < 31 {
let si = 1 + e.min(states.len() - 2);
self.put_bit(&mut states[si], true); }
for i in (0..e).rev() {
let bit = (abs_val >> i) & 1 != 0;
let mut bypass = 128u8;
self.put_bit(&mut bypass, bit);
}
let si = (e + 1).min(states.len() - 1);
self.put_bit(&mut states[si], sign);
}
pub fn finish(mut self) -> Vec<u8> {
self.range = u32::from(RANGE_BOTTOM);
for _ in 0..5 {
self.shift_low();
}
self.buf.push(self.first_byte);
for _ in 0..self.outstanding {
self.buf.push(0xFF);
}
let mut result = Vec::with_capacity(self.buf.len() + 2);
result.extend_from_slice(&self.buf);
if result.len() < 2 {
result.resize(2, 0);
}
result
}
}
pub struct SimpleRangeDecoder {
data: Vec<u8>,
pos: usize,
low: u32,
range: u32,
}
impl SimpleRangeDecoder {
pub fn new(data: &[u8]) -> CodecResult<Self> {
if data.len() < 2 {
return Err(CodecError::InvalidBitstream(
"range coder needs at least 2 bytes".to_string(),
));
}
let low = (u32::from(data[0]) << 8) | u32::from(data[1]);
Ok(Self {
data: data.to_vec(),
pos: 2,
low,
range: 0xFF00,
})
}
#[inline]
fn read_byte(&mut self) -> u8 {
if self.pos < self.data.len() {
let b = self.data[self.pos];
self.pos += 1;
b
} else {
0
}
}
#[inline]
fn renorm(&mut self) {
while self.range < u32::from(RANGE_BOTTOM) {
self.range <<= 8;
self.low = (self.low << 8) | u32::from(self.read_byte());
}
}
pub fn get_bit(&mut self, state: &mut u8) -> CodecResult<bool> {
let s = u32::from(*state);
let raw_split = ((self.range >> 8) * s) & 0xFFFF_FF00;
let split = raw_split.clamp(1, self.range.saturating_sub(1).max(1));
if self.low < self.range - split {
self.range -= split;
*state = ZERO_STATE[*state as usize];
self.renorm();
Ok(false)
} else {
self.low -= self.range - split;
self.range = split;
*state = ONE_STATE[*state as usize];
self.renorm();
Ok(true)
}
}
pub fn get_symbol(&mut self, states: &mut [u8]) -> CodecResult<i32> {
let is_zero = self.get_bit(&mut states[0])?;
if is_zero {
return Ok(0);
}
let mut e = 0usize;
while e < 31 {
let si = 1 + e.min(states.len() - 2);
if self.get_bit(&mut states[si])? {
break; }
e += 1;
}
let mut value: u32 = 1; for _ in 0..e {
let mut bypass = 128u8;
let bit = self.get_bit(&mut bypass)?;
value = (value << 1) | (bit as u32);
}
let si = (e + 1).min(states.len() - 1);
let sign = self.get_bit(&mut states[si])?;
if sign {
Ok(-(value as i32))
} else {
Ok(value as i32)
}
}
#[must_use]
pub fn bytes_consumed(&self) -> usize {
self.pos
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore]
fn test_state_tables_identity_at_128() {
assert!(ONE_STATE[128] >= 128);
assert!(ZERO_STATE[128] <= 128);
}
#[test]
#[ignore]
fn test_state_tables_monotone() {
for i in 0..255 {
assert!(ONE_STATE[i + 1] >= ONE_STATE[i]);
}
for i in 0..255 {
assert!(ZERO_STATE[i + 1] >= ZERO_STATE[i]);
}
}
#[test]
#[ignore]
fn test_simple_range_coder_single_bit_roundtrip() {
let bits = [true, false, true, true, false, false, true];
let mut enc = SimpleRangeEncoder::new();
let mut estate = 128u8;
for &b in &bits {
enc.put_bit(&mut estate, b);
}
let encoded = enc.finish();
let mut dec = SimpleRangeDecoder::new(&encoded).expect("valid data");
let mut dstate = 128u8;
for &expected in &bits {
let got = dec.get_bit(&mut dstate).expect("decode ok");
assert_eq!(expected, got);
}
}
#[test]
#[ignore]
fn test_simple_range_coder_symbol_roundtrip() {
let test_values = [0, 1, -1, 2, -2, 10, -10, 127, -128, 255, -255, 1000, -1000];
for &val in &test_values {
let mut enc = SimpleRangeEncoder::new();
let mut states = vec![128u8; 32];
enc.put_symbol(&mut states, val);
let encoded = enc.finish();
let mut dec = SimpleRangeDecoder::new(&encoded).expect("valid data");
let mut dec_states = vec![128u8; 32];
let decoded = dec.get_symbol(&mut dec_states).expect("decode ok");
assert_eq!(
val, decoded,
"round-trip failed for value {val}: got {decoded}"
);
}
}
#[test]
#[ignore]
fn test_simple_range_coder_multi_symbol_roundtrip() {
let values = [0, 5, -3, 100, -200, 0, 1, -1, 42];
let mut enc = SimpleRangeEncoder::new();
let mut enc_states = vec![128u8; 32];
for &v in &values {
enc.put_symbol(&mut enc_states, v);
}
let encoded = enc.finish();
let mut dec = SimpleRangeDecoder::new(&encoded).expect("valid data");
let mut dec_states = vec![128u8; 32];
for &expected in &values {
let got = dec.get_symbol(&mut dec_states).expect("decode ok");
assert_eq!(expected, got);
}
}
#[test]
#[ignore]
fn test_simple_range_coder_many_zeros() {
let mut enc = SimpleRangeEncoder::new();
let mut states = vec![128u8; 32];
for _ in 0..100 {
enc.put_symbol(&mut states, 0);
}
let encoded = enc.finish();
let mut dec = SimpleRangeDecoder::new(&encoded).expect("valid data");
let mut dec_states = vec![128u8; 32];
for _ in 0..100 {
let v = dec.get_symbol(&mut dec_states).expect("decode ok");
assert_eq!(v, 0);
}
}
#[test]
#[ignore]
fn test_decoder_too_short() {
assert!(SimpleRangeDecoder::new(&[]).is_err());
assert!(SimpleRangeDecoder::new(&[0]).is_err());
}
#[test]
#[ignore]
fn test_range_coder_adaptive_state_changes() {
let mut enc = SimpleRangeEncoder::new();
let mut state = 128u8;
for _ in 0..50 {
enc.put_bit(&mut state, true);
}
assert!(state > 128);
}
}