use crate::{CodecError, CodecResult};
const EC_SYM_BITS: u32 = 8;
const EC_SYM_MASK: u32 = (1u32 << EC_SYM_BITS) - 1;
const EC_CODE_BITS: u32 = 32;
const EC_CODE_TOP: u32 = 1u32 << (EC_CODE_BITS - 1);
const EC_CODE_BOT: u32 = EC_CODE_TOP >> EC_SYM_BITS;
const EC_CODE_SHIFT: u32 = EC_CODE_BITS - EC_SYM_BITS - 1;
const EC_UINT_BITS: u32 = 8;
const EC_WINDOW_SIZE: u32 = 32;
#[derive(Debug)]
pub struct SilkRangeEncoder {
front: Vec<u8>,
back: Vec<u8>,
val: u32,
rng: u32,
rem: Option<u8>,
ext: u32,
nbits_total: i32,
end_window: u32,
nend_bits: u32,
}
impl Default for SilkRangeEncoder {
fn default() -> Self {
Self::new()
}
}
impl SilkRangeEncoder {
#[must_use]
pub fn new() -> Self {
Self {
front: Vec::with_capacity(64),
back: Vec::with_capacity(16),
val: 0,
rng: EC_CODE_TOP,
rem: None,
ext: 0,
nbits_total: (EC_CODE_BITS as i32) + 1,
end_window: 0,
nend_bits: 0,
}
}
#[must_use]
pub fn tell(&self) -> i32 {
self.nbits_total - log2_floor(self.rng) as i32
}
pub fn encode_icdf(&mut self, s: usize, icdf: &[u8], ftb: u32) -> CodecResult<()> {
if icdf.is_empty() {
return Err(CodecError::InvalidData("empty ICDF table".to_string()));
}
if s >= icdf.len() {
return Err(CodecError::InvalidData(format!(
"symbol {s} out of range for ICDF of length {}",
icdf.len()
)));
}
let r = self.rng >> ftb;
if s > 0 {
let prev = u32::from(icdf[s - 1]);
self.val = self
.val
.wrapping_add(self.rng.wrapping_sub(r.wrapping_mul(prev)));
self.rng = r.wrapping_mul(prev - u32::from(icdf[s]));
} else {
self.rng = self.rng.wrapping_sub(r.wrapping_mul(u32::from(icdf[s])));
}
self.normalize();
Ok(())
}
pub fn encode_bit_logp(&mut self, bit: bool, logp: u32) -> CodecResult<()> {
let r = self.rng;
let s = r >> logp;
let r2 = r - s;
if bit {
self.val = self.val.wrapping_add(r2);
self.rng = s;
} else {
self.rng = r2;
}
self.normalize();
Ok(())
}
pub fn encode_uint(&mut self, value: u32, ft: u32) -> CodecResult<()> {
if ft <= 1 {
return Ok(());
}
if value >= ft {
return Err(CodecError::InvalidData(format!(
"encode_uint value {value} out of range for ft {ft}"
)));
}
let ft_minus_1 = ft - 1;
let ftb = (32 - ft_minus_1.leading_zeros()) as i32;
if ftb > EC_UINT_BITS as i32 {
let extra = (ftb - EC_UINT_BITS as i32) as u32;
let top = (ft_minus_1 >> extra) + 1;
let high = value >> extra;
let low = value & ((1u32 << extra) - 1);
self.encode_inner(high, high + 1, top)?;
self.encode_raw_bits(low, extra)?;
Ok(())
} else {
self.encode_inner(value, value + 1, ft)
}
}
pub fn encode_raw_bits(&mut self, value: u32, bits: u32) -> CodecResult<()> {
if bits == 0 {
return Ok(());
}
if bits > 32 {
return Err(CodecError::InvalidData(
"cannot encode more than 32 raw bits".to_string(),
));
}
let value = if bits == 32 {
value
} else {
value & ((1u32 << bits) - 1)
};
let mut window = self.end_window;
let mut used = self.nend_bits;
if used + bits > EC_WINDOW_SIZE {
while used >= EC_SYM_BITS {
self.back.push((window & EC_SYM_MASK) as u8);
window >>= EC_SYM_BITS;
used -= EC_SYM_BITS;
}
}
window |= value << used;
used += bits;
self.end_window = window;
self.nend_bits = used;
self.nbits_total += bits as i32;
Ok(())
}
pub fn finish(mut self) -> CodecResult<Vec<u8>> {
let l_initial = (EC_CODE_BITS - ec_ilog(self.rng)) as i32;
let mut l = l_initial;
let mut msk = (EC_CODE_TOP - 1) >> l;
let mut end = (self.val.wrapping_add(msk)) & !msk;
if (end | msk) >= self.val.wrapping_add(self.rng) {
l += 1;
msk >>= 1;
end = (self.val.wrapping_add(msk)) & !msk;
}
while l > 0 {
let byte = (end >> EC_CODE_SHIFT) as i32;
self.carry_out(byte);
end = (end << EC_SYM_BITS) & (EC_CODE_TOP - 1);
l -= EC_SYM_BITS as i32;
}
if self.rem.is_some() || self.ext > 0 {
self.carry_out(0);
}
let mut window = self.end_window;
let mut used = self.nend_bits;
while used >= EC_SYM_BITS {
self.back.push((window & EC_SYM_MASK) as u8);
window >>= EC_SYM_BITS;
used -= EC_SYM_BITS;
}
if used > 0 {
self.back.push((window & EC_SYM_MASK) as u8);
}
const TAIL_PAD: usize = 8;
let mut output = self.front;
for _ in 0..TAIL_PAD {
output.push(0);
}
for &b in self.back.iter().rev() {
output.push(b);
}
Ok(output)
}
fn encode_inner(&mut self, fl: u32, fh: u32, ft: u32) -> CodecResult<()> {
if ft == 0 {
return Err(CodecError::InvalidData(
"encode_inner total frequency zero".to_string(),
));
}
if fh > ft || fl > fh {
return Err(CodecError::InvalidData(format!(
"encode_inner invalid bracket fl={fl}, fh={fh}, ft={ft}"
)));
}
let r = self.rng / ft;
if fl > 0 {
self.val = self
.val
.wrapping_add(self.rng.wrapping_sub(r.wrapping_mul(ft - fl)));
self.rng = r.wrapping_mul(fh - fl);
} else {
self.rng = self.rng.wrapping_sub(r.wrapping_mul(ft - fh));
}
self.normalize();
Ok(())
}
fn normalize(&mut self) {
while self.rng <= EC_CODE_BOT {
let byte = (self.val >> EC_CODE_SHIFT) as i32;
self.carry_out(byte);
self.val = (self.val << EC_SYM_BITS) & (EC_CODE_TOP - 1);
self.rng <<= EC_SYM_BITS;
self.nbits_total += EC_SYM_BITS as i32;
}
}
fn carry_out(&mut self, c: i32) {
if (c as u32) != EC_SYM_MASK {
let carry: u8 = ((c as u32) >> EC_SYM_BITS) as u8;
if let Some(rem) = self.rem.take() {
self.front.push(rem.wrapping_add(carry));
}
if self.ext > 0 {
let sym = ((EC_SYM_MASK + u32::from(carry)) & EC_SYM_MASK) as u8;
for _ in 0..self.ext {
self.front.push(sym);
}
self.ext = 0;
}
self.rem = Some(((c as u32) & EC_SYM_MASK) as u8);
} else {
self.ext += 1;
}
}
}
fn ec_ilog(x: u32) -> u32 {
32 - x.leading_zeros()
}
fn log2_floor(x: u32) -> u32 {
x.checked_ilog2().unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::super::silk_range::SilkRangeDecoder;
use super::*;
#[test]
fn test_round_trip_uniform_icdf() {
let icdf: [u8; 4] = [192, 128, 64, 0];
let symbols = [0usize, 1, 2, 3, 0, 3, 1, 2];
let mut enc = SilkRangeEncoder::new();
for &s in &symbols {
enc.encode_icdf(s, &icdf, 8).expect("encode icdf");
}
let bytes = enc.finish().expect("finish");
let mut dec = SilkRangeDecoder::new(&bytes).expect("decoder init");
for &s in &symbols {
let got = dec.decode_icdf(&icdf, 8).expect("decode icdf");
assert_eq!(
got, s,
"round trip uniform iCDF failed: encoded {s}, decoded {got}"
);
}
}
#[test]
fn test_round_trip_binary_icdf() {
let icdf: [u8; 2] = [206, 0];
let symbols = [1usize; 16];
let mut enc = SilkRangeEncoder::new();
for &s in &symbols {
enc.encode_icdf(s, &icdf, 8).expect("encode");
}
let bytes = enc.finish().expect("finish");
let mut dec = SilkRangeDecoder::new(&bytes).expect("dec");
for &s in &symbols {
let got = dec.decode_icdf(&icdf, 8).expect("decode");
assert_eq!(got, s);
}
}
#[test]
fn test_round_trip_bit_logp() {
let bits = [true, false, true, true, false, false, true];
let logps = [1u32, 1, 1, 1, 1, 1, 1];
let mut enc = SilkRangeEncoder::new();
for (&b, &lp) in bits.iter().zip(logps.iter()) {
enc.encode_bit_logp(b, lp).expect("encode bit");
}
let bytes = enc.finish().expect("finish");
let mut dec = SilkRangeDecoder::new(&bytes).expect("dec");
for (&b, &lp) in bits.iter().zip(logps.iter()) {
let got = dec.decode_bit_logp(lp).expect("decode bit");
assert_eq!(got, b);
}
}
#[test]
fn test_round_trip_raw_bits() {
let values: &[(u32, u32)] = &[(0x0B, 4), (0x0A, 4), (0x55, 8), (0xC, 4)];
let mut enc = SilkRangeEncoder::new();
for _ in 0..32 {
enc.encode_bit_logp(false, 1).expect("seed");
}
for &(v, b) in values {
enc.encode_raw_bits(v, b).expect("raw");
}
let bytes = enc.finish().expect("finish");
let mut dec = SilkRangeDecoder::new(&bytes).expect("dec");
for _ in 0..32 {
let _ = dec.decode_bit_logp(1).expect("seed dec");
}
let g1 = dec.decode_raw_bits(4).expect("r1");
let g2 = dec.decode_raw_bits(4).expect("r2");
let g3 = dec.decode_raw_bits(8).expect("r3");
let g4 = dec.decode_raw_bits(4).expect("r4");
assert_eq!(g1, 0x0B);
assert_eq!(g2, 0x0A);
assert_eq!(g3, 0x55);
assert_eq!(g4, 0x0C);
}
#[test]
fn test_round_trip_uint() {
let cases: &[(u32, u32)] = &[(3, 7), (0, 4), (15, 16), (300, 1024), (0, 65536)];
let mut enc = SilkRangeEncoder::new();
for &(v, ft) in cases {
enc.encode_uint(v, ft).expect("encode uint");
}
let bytes = enc.finish().expect("finish");
let mut dec = SilkRangeDecoder::new(&bytes).expect("dec");
for &(v, ft) in cases {
let g = dec.decode_uint(ft).expect("decode uint");
assert_eq!(g, v, "uint round trip {v}/{ft} -> {g}");
}
}
#[test]
fn test_tell_monotonic() {
let mut enc = SilkRangeEncoder::new();
let mut last = enc.tell();
for _ in 0..32 {
enc.encode_bit_logp(true, 1).expect("bit");
let now = enc.tell();
assert!(now >= last, "tell must be monotonic");
last = now;
}
}
}