#![forbid(unsafe_code)]
#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::cast_sign_loss)]
#![allow(clippy::cast_possible_wrap)]
pub const MAX_RICE_PARAM: u8 = 14;
#[inline]
pub fn zigzag_encode(v: i32) -> u32 {
if v >= 0 {
(v as u32) << 1
} else {
((-v - 1) as u32) << 1 | 1
}
}
#[inline]
pub fn zigzag_decode(u: u32) -> i32 {
if u & 1 == 0 {
(u >> 1) as i32
} else {
-((u >> 1) as i32) - 1
}
}
#[must_use]
pub fn rice_bit_cost(residuals: &[i32], k: u8) -> u64 {
residuals
.iter()
.map(|&r| {
let u = zigzag_encode(r);
let quotient = u >> k;
1u64 + u64::from(k) + u64::from(quotient)
})
.sum()
}
#[must_use]
pub fn optimal_rice_param(residuals: &[i32]) -> u8 {
if residuals.is_empty() {
return 0;
}
let mut best_k = 0u8;
let mut best_cost = u64::MAX;
for k in 0..=MAX_RICE_PARAM {
let cost = rice_bit_cost(residuals, k);
if cost < best_cost {
best_cost = cost;
best_k = k;
}
}
best_k
}
#[must_use]
pub fn rice_encode(residuals: &[i32], k: u8) -> Vec<u8> {
let mut bits: Vec<bool> = Vec::new();
for &r in residuals {
let u = zigzag_encode(r);
let quotient = u >> k;
let remainder = u & ((1u32 << k) - 1);
for _ in 0..quotient {
bits.push(true);
}
bits.push(false);
for bit_idx in (0..k).rev() {
bits.push((remainder >> bit_idx) & 1 != 0);
}
}
let mut out = Vec::with_capacity((bits.len() + 7) / 8);
let mut byte = 0u8;
let mut fill = 0u8;
for bit in bits {
byte = (byte << 1) | u8::from(bit);
fill += 1;
if fill == 8 {
out.push(byte);
byte = 0;
fill = 0;
}
}
if fill > 0 {
out.push(byte << (8 - fill));
}
out
}
pub struct RiceDecoder<'a> {
data: &'a [u8],
byte_pos: usize,
bit_pos: u8,
}
impl<'a> RiceDecoder<'a> {
#[must_use]
pub fn new(data: &'a [u8]) -> Self {
Self {
data,
byte_pos: 0,
bit_pos: 0,
}
}
fn read_bit(&mut self) -> Option<bool> {
if self.byte_pos >= self.data.len() {
return None;
}
let bit = (self.data[self.byte_pos] >> (7 - self.bit_pos)) & 1 != 0;
self.bit_pos += 1;
if self.bit_pos == 8 {
self.byte_pos += 1;
self.bit_pos = 0;
}
Some(bit)
}
pub fn decode_one(&mut self, k: u8) -> Option<i32> {
let mut quotient = 0u32;
loop {
let bit = self.read_bit()?;
if !bit {
break;
}
quotient += 1;
if quotient > 1024 * 1024 {
return None; }
}
let mut remainder = 0u32;
for _ in 0..k {
let bit = self.read_bit()?;
remainder = (remainder << 1) | u32::from(bit);
}
let u = (quotient << k) | remainder;
Some(zigzag_decode(u))
}
pub fn decode_n(&mut self, count: usize, k: u8) -> Vec<i32> {
(0..count).map_while(|_| self.decode_one(k)).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_zigzag_encode_decode_identity() {
for v in [-100i32, -1, 0, 1, 100, i16::MAX as i32] {
let u = zigzag_encode(v);
let back = zigzag_decode(u);
assert_eq!(back, v, "zigzag roundtrip failed for {v}");
}
}
#[test]
fn test_zigzag_non_negative_output() {
for v in [-200i32, -100, -1, 0, 1, 100, 200] {
let _u = zigzag_encode(v);
}
}
#[test]
fn test_rice_bit_cost_zero_residuals() {
let res = vec![0i32; 16];
let cost = rice_bit_cost(&res, 0);
assert_eq!(cost, 16);
}
#[test]
fn test_rice_encode_decode_roundtrip() {
let residuals = vec![0i32, 1, -1, 2, -2, 5, -5, 10, -10];
let k = optimal_rice_param(&residuals);
let encoded = rice_encode(&residuals, k);
let mut dec = RiceDecoder::new(&encoded);
let decoded = dec.decode_n(residuals.len(), k);
assert_eq!(decoded, residuals, "Rice roundtrip must be lossless");
}
#[test]
fn test_rice_encode_empty() {
let encoded = rice_encode(&[], 4);
assert!(encoded.is_empty());
}
#[test]
fn test_optimal_rice_param_small_residuals() {
let residuals = vec![0i32; 32];
let k = optimal_rice_param(&residuals);
assert_eq!(k, 0, "All-zero residuals → k=0 is optimal");
}
#[test]
fn test_optimal_rice_param_large_residuals() {
let residuals: Vec<i32> = (0..32).map(|i| i * 1000).collect();
let k_large = optimal_rice_param(&residuals);
let k_small = optimal_rice_param(&vec![0i32; 32]);
assert!(k_large >= k_small, "Large residuals should use larger k");
}
#[test]
fn test_rice_decode_n_partial() {
let residuals = vec![1i32, 2, 3];
let k = 1;
let encoded = rice_encode(&residuals, k);
let mut dec = RiceDecoder::new(&encoded);
let decoded = dec.decode_n(100, k);
assert!(decoded.len() >= residuals.len());
assert_eq!(&decoded[..residuals.len()], &residuals[..]);
}
}