use crate::celt_pvq_v::{pvq_codebook_size, PvqVError, PVQ_V_K_MAX, PVQ_V_N_MAX};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PvqDecodeError {
CodebookSize(PvqVError),
IndexOutOfRange {
index: u32,
codebook_size: u32,
},
OutputBufferTooSmall {
required: usize,
provided: usize,
},
}
impl core::fmt::Display for PvqDecodeError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match *self {
PvqDecodeError::CodebookSize(e) => {
write!(
f,
"oxideav-opus: PVQ vector decode codebook-size error: {e}"
)
}
PvqDecodeError::IndexOutOfRange {
index,
codebook_size,
} => write!(
f,
"oxideav-opus: PVQ vector decode index {index} out of range \
(must be < V(N, K) = {codebook_size}) per RFC 6716 §4.3.4.2"
),
PvqDecodeError::OutputBufferTooSmall { required, provided } => write!(
f,
"oxideav-opus: PVQ vector decode output buffer too small: \
required={required}, provided={provided}"
),
}
}
}
impl std::error::Error for PvqDecodeError {}
impl From<PvqVError> for PvqDecodeError {
fn from(e: PvqVError) -> Self {
PvqDecodeError::CodebookSize(e)
}
}
pub fn decode_pvq_vector(n: u32, k: u32, index: u32) -> Result<Vec<i32>, PvqDecodeError> {
let mut out = vec![0i32; n as usize];
decode_pvq_vector_into(n, k, index, &mut out)?;
Ok(out)
}
pub fn decode_pvq_vector_into(
n: u32,
k: u32,
index: u32,
out: &mut [i32],
) -> Result<usize, PvqDecodeError> {
let n_usize = n as usize;
if out.len() < n_usize {
return Err(PvqDecodeError::OutputBufferTooSmall {
required: n_usize,
provided: out.len(),
});
}
let codebook_size = pvq_codebook_size(n, k)?;
if index >= codebook_size {
return Err(PvqDecodeError::IndexOutOfRange {
index,
codebook_size,
});
}
if n == 0 {
return Ok(0);
}
let mut i: u64 = index as u64;
let mut k_cur: u32 = k;
for j in 0..n {
let v_lower = pvq_codebook_size(n - j - 1, k_cur)? as u64;
let v_upper = pvq_codebook_size(n - j, k_cur)? as u64;
let mut p: u64 = (v_lower + v_upper) / 2;
let sgn: i32 = if i < p {
1
} else {
i -= p;
-1
};
let k0 = k_cur;
p -= v_lower;
while p > i {
k_cur -= 1;
let v = pvq_codebook_size(n - j - 1, k_cur)? as u64;
p -= v;
}
let magnitude = (k0 - k_cur) as i32;
out[j as usize] = sgn * magnitude;
i -= p;
}
Ok(n_usize)
}
pub fn pvq_l1_norm(x: &[i32]) -> u64 {
x.iter().map(|&v| (v as i64).unsigned_abs()).sum()
}
pub fn pvq_l2_norm_squared(x: &[i32]) -> u64 {
x.iter().map(|&v| (v as i64 * v as i64) as u64).sum()
}
pub fn pvq_unit_normalize(x: &[i32], out: &mut [f64]) -> Result<(), PvqShapeError> {
if out.len() < x.len() {
return Err(PvqShapeError::OutputBufferTooSmall {
required: x.len(),
provided: out.len(),
});
}
let norm_sq = pvq_l2_norm_squared(x);
if norm_sq == 0 {
for slot in out.iter_mut().take(x.len()) {
*slot = 0.0;
}
return Ok(());
}
let inv_norm = 1.0 / (norm_sq as f64).sqrt();
for (slot, &v) in out.iter_mut().zip(x.iter()) {
*slot = v as f64 * inv_norm;
}
Ok(())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PvqShapeError {
CodebookSize(PvqVError),
RangeDecoder(crate::Error),
PulseVector(PvqDecodeError),
OutputBufferTooSmall {
required: usize,
provided: usize,
},
}
impl core::fmt::Display for PvqShapeError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match *self {
PvqShapeError::CodebookSize(e) => {
write!(f, "oxideav-opus: PVQ shape decode codebook-size error: {e}")
}
PvqShapeError::RangeDecoder(e) => write!(
f,
"oxideav-opus: PVQ shape decode range-decoder error reading \
ec_dec_uint(V(N, K)): {e}"
),
PvqShapeError::PulseVector(e) => {
write!(f, "oxideav-opus: PVQ shape decode pulse-vector error: {e}")
}
PvqShapeError::OutputBufferTooSmall { required, provided } => write!(
f,
"oxideav-opus: PVQ shape decode output buffer too small: \
required={required}, provided={provided}"
),
}
}
}
impl std::error::Error for PvqShapeError {}
impl From<PvqVError> for PvqShapeError {
fn from(e: PvqVError) -> Self {
PvqShapeError::CodebookSize(e)
}
}
impl From<PvqDecodeError> for PvqShapeError {
fn from(e: PvqDecodeError) -> Self {
match e {
PvqDecodeError::CodebookSize(v) => PvqShapeError::CodebookSize(v),
other => PvqShapeError::PulseVector(other),
}
}
}
pub fn decode_pvq_shape(
rd: &mut crate::RangeDecoder<'_>,
n: u32,
k: u32,
) -> Result<Vec<f64>, PvqShapeError> {
let mut out = vec![0.0f64; n as usize];
decode_pvq_shape_into(rd, n, k, &mut out)?;
Ok(out)
}
pub fn decode_pvq_shape_into(
rd: &mut crate::RangeDecoder<'_>,
n: u32,
k: u32,
out: &mut [f64],
) -> Result<usize, PvqShapeError> {
let n_usize = n as usize;
if out.len() < n_usize {
return Err(PvqShapeError::OutputBufferTooSmall {
required: n_usize,
provided: out.len(),
});
}
let codebook_size = pvq_codebook_size(n, k)?;
let index = rd
.dec_uint(codebook_size)
.map_err(PvqShapeError::RangeDecoder)?;
let mut pulses = vec![0i32; n_usize];
decode_pvq_vector_into(n, k, index, &mut pulses)?;
pvq_unit_normalize(&pulses, out)?;
Ok(n_usize)
}
pub const PVQ_DECODE_N_MAX: u32 = PVQ_V_N_MAX;
pub const PVQ_DECODE_K_MAX: u32 = PVQ_V_K_MAX;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn every_index_decodes_to_l1_norm_k() {
for n in 1..=6u32 {
for k in 0..=6u32 {
let v = pvq_codebook_size(n, k).unwrap();
for index in 0..v {
let x = decode_pvq_vector(n, k, index).unwrap();
assert_eq!(x.len(), n as usize, "len at (N={n}, K={k}, i={index})");
assert_eq!(
pvq_l1_norm(&x),
k as u64,
"L1 norm at (N={n}, K={k}, i={index}): got {x:?}"
);
}
}
}
}
#[test]
fn decode_is_injective_over_full_index_range() {
use std::collections::HashSet;
for n in 1..=6u32 {
for k in 0..=6u32 {
let v = pvq_codebook_size(n, k).unwrap();
let mut seen: HashSet<Vec<i32>> = HashSet::new();
for index in 0..v {
let x = decode_pvq_vector(n, k, index).unwrap();
assert!(
seen.insert(x.clone()),
"duplicate vector {x:?} at (N={n}, K={k}, i={index})"
);
}
assert_eq!(seen.len() as u32, v, "coverage at (N={n}, K={k})");
}
}
}
#[test]
fn k_zero_decodes_to_all_zero_vector() {
for n in 0..=8u32 {
let x = decode_pvq_vector(n, 0, 0).unwrap();
assert_eq!(x, vec![0i32; n as usize], "all-zero at N={n}");
}
}
#[test]
fn n_one_k_one_two_signed_pulses() {
let mut got: Vec<Vec<i32>> = (0..2)
.map(|i| decode_pvq_vector(1, 1, i).unwrap())
.collect();
got.sort();
assert_eq!(got, vec![vec![-1], vec![1]]);
}
#[test]
fn n_one_k_three_two_signed_pulses() {
let mut got: Vec<Vec<i32>> = (0..2)
.map(|i| decode_pvq_vector(1, 3, i).unwrap())
.collect();
got.sort();
assert_eq!(got, vec![vec![-3], vec![3]]);
}
#[test]
fn n_two_k_one_four_codewords() {
let mut got: Vec<Vec<i32>> = (0..4)
.map(|i| decode_pvq_vector(2, 1, i).unwrap())
.collect();
got.sort();
assert_eq!(got, vec![vec![-1, 0], vec![0, -1], vec![0, 1], vec![1, 0]]);
}
#[test]
fn n_two_k_two_full_codebook() {
let mut got: Vec<Vec<i32>> = (0..8)
.map(|i| decode_pvq_vector(2, 2, i).unwrap())
.collect();
got.sort();
let mut expected = vec![
vec![-2, 0],
vec![-1, -1],
vec![-1, 1],
vec![0, -2],
vec![0, 2],
vec![1, -1],
vec![1, 1],
vec![2, 0],
];
expected.sort();
assert_eq!(got, expected);
for v in &got {
assert_eq!(pvq_l1_norm(v), 2);
}
}
#[test]
fn n_three_k_two_full_codebook_count_and_norm() {
use std::collections::HashSet;
let v = pvq_codebook_size(3, 2).unwrap();
assert_eq!(v, 18);
let mut seen = HashSet::new();
for index in 0..v {
let x = decode_pvq_vector(3, 2, index).unwrap();
assert_eq!(pvq_l1_norm(&x), 2);
assert!(seen.insert(x));
}
assert_eq!(seen.len(), 18);
}
#[test]
fn index_zero_is_all_positive_leading_pulse() {
for n in 1..=6u32 {
for k in 1..=6u32 {
let x = decode_pvq_vector(n, k, 0).unwrap();
assert_eq!(x[0], k as i32, "index-0 leading pulse at (N={n}, K={k})");
for (idx, &val) in x.iter().enumerate().skip(1) {
assert_eq!(val, 0, "index-0 trailing coord {idx} at (N={n}, K={k})");
}
}
}
}
#[test]
fn last_index_is_all_negative_leading_pulse() {
for n in 1..=6u32 {
for k in 1..=6u32 {
let v = pvq_codebook_size(n, k).unwrap();
let x = decode_pvq_vector(n, k, v - 1).unwrap();
assert_eq!(pvq_l1_norm(&x), k as u64);
assert!(
x[0] <= 0,
"last-index leading coord at (N={n}, K={k}): {x:?}"
);
}
}
}
#[test]
fn l2_norm_squared_matches_manual() {
assert_eq!(pvq_l2_norm_squared(&[3, 0]), 9);
assert_eq!(pvq_l2_norm_squared(&[1, 1]), 2);
assert_eq!(pvq_l2_norm_squared(&[-2, 1, -1]), 6);
assert_eq!(pvq_l2_norm_squared(&[]), 0);
}
#[test]
fn l1_norm_matches_manual() {
assert_eq!(pvq_l1_norm(&[3, 0]), 3);
assert_eq!(pvq_l1_norm(&[-2, 1, -1]), 4);
assert_eq!(pvq_l1_norm(&[0, 0, 0]), 0);
}
#[test]
fn decode_into_matches_allocating_variant() {
for n in 1..=5u32 {
for k in 0..=5u32 {
let v = pvq_codebook_size(n, k).unwrap();
for index in 0..v {
let owned = decode_pvq_vector(n, k, index).unwrap();
let mut buf = vec![0i32; n as usize + 3];
let written = decode_pvq_vector_into(n, k, index, &mut buf).unwrap();
assert_eq!(written, n as usize);
assert_eq!(&buf[..n as usize], owned.as_slice());
assert_eq!(&buf[n as usize..], &[0, 0, 0]);
}
}
}
}
#[test]
fn decode_into_rejects_short_buffer() {
let mut buf = vec![0i32; 2];
let result = decode_pvq_vector_into(3, 2, 0, &mut buf);
assert_eq!(
result,
Err(PvqDecodeError::OutputBufferTooSmall {
required: 3,
provided: 2,
})
);
}
#[test]
fn decode_into_exact_length_buffer_ok() {
let mut buf = vec![0i32; 3];
let written = decode_pvq_vector_into(3, 2, 5, &mut buf).unwrap();
assert_eq!(written, 3);
assert_eq!(pvq_l1_norm(&buf), 2);
}
#[test]
fn rejects_index_equal_to_codebook_size() {
let v = pvq_codebook_size(3, 2).unwrap();
let result = decode_pvq_vector(3, 2, v);
assert_eq!(
result,
Err(PvqDecodeError::IndexOutOfRange {
index: v,
codebook_size: v,
})
);
}
#[test]
fn rejects_index_above_codebook_size() {
let v = pvq_codebook_size(2, 3).unwrap();
let result = decode_pvq_vector(2, 3, v + 100);
assert_eq!(
result,
Err(PvqDecodeError::IndexOutOfRange {
index: v + 100,
codebook_size: v,
})
);
}
#[test]
fn last_valid_index_is_accepted() {
let v = pvq_codebook_size(4, 3).unwrap();
let x = decode_pvq_vector(4, 3, v - 1).unwrap();
assert_eq!(pvq_l1_norm(&x), 3);
}
#[test]
fn propagates_n_out_of_range() {
let result = decode_pvq_vector(PVQ_V_N_MAX + 1, 2, 0);
match result {
Err(PvqDecodeError::CodebookSize(PvqVError::NOutOfRange { .. })) => {}
other => panic!("expected CodebookSize(NOutOfRange), got {other:?}"),
}
}
#[test]
fn propagates_k_out_of_range() {
let result = decode_pvq_vector(4, PVQ_V_K_MAX + 1, 0);
match result {
Err(PvqDecodeError::CodebookSize(PvqVError::KOutOfRange { .. })) => {}
other => panic!("expected CodebookSize(KOutOfRange), got {other:?}"),
}
}
#[test]
fn propagates_overflow_for_large_codebook() {
let result = decode_pvq_vector(176, 176, 0);
match result {
Err(PvqDecodeError::CodebookSize(PvqVError::OverflowsDecUintRange { .. })) => {}
other => panic!("expected CodebookSize(OverflowsDecUintRange), got {other:?}"),
}
}
#[test]
fn n_zero_k_zero_empty_vector() {
let x = decode_pvq_vector(0, 0, 0).unwrap();
assert!(x.is_empty());
}
#[test]
fn n_zero_k_positive_has_no_codewords() {
let result = decode_pvq_vector(0, 3, 0);
assert_eq!(
result,
Err(PvqDecodeError::IndexOutOfRange {
index: 0,
codebook_size: 0,
})
);
}
#[test]
fn larger_band_spot_check_l1_invariant() {
let n = 16u32;
let k = 4u32;
let v = pvq_codebook_size(n, k).unwrap();
let stride = (v / 97).max(1);
let mut index = 0u32;
while index < v {
let x = decode_pvq_vector(n, k, index).unwrap();
assert_eq!(x.len(), n as usize);
assert_eq!(pvq_l1_norm(&x), k as u64, "L1 at (N={n}, K={k}, i={index})");
index += stride;
}
let x_last = decode_pvq_vector(n, k, v - 1).unwrap();
assert_eq!(pvq_l1_norm(&x_last), k as u64);
}
#[test]
fn mirrored_bounds_match_pvq_v() {
assert_eq!(PVQ_DECODE_N_MAX, PVQ_V_N_MAX);
assert_eq!(PVQ_DECODE_K_MAX, PVQ_V_K_MAX);
}
#[test]
fn display_messages_mention_the_failing_input() {
let oob = PvqDecodeError::IndexOutOfRange {
index: 50,
codebook_size: 18,
};
let msg = format!("{oob}");
assert!(msg.contains("50"));
assert!(msg.contains("18"));
assert!(msg.contains("4.3.4.2"));
let small = PvqDecodeError::OutputBufferTooSmall {
required: 7,
provided: 3,
};
let msg = format!("{small}");
assert!(msg.contains('7'));
assert!(msg.contains('3'));
let cb = PvqDecodeError::CodebookSize(PvqVError::OverflowsDecUintRange { n: 176, k: 176 });
let msg = format!("{cb}");
assert!(msg.contains("176"));
}
#[test]
fn from_pvq_v_error_conversion() {
let e: PvqDecodeError = PvqVError::NOutOfRange {
provided: 999,
max: PVQ_V_N_MAX,
}
.into();
assert!(matches!(
e,
PvqDecodeError::CodebookSize(PvqVError::NOutOfRange { .. })
));
}
use crate::RangeDecoder;
fn approx_eq(a: f64, b: f64, eps: f64) -> bool {
(a - b).abs() <= eps
}
#[test]
fn unit_normalize_produces_unit_l2_norm() {
for n in 1..=6u32 {
for k in 1..=6u32 {
let v = pvq_codebook_size(n, k).unwrap();
for index in 0..v {
let x = decode_pvq_vector(n, k, index).unwrap();
let mut out = vec![0.0f64; n as usize];
pvq_unit_normalize(&x, &mut out).unwrap();
let norm_sq: f64 = out.iter().map(|&c| c * c).sum();
assert!(
approx_eq(norm_sq, 1.0, 1e-12),
"‖shape‖² = {norm_sq} at (N={n}, K={k}, i={index})"
);
}
}
}
}
#[test]
fn unit_normalize_preserves_direction() {
let x = [3i32, 0, -4];
let mut out = [0.0f64; 3];
pvq_unit_normalize(&x, &mut out).unwrap();
assert!(approx_eq(out[0], 0.6, 1e-15));
assert!(approx_eq(out[1], 0.0, 1e-15));
assert!(approx_eq(out[2], -0.8, 1e-15));
}
#[test]
fn unit_normalize_single_pulse_is_signed_unit() {
let x = [0i32, 5, 0];
let mut out = [0.0f64; 3];
pvq_unit_normalize(&x, &mut out).unwrap();
assert!(approx_eq(out[0], 0.0, 1e-15));
assert!(approx_eq(out[1], 1.0, 1e-15));
assert!(approx_eq(out[2], 0.0, 1e-15));
let xn = [0i32, -2];
let mut outn = [0.0f64; 2];
pvq_unit_normalize(&xn, &mut outn).unwrap();
assert!(approx_eq(outn[1], -1.0, 1e-15));
}
#[test]
fn unit_normalize_zero_vector_stays_zero() {
let x = [0i32, 0, 0];
let mut out = [9.0f64; 3];
pvq_unit_normalize(&x, &mut out).unwrap();
assert_eq!(out, [0.0, 0.0, 0.0]);
}
#[test]
fn unit_normalize_rejects_short_buffer() {
let x = [1i32, 1, 1];
let mut out = [0.0f64; 2];
let r = pvq_unit_normalize(&x, &mut out);
assert_eq!(
r,
Err(PvqShapeError::OutputBufferTooSmall {
required: 3,
provided: 2,
})
);
}
#[test]
fn unit_normalize_over_long_buffer_leaves_tail() {
let x = [3i32, -4];
let mut out = [7.0f64; 4];
pvq_unit_normalize(&x, &mut out).unwrap();
assert!(approx_eq(out[0], 0.6, 1e-15));
assert!(approx_eq(out[1], -0.8, 1e-15));
assert_eq!(out[2], 7.0);
assert_eq!(out[3], 7.0);
}
#[test]
fn shape_matches_vector_then_normalize() {
let buf = [0x9Au8, 0x3C, 0x71, 0x05, 0xE2, 0x4D, 0xB8, 0x16];
for n in 1..=6u32 {
for k in 1..=6u32 {
let v = pvq_codebook_size(n, k).unwrap();
let mut probe = RangeDecoder::new(&buf);
let index = probe.dec_uint(v).unwrap();
assert!(index < v, "probe index in range at (N={n}, K={k})");
let pulses = decode_pvq_vector(n, k, index).unwrap();
let mut expected = vec![0.0f64; n as usize];
pvq_unit_normalize(&pulses, &mut expected).unwrap();
let mut rd = RangeDecoder::new(&buf);
let shape = decode_pvq_shape(&mut rd, n, k).unwrap();
assert_eq!(shape.len(), n as usize);
for (a, b) in shape.iter().zip(expected.iter()) {
assert!(
approx_eq(*a, *b, 1e-15),
"shape mismatch at (N={n}, K={k}): {shape:?} vs {expected:?}"
);
}
let norm_sq: f64 = shape.iter().map(|&c| c * c).sum();
assert!(approx_eq(norm_sq, 1.0, 1e-12));
}
}
}
#[test]
fn shape_into_matches_allocating_variant() {
let buf = [0x42u8, 0xF1, 0x08, 0xAC, 0x55, 0x9D];
for n in 1..=5u32 {
for k in 1..=5u32 {
let mut rd_a = RangeDecoder::new(&buf);
let owned = decode_pvq_shape(&mut rd_a, n, k).unwrap();
let mut rd_b = RangeDecoder::new(&buf);
let mut buf_out = vec![0.0f64; n as usize + 2];
let written = decode_pvq_shape_into(&mut rd_b, n, k, &mut buf_out).unwrap();
assert_eq!(written, n as usize);
assert_eq!(&buf_out[..n as usize], owned.as_slice());
assert_eq!(&buf_out[n as usize..], &[0.0, 0.0]);
}
}
}
#[test]
fn shape_k_zero_is_all_zero_and_consumes_nothing() {
let buf = [0xABu8, 0xCD, 0xEF, 0x12];
for n in 1..=6u32 {
let mut rd = RangeDecoder::new(&buf);
let tell_before = rd.tell();
let shape = decode_pvq_shape(&mut rd, n, 0).unwrap();
assert_eq!(shape, vec![0.0f64; n as usize]);
assert_eq!(rd.tell(), tell_before);
}
}
#[test]
fn shape_rejects_short_output_buffer() {
let buf = [0x00u8, 0x11, 0x22];
let mut rd = RangeDecoder::new(&buf);
let mut out = [0.0f64; 2];
let r = decode_pvq_shape_into(&mut rd, 3, 2, &mut out);
assert_eq!(
r,
Err(PvqShapeError::OutputBufferTooSmall {
required: 3,
provided: 2,
})
);
}
#[test]
fn shape_propagates_codebook_size_error() {
let buf = [0x00u8; 4];
let mut rd = RangeDecoder::new(&buf);
let r = decode_pvq_shape(&mut rd, PVQ_V_N_MAX + 1, 2);
match r {
Err(PvqShapeError::CodebookSize(PvqVError::NOutOfRange { .. })) => {}
other => panic!("expected CodebookSize(NOutOfRange), got {other:?}"),
}
}
#[test]
fn shape_n_one_is_signed_unit() {
let buf = [0x80u8, 0x00, 0x00, 0x00];
let mut rd = RangeDecoder::new(&buf);
let shape = decode_pvq_shape(&mut rd, 1, 3).unwrap();
assert_eq!(shape.len(), 1);
assert!(approx_eq(shape[0].abs(), 1.0, 1e-15));
}
#[test]
fn shape_error_from_conversions() {
let e: PvqShapeError = PvqVError::KOutOfRange {
provided: 99_999,
max: PVQ_V_K_MAX,
}
.into();
assert!(matches!(
e,
PvqShapeError::CodebookSize(PvqVError::KOutOfRange { .. })
));
let flat: PvqShapeError =
PvqDecodeError::CodebookSize(PvqVError::OverflowsDecUintRange { n: 176, k: 176 })
.into();
assert!(matches!(
flat,
PvqShapeError::CodebookSize(PvqVError::OverflowsDecUintRange { .. })
));
let pv: PvqShapeError = PvqDecodeError::IndexOutOfRange {
index: 5,
codebook_size: 3,
}
.into();
assert!(matches!(pv, PvqShapeError::PulseVector(_)));
}
#[test]
fn shape_error_display_mentions_inputs() {
let small = PvqShapeError::OutputBufferTooSmall {
required: 9,
provided: 4,
};
let msg = format!("{small}");
assert!(msg.contains('9'));
assert!(msg.contains('4'));
let rd = PvqShapeError::RangeDecoder(crate::Error::MalformedPacket);
let msg = format!("{rd}");
assert!(msg.contains("ec_dec_uint"));
let cb = PvqShapeError::CodebookSize(PvqVError::OverflowsDecUintRange { n: 176, k: 176 });
let msg = format!("{cb}");
assert!(msg.contains("176"));
}
}