use crate::range::{RangeDecoder, RangeEncoder};
const PVQ_U_ROW: [usize; 16] = [
0, 176, 351, 525, 698, 870, 1041, 1131, 1178, 1207, 1226, 1240, 1248, 1254, 1257, 1272,
];
const PVQ_U_DATA_LEN: usize = PVQ_U_ROW[15];
const fn compute_pvq_u_data() -> [u32; PVQ_U_DATA_LEN] {
let mut uu = [[0u32; 177]; 15];
uu[0][0] = 1; let mut a = 1;
while a < 15 {
let mut b = 1;
while b < 177 {
uu[a][b] = uu[a - 1][b].wrapping_add(uu[a][b - 1]).wrapping_add(uu[a - 1][b - 1]);
b += 1;
}
a += 1;
}
let mut d = [0u32; PVQ_U_DATA_LEN];
let mut m = 0;
while m < 15 {
let start = PVQ_U_ROW[m] + m; let end = if m < 14 {
PVQ_U_ROW[m + 1] + (m + 1)
} else {
PVQ_U_ROW[15]
};
let mut p = start;
while p < end {
d[p] = uu[m][p - PVQ_U_ROW[m]];
p += 1;
}
m += 1;
}
d
}
static PVQ_U_DATA: [u32; PVQ_U_DATA_LEN] = compute_pvq_u_data();
fn pvq_u_data() -> &'static [u32] {
&PVQ_U_DATA
}
#[inline]
fn celt_pvq_u(data: &[u32], n: usize, k: usize) -> u32 {
let (lo, hi) = if n < k { (n, k) } else { (k, n) };
data[PVQ_U_ROW[lo] + hi]
}
#[inline]
fn urow(data: &[u32], row: usize, col: usize) -> u32 {
data[PVQ_U_ROW[row] + col]
}
fn cwrsi(n0: usize, k0: usize, mut i: u32, y: &mut [i32]) {
debug_assert!(n0 > 1 && k0 > 0);
let data = pvq_u_data();
let mut n = n0;
let mut k = k0;
let mut yi = 0usize;
while n > 2 {
if k >= n {
let p = urow(data, n, k + 1);
let s = -i32::from(i >= p);
if s != 0 {
i -= p;
}
let k0_dim = k;
let q = urow(data, n, n);
let p = if q > i {
k = n;
loop {
k -= 1;
let p = urow(data, k, n);
if p <= i {
break p;
}
}
} else {
let mut p = urow(data, n, k);
while p > i {
k -= 1;
p = urow(data, n, k);
}
p
};
i -= p;
let val = (k0_dim - k) as i32;
y[yi] = (val + s) ^ s;
yi += 1;
} else {
let p = urow(data, k, n);
let q = urow(data, k + 1, n);
if p <= i && i < q {
i -= p;
y[yi] = 0;
} else {
let s = -i32::from(i >= q);
if s != 0 {
i -= q;
}
let k0_dim = k;
let p = loop {
k -= 1;
let p = urow(data, k, n);
if p <= i {
break p;
}
};
i -= p;
let val = (k0_dim - k) as i32;
y[yi] = (val + s) ^ s;
}
yi += 1;
}
n -= 1;
}
let p = 2 * k as u32 + 1;
let s = -i32::from(i >= p);
if s != 0 {
i -= p;
}
let k0_dim = k;
k = ((i + 1) >> 1) as usize;
if k != 0 {
i -= 2 * k as u32 - 1;
}
let val = (k0_dim - k) as i32;
y[yi] = (val + s) ^ s;
yi += 1;
let s = -(i as i32);
y[yi] = (k as i32 + s) ^ s;
}
fn icwrs(y: &[i32]) -> u32 {
let n = y.len();
debug_assert!(n >= 2);
let data = pvq_u_data();
let mut j = n - 1;
let mut i = u32::from(y[j] < 0);
let mut k = y[j].unsigned_abs() as usize;
loop {
j -= 1;
i += celt_pvq_u(data, n - j, k);
k += y[j].unsigned_abs() as usize;
if y[j] < 0 {
i += celt_pvq_u(data, n - j, k + 1);
}
if j == 0 {
break;
}
}
i
}
#[must_use]
pub fn pvq_codebook_size(n: usize, k: usize) -> u32 {
let data = pvq_u_data();
celt_pvq_u(data, n, k) + celt_pvq_u(data, n, k + 1)
}
#[must_use]
pub fn decode_pulses(dec: &mut RangeDecoder, y: &mut [i32], k: usize) -> Option<()> {
debug_assert!(y.len() >= 2 && k >= 1);
let v = pvq_codebook_size(y.len(), k);
let i = dec.decode_uint(v)?;
cwrsi(y.len(), k, i, y);
Some(())
}
pub fn encode_pulses(enc: &mut RangeEncoder, y: &[i32], k: usize) {
debug_assert!(y.len() >= 2 && k >= 1);
enc.encode_uint(icwrs(y), pvq_codebook_size(y.len(), k));
}
#[cfg(test)]
mod tests {
use super::*;
const V_TABLE: [[u32; 10]; 10] = [
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 2, 2, 2, 2, 2, 2, 2, 2, 2],
[1, 4, 8, 12, 16, 20, 24, 28, 32, 36],
[1, 6, 18, 38, 66, 102, 146, 198, 258, 326],
[1, 8, 32, 88, 192, 360, 608, 952, 1408, 1992],
[1, 10, 50, 170, 450, 1002, 1970, 3530, 5890, 9290],
[1, 12, 72, 292, 912, 2364, 5336, 10836, 20256, 35436],
[1, 14, 98, 462, 1666, 4942, 12642, 28814, 59906, 115598],
[1, 16, 128, 688, 2816, 9424, 27008, 68464, 157184, 332688],
[1, 18, 162, 978, 4482, 16722, 53154, 148626, 374274, 864146],
];
#[test]
fn codebook_sizes_match_reference_table() {
for (n, row) in V_TABLE.iter().enumerate().skip(2) {
for (k, &expected) in row.iter().enumerate().skip(1) {
assert_eq!(pvq_codebook_size(n, k), expected, "V({n}, {k})");
}
}
}
#[test]
fn exhaustive_index_bijection_small_nk() {
for n in 2..=6usize {
for k in 1..=6usize {
let v = pvq_codebook_size(n, k);
for i in 0..v {
let mut y = alloc::vec![0i32; n];
cwrsi(n, k, i, &mut y);
let pulses: u32 = y.iter().map(|x| x.unsigned_abs()).sum();
assert_eq!(pulses, k as u32, "N={n} K={k} i={i}: pulse count");
assert_eq!(icwrs(&y), i, "N={n} K={k}: index round-trip");
}
}
}
}
#[test]
fn range_coder_round_trip() {
let cases: [(usize, usize); 6] = [(2, 1), (4, 3), (8, 8), (16, 4), (24, 5), (96, 3)];
let mut enc = RangeEncoder::new(1024);
let mut vectors = alloc::vec::Vec::new();
for &(n, k) in &cases {
let mut y = vec![0i32; n];
for p in 0..k {
let at = (p * 7) % n;
y[at] += if p % 2 == 0 { 1 } else { -1 };
}
let total: u32 = y.iter().map(|x| x.unsigned_abs()).sum();
if total != k as u32 {
y = vec![0i32; n];
for p in 0..k {
y[p % n] += 1;
}
}
encode_pulses(&mut enc, &y, k);
vectors.push((n, k, y));
}
let enc_rng = enc.range_size();
let buf = enc.finalize().expect("within budget");
let mut dec = RangeDecoder::new(&buf);
for (n, k, expected) in vectors {
let mut y = vec![0i32; n];
decode_pulses(&mut dec, &mut y, k).expect("in range");
assert_eq!(y, expected, "N={n} K={k}");
}
assert_eq!(dec.range_size(), enc_rng);
}
extern crate alloc;
}