use alloc::vec::Vec;
use crate::bit_io::{BitReader, BitWriter};
use crate::frame::DecodeError;
use crate::{MAX_PARTITION_ORDER, MAX_RICE_K};
pub const RICE_K_BITS: u8 = 5;
#[inline]
pub fn zigzag(n: i32) -> u32 {
((n << 1) ^ (n >> 31)) as u32
}
#[inline]
pub fn unzigzag(z: u32) -> i32 {
((z >> 1) as i32) ^ -((z & 1) as i32)
}
#[inline]
fn rice_cost(zigzag_vals: &[u32], k: u8) -> usize {
let mut q_sum: u64 = 0;
for &v in zigzag_vals {
q_sum += (v >> k) as u64;
}
q_sum as usize + zigzag_vals.len() * (1 + k as usize)
}
pub fn select_k(zigzag_vals: &[u32]) -> (u8, usize) {
if zigzag_vals.is_empty() {
return (0, 0);
}
let sum: u64 = zigzag_vals.iter().map(|&v| v as u64).sum();
let n = zigzag_vals.len() as u64;
if sum == 0 {
return (0, zigzag_vals.len());
}
let mean = sum / n;
let k_seed = if mean == 0 {
0
} else {
(63 - mean.leading_zeros() as u8).min(MAX_RICE_K)
};
let mut best_k = k_seed;
let mut best_cost = rice_cost(zigzag_vals, k_seed);
while best_k > 0 {
let c = rice_cost(zigzag_vals, best_k - 1);
if c <= best_cost {
best_cost = c;
best_k -= 1;
} else {
break;
}
}
if best_k == k_seed {
while best_k < MAX_RICE_K {
let c = rice_cost(zigzag_vals, best_k + 1);
if c < best_cost {
best_cost = c;
best_k += 1;
} else {
break;
}
}
}
(best_k, best_cost)
}
#[cfg(test)]
fn select_k_exhaustive(zigzag_vals: &[u32]) -> (u8, usize) {
let mut best_k = 0u8;
let mut best_cost = usize::MAX;
for k in 0..=MAX_RICE_K {
let cost = rice_cost(zigzag_vals, k);
if cost < best_cost {
best_cost = cost;
best_k = k;
}
}
(best_k, best_cost)
}
#[cfg(test)]
pub(crate) fn rice_encode(residuals: &[i32], partition_order: u8) -> Vec<u8> {
let zigzag_buf: Vec<u32> = residuals.iter().map(|&r| zigzag(r)).collect();
let mut out = Vec::new();
rice_encode_zigzag_into(&zigzag_buf, partition_order, &mut out);
out
}
pub fn rice_encode_zigzag_into(zigzag_vals: &[u32], partition_order: u8, out: &mut Vec<u8>) {
debug_assert!(
partition_order <= MAX_PARTITION_ORDER,
"partition_order={partition_order} exceeds MAX_PARTITION_ORDER={MAX_PARTITION_ORDER}"
);
let n_partitions = 1usize << partition_order;
let partition_size = zigzag_vals.len() / n_partitions;
debug_assert_eq!(
zigzag_vals.len() % n_partitions,
0,
"zigzag_vals.len()={} is not a multiple of partition count={}",
zigzag_vals.len(),
n_partitions
);
let mut w = BitWriter::new(out);
for p in 0..n_partitions {
let partition = &zigzag_vals[p * partition_size..(p + 1) * partition_size];
let (k, _cost) = select_k(partition);
w.write_bits(k as u32, RICE_K_BITS);
write_rice_partition(&mut w, partition, k);
}
w.finish();
}
fn write_rice_partition(w: &mut BitWriter<'_>, zigzag_vals: &[u32], k: u8) {
let k_mask = if k > 0 { (1u32 << k) - 1 } else { 0 };
for &v in zigzag_vals {
let q = v >> k;
for _ in 0..q {
w.write_bit(false);
}
w.write_bit(true);
if k > 0 {
w.write_bits(v & k_mask, k);
}
}
}
#[cfg(test)]
pub(crate) fn rice_decode(
data: &[u8],
partition_order: u8,
total_count: usize,
) -> Result<Vec<i32>, DecodeError> {
let mut out = Vec::with_capacity(total_count);
rice_decode_into(data, partition_order, total_count, &mut out)?;
Ok(out)
}
pub fn rice_decode_into(
data: &[u8],
partition_order: u8,
total_count: usize,
out: &mut Vec<i32>,
) -> Result<(), DecodeError> {
out.clear();
out.reserve(total_count);
if partition_order > MAX_PARTITION_ORDER {
return Err(DecodeError::InvalidParameter);
}
let n_partitions = 1usize << partition_order;
if !total_count.is_multiple_of(n_partitions) {
return Err(DecodeError::InvalidParameter);
}
let partition_size = total_count / n_partitions;
let mut r = BitReader::new(data);
for _ in 0..n_partitions {
let k = r.read_bits(RICE_K_BITS).ok_or(DecodeError::Truncated)? as u8;
if k > MAX_RICE_K {
return Err(DecodeError::InvalidParameter);
}
let q_max: u32 = u32::MAX >> k;
for _ in 0..partition_size {
let q = r.read_unary().ok_or(DecodeError::Truncated)?;
if q > q_max {
return Err(DecodeError::InvalidParameter);
}
let remainder = if k > 0 {
r.read_bits(k).ok_or(DecodeError::Truncated)?
} else {
0
};
let z = ((q as u64) << k) | (remainder as u64);
debug_assert!(z <= u32::MAX as u64, "q<<k overflow: q={q} k={k}");
out.push(unzigzag(z as u32));
}
}
Ok(())
}
pub fn estimate_cost(zigzag_vals: &[u32], partition_order: u8) -> Option<usize> {
if partition_order > MAX_PARTITION_ORDER {
return None;
}
let n_partitions = 1usize << partition_order;
if !zigzag_vals.len().is_multiple_of(n_partitions) {
return None;
}
let partition_size = zigzag_vals.len() / n_partitions;
let mut total_bits = 0usize;
for p in 0..n_partitions {
let partition = &zigzag_vals[p * partition_size..(p + 1) * partition_size];
let (_k, cost) = select_k(partition);
total_bits += RICE_K_BITS as usize + cost;
}
Some(total_bits)
}
#[cfg(test)]
pub(crate) fn estimate_cost_from_residuals(
residuals: &[i32],
partition_order: u8,
) -> Option<usize> {
let zigzag_vals: Vec<u32> = residuals.iter().map(|&r| zigzag(r)).collect();
estimate_cost(&zigzag_vals, partition_order)
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
use alloc::vec::Vec;
#[test]
fn zigzag_roundtrip() {
let cases = [0, 1, -1, 2, -2, 100, -100, i32::MAX / 2, i32::MIN / 2];
for n in cases {
assert_eq!(unzigzag(zigzag(n)), n, "zigzag failed for {n}");
}
}
#[test]
fn zigzag_ordering() {
assert_eq!(zigzag(0), 0);
assert_eq!(zigzag(-1), 1);
assert_eq!(zigzag(1), 2);
assert_eq!(zigzag(-2), 3);
assert_eq!(zigzag(2), 4);
}
#[test]
fn rice_roundtrip_single_partition() {
let residuals: Vec<i32> = (-480..480).collect();
for po in 0..=MAX_PARTITION_ORDER {
if !residuals.len().is_multiple_of(1usize << po) {
continue;
}
let bytes = rice_encode(&residuals, po);
let decoded = rice_decode(&bytes, po, residuals.len()).unwrap();
assert_eq!(
decoded, residuals,
"roundtrip failed at partition_order={po}"
);
}
}
#[test]
fn rice_roundtrip_all_partition_orders() {
let residuals: Vec<i32> = (0..1024i32).map(|i| ((i * 13 + 7) % 200) - 100).collect();
for po in 0..=MAX_PARTITION_ORDER {
let bytes = rice_encode(&residuals, po);
let decoded = rice_decode(&bytes, po, residuals.len()).unwrap();
assert_eq!(
decoded, residuals,
"roundtrip failed at partition_order={po}"
);
}
}
#[test]
fn rice_all_zeros_is_optimal() {
let residuals = vec![0i32; 128];
let bytes = rice_encode(&residuals, 0);
assert_eq!(bytes.len(), (5usize + 128).div_ceil(8));
let decoded = rice_decode(&bytes, 0, residuals.len()).unwrap();
assert_eq!(decoded, residuals);
}
#[test]
fn partitioned_beats_single_when_activity_varies() {
let mut residuals = Vec::with_capacity(512);
for i in 0..256i32 {
residuals.push((i % 7) - 3);
}
for i in 0..256i32 {
residuals.push(((i * 41) % 2000) - 1000);
}
let cost_po0 = estimate_cost_from_residuals(&residuals, 0).unwrap();
let cost_po1 = estimate_cost_from_residuals(&residuals, 1).unwrap();
assert!(
cost_po1 < cost_po0,
"partitioned should beat single-k on varying activity: po0={cost_po0}, po1={cost_po1}"
);
}
#[test]
fn rice_decode_rejects_q_shift_overflow() {
let mut bytes = Vec::new();
{
let mut w = BitWriter::new(&mut bytes);
w.write_bits(MAX_RICE_K as u32, 5);
for _ in 0..512 {
w.write_bit(false);
}
w.write_bit(true);
w.write_bits(0, MAX_RICE_K);
w.finish();
}
let result = rice_decode(&bytes, 0, 1);
assert!(
matches!(result, Err(DecodeError::InvalidParameter)),
"adversarial q-overflow must return InvalidParameter, got {:?}",
result
);
}
#[test]
fn rice_decode_rejects_oversize_k() {
let mut bytes = Vec::new();
{
let mut w = BitWriter::new(&mut bytes);
w.write_bits(31, 5);
w.write_bit(true); w.finish();
}
let result = rice_decode(&bytes, 0, 1);
assert!(matches!(result, Err(DecodeError::InvalidParameter)));
}
#[test]
fn rice_decode_rejects_k_at_boundary() {
let mut bytes = Vec::new();
{
let mut w = BitWriter::new(&mut bytes);
w.write_bits(MAX_RICE_K as u32 + 1, 5); w.write_bit(true);
w.finish();
}
let result = rice_decode(&bytes, 0, 1);
assert!(
matches!(result, Err(DecodeError::InvalidParameter)),
"k = MAX_RICE_K + 1 = {} must be rejected",
MAX_RICE_K + 1
);
}
#[test]
fn rice_decode_truncated_returns_error() {
let bytes = rice_encode(&vec![0i32; 128], 0);
let truncated = &bytes[..bytes.len() / 2];
let result = rice_decode(truncated, 0, 128);
assert!(matches!(result, Err(DecodeError::Truncated)));
}
#[test]
fn rice_decode_truncated_mid_remainder() {
let residuals: Vec<i32> = vec![1024; 64];
let bytes = rice_encode(&residuals, 0);
let cut = &bytes[..bytes.len() - 1];
let result = rice_decode(cut, 0, residuals.len());
assert!(
matches!(result, Err(DecodeError::Truncated)),
"mid-remainder truncation must return Truncated, got {:?}",
result
);
}
#[test]
fn rice_decode_truncated_mid_unary() {
let bytes = rice_encode(&[1_000_000i32], 0);
let cut = &bytes[..1];
let result = rice_decode(cut, 0, 1);
assert!(
matches!(result, Err(DecodeError::Truncated)),
"mid-unary-run truncation must return Truncated, got {:?}",
result
);
}
#[test]
fn rice_decode_rejects_mismatched_partition_size() {
let bytes = rice_encode(&[0i32; 8], 3);
let result = rice_decode(&bytes, 3, 7);
assert!(matches!(result, Err(DecodeError::InvalidParameter)));
}
#[test]
fn select_k_picks_k0_for_all_zeros() {
let vals = vec![0u32; 128];
let (k, _) = select_k(&vals);
assert_eq!(k, 0);
}
#[test]
fn select_k_picks_large_k_for_large_values() {
let vals = vec![(1u32 << 20) - 1; 32];
let (k, _) = select_k(&vals);
assert!(k >= 16, "expected large k for large values, got {k}");
}
#[test]
fn select_k_matches_exhaustive() {
let test_cases: Vec<Vec<u32>> = vec![
vec![0u32; 32],
vec![0u32; 1],
vec![0, 0, 0, 0, 0, 0, 0, 42],
vec![1u32; 64],
vec![3u32; 64],
vec![7u32; 64],
vec![(1u32 << 15) - 1; 32],
vec![(1u32 << 20) - 1; 32],
vec![(1u32 << 24) - 1; 16],
{
let mut v = vec![0u32; 64];
for (i, slot) in v.iter_mut().take(8).enumerate() {
*slot = 1_000_000 + i as u32;
}
v
},
(0u32..64).collect(),
(0u32..256).map(|i| i * 37).collect(),
{
let mut v = Vec::with_capacity(128);
let mut state: u32 = 0x9E3779B9;
for _ in 0..128 {
state ^= state << 13;
state ^= state >> 17;
state ^= state << 5;
v.push(state & 0xFFFFF);
}
v
},
vec![1u32 << 23; 16],
vec![1u32 << 24; 16],
vec![u32::MAX >> 8; 16],
];
for case in &test_cases {
let fast = select_k(case);
let exhaustive = select_k_exhaustive(case);
assert_eq!(
fast,
exhaustive,
"select_k mismatch for case of len {}: fast={:?} exhaustive={:?}",
case.len(),
fast,
exhaustive
);
}
}
#[test]
fn estimate_cost_matches_actual_payload_bits() {
let residuals: Vec<i32> = (0..256i32).map(|i| ((i * 7 + 3) % 50) - 25).collect();
for po in 0..=3 {
let estimated = estimate_cost_from_residuals(&residuals, po).unwrap();
let bytes = rice_encode(&residuals, po);
let actual_max = bytes.len() * 8;
let actual_min = actual_max.saturating_sub(7);
assert!(
estimated >= actual_min && estimated <= actual_max,
"cost mismatch at po={po}: estimated={estimated}, actual range=[{actual_min},{actual_max}]"
);
}
}
}