use crate::range_decoder::RangeDecoder;
pub const SPREAD_VALUE_COUNT: usize = 4;
pub const SPREAD_MAX: u8 = 3;
pub const SPREAD_PDF: [u8; SPREAD_VALUE_COUNT] = [7, 2, 21, 2];
pub const SPREAD_FTB: u32 = 5;
pub const SPREAD_PDF_DENOMINATOR: u32 = 1 << SPREAD_FTB;
pub const SPREAD_ICDF: [u8; SPREAD_VALUE_COUNT] = [25, 23, 2, 0];
pub const SPREAD_F_R: [Option<u32>; SPREAD_VALUE_COUNT] = [None, Some(15), Some(10), Some(5)];
pub const SPREAD_PRE_ROTATION_MIN_BLOCK_LEN: usize = 8;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SpreadingError {
SpreadOutOfRange {
spread: u8,
},
ZeroDimensions,
ZeroBlocks,
ZeroStride,
BlocksDoNotDivideLength {
len: usize,
nb_blocks: usize,
},
}
impl core::fmt::Display for SpreadingError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match *self {
SpreadingError::SpreadOutOfRange { spread } => write!(
f,
"oxideav-opus: CELT spread value {spread} out of range \
(RFC 6716 Table 59 allows 0..=3)"
),
SpreadingError::ZeroDimensions => write!(
f,
"oxideav-opus: CELT §4.3.4.3 rotation gain requires N >= 1"
),
SpreadingError::ZeroBlocks => write!(
f,
"oxideav-opus: CELT §4.3.4.3 spreading requires nb_blocks >= 1"
),
SpreadingError::ZeroStride => write!(
f,
"oxideav-opus: CELT §4.3.4.3 strided rotation requires stride >= 1"
),
SpreadingError::BlocksDoNotDivideLength { len, nb_blocks } => write!(
f,
"oxideav-opus: CELT §4.3.4.3 vector length {len} is not a \
multiple of nb_blocks = {nb_blocks}"
),
}
}
}
impl std::error::Error for SpreadingError {}
pub fn decode_spread(rd: &mut RangeDecoder<'_>) -> u8 {
let symbol = rd.dec_icdf(&SPREAD_ICDF, SPREAD_FTB);
debug_assert!(symbol < SPREAD_VALUE_COUNT as u32);
symbol as u8
}
pub fn spread_f_r(spread: u8) -> Result<Option<u32>, SpreadingError> {
SPREAD_F_R
.get(usize::from(spread))
.copied()
.ok_or(SpreadingError::SpreadOutOfRange { spread })
}
pub fn rotation_gain(n: usize, k: u32, f_r: u32) -> Result<f64, SpreadingError> {
if n == 0 {
return Err(SpreadingError::ZeroDimensions);
}
let n = n as f64;
Ok(n / (n + f64::from(f_r) * f64::from(k)))
}
pub fn rotation_angle(g_r: f64) -> f64 {
core::f64::consts::PI * g_r * g_r / 4.0
}
pub fn spread_theta(n: usize, k: u32, spread: u8) -> Result<Option<f64>, SpreadingError> {
match spread_f_r(spread)? {
None => Ok(None),
Some(f_r) => Ok(Some(rotation_angle(rotation_gain(n, k, f_r)?))),
}
}
#[inline]
fn rot2(x: &mut [f64], i: usize, j: usize, cos_t: f64, sin_t: f64) {
let xi = x[i];
let xj = x[j];
x[i] = cos_t * xi + sin_t * xj;
x[j] = -sin_t * xi + cos_t * xj;
}
pub fn rotate_in_place(x: &mut [f64], theta: f64) {
let n = x.len();
if n < 2 {
return;
}
let cos_t = theta.cos();
let sin_t = theta.sin();
for i in 0..n - 1 {
rot2(x, i, i + 1, cos_t, sin_t);
}
for i in (0..n - 2).rev() {
rot2(x, i, i + 1, cos_t, sin_t);
}
}
pub fn spreading_stride(len: usize, nb_blocks: usize) -> Result<usize, SpreadingError> {
if nb_blocks == 0 {
return Err(SpreadingError::ZeroBlocks);
}
let stride = (len as f64 / nb_blocks as f64).sqrt().round() as usize;
Ok(stride.max(1))
}
pub fn rotate_strided(x: &mut [f64], stride: usize, theta: f64) -> Result<(), SpreadingError> {
if stride == 0 {
return Err(SpreadingError::ZeroStride);
}
if stride == 1 {
rotate_in_place(x, theta);
return Ok(());
}
let len = x.len();
let cos_t = theta.cos();
let sin_t = theta.sin();
for k in 0..stride.min(len) {
let set_len = (len - k).div_ceil(stride);
if set_len < 2 {
continue;
}
for m in 0..set_len - 1 {
rot2(x, k + m * stride, k + (m + 1) * stride, cos_t, sin_t);
}
for m in (0..set_len - 2).rev() {
rot2(x, k + m * stride, k + (m + 1) * stride, cos_t, sin_t);
}
}
Ok(())
}
pub fn apply_spreading(
x: &mut [f64],
k: u32,
spread: u8,
nb_blocks: usize,
) -> Result<(), SpreadingError> {
let f_r = spread_f_r(spread)?;
if nb_blocks == 0 {
return Err(SpreadingError::ZeroBlocks);
}
let len = x.len();
if len % nb_blocks != 0 {
return Err(SpreadingError::BlocksDoNotDivideLength { len, nb_blocks });
}
let Some(f_r) = f_r else {
return Ok(()); };
let block_len = len / nb_blocks;
if block_len < 2 {
return Ok(()); }
let theta = rotation_angle(rotation_gain(block_len, k, f_r)?);
if nb_blocks > 1 && block_len >= SPREAD_PRE_ROTATION_MIN_BLOCK_LEN {
let stride = spreading_stride(len, nb_blocks)?;
rotate_strided(x, stride, core::f64::consts::FRAC_PI_2 - theta)?;
}
for block in x.chunks_exact_mut(block_len) {
rotate_in_place(block, theta);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
const TOL: f64 = 1e-12;
fn l2(x: &[f64]) -> f64 {
x.iter().map(|v| v * v).sum::<f64>().sqrt()
}
fn pseudo_vector(len: usize, seed: u64) -> Vec<f64> {
let mut state = seed;
(0..len)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((state >> 33) as f64 / (1u64 << 31) as f64) - 0.5
})
.collect()
}
#[test]
fn spread_pdf_sums_to_denominator() {
let sum: u32 = SPREAD_PDF.iter().map(|&p| u32::from(p)).sum();
assert_eq!(sum, SPREAD_PDF_DENOMINATOR);
assert_eq!(1u32 << SPREAD_FTB, SPREAD_PDF_DENOMINATOR);
}
#[test]
fn spread_icdf_matches_pdf() {
let mut acc = SPREAD_PDF_DENOMINATOR;
for (k, &p) in SPREAD_PDF.iter().enumerate() {
acc -= u32::from(p);
assert_eq!(u32::from(SPREAD_ICDF[k]), acc, "icdf cell {k}");
}
assert_eq!(SPREAD_ICDF[SPREAD_VALUE_COUNT - 1], 0);
}
#[test]
fn decode_spread_is_always_in_table_59_range() {
for b0 in 0..=u8::MAX {
let buf = [b0, 0xA5, 0x5A, 0xFF];
let mut rd = RangeDecoder::new(&buf);
let spread = decode_spread(&mut rd);
assert!(spread <= SPREAD_MAX, "byte {b0:#x} -> spread {spread}");
assert!(spread_f_r(spread).is_ok());
}
}
#[test]
fn table_59_f_r_mapping() {
assert_eq!(spread_f_r(0), Ok(None));
assert_eq!(spread_f_r(1), Ok(Some(15)));
assert_eq!(spread_f_r(2), Ok(Some(10)));
assert_eq!(spread_f_r(3), Ok(Some(5)));
}
#[test]
fn spread_out_of_range_is_rejected() {
for spread in 4..=u8::MAX {
assert_eq!(
spread_f_r(spread),
Err(SpreadingError::SpreadOutOfRange { spread })
);
}
}
#[test]
fn rotation_gain_worked_points() {
let g = rotation_gain(16, 4, 5).unwrap();
assert!((g - 4.0 / 9.0).abs() < TOL);
assert_eq!(rotation_gain(7, 0, 15).unwrap(), 1.0);
assert!((rotation_gain(1, 1, 15).unwrap() - 1.0 / 16.0).abs() < TOL);
}
#[test]
fn rotation_gain_zero_dimensions_is_rejected() {
assert_eq!(rotation_gain(0, 1, 5), Err(SpreadingError::ZeroDimensions));
}
#[test]
fn rotation_angle_worked_points() {
assert!((rotation_angle(1.0) - core::f64::consts::FRAC_PI_4).abs() < TOL);
assert_eq!(rotation_angle(0.0), 0.0);
let theta = rotation_angle(4.0 / 9.0);
assert!((theta - 4.0 * core::f64::consts::PI / 81.0).abs() < TOL);
}
#[test]
fn theta_shrinks_with_more_pulses_and_larger_f_r() {
let t1 = spread_theta(16, 1, 3).unwrap().unwrap();
let t4 = spread_theta(16, 4, 3).unwrap().unwrap();
let t16 = spread_theta(16, 16, 3).unwrap().unwrap();
assert!(t1 > t4 && t4 > t16);
let s1 = spread_theta(16, 4, 1).unwrap().unwrap();
let s2 = spread_theta(16, 4, 2).unwrap().unwrap();
let s3 = spread_theta(16, 4, 3).unwrap().unwrap();
assert!(s1 < s2 && s2 < s3);
assert_eq!(spread_theta(16, 4, 0), Ok(None));
}
#[test]
fn two_dim_rotation_matches_definition() {
let theta = 0.3;
let mut x = [1.0, 0.0];
rotate_in_place(&mut x, theta);
assert!((x[0] - theta.cos()).abs() < TOL);
assert!((x[1] + theta.sin()).abs() < TOL);
}
#[test]
fn three_dim_series_matches_matrix_composition() {
let theta = 0.47f64;
let (c, s) = (theta.cos(), theta.sin());
let r01 = [[c, s, 0.0], [-s, c, 0.0], [0.0, 0.0, 1.0]];
let r12 = [[1.0, 0.0, 0.0], [0.0, c, s], [0.0, -s, c]];
let mat_vec = |m: &[[f64; 3]; 3], v: [f64; 3]| {
[
m[0][0] * v[0] + m[0][1] * v[1] + m[0][2] * v[2],
m[1][0] * v[0] + m[1][1] * v[1] + m[1][2] * v[2],
m[2][0] * v[0] + m[2][1] * v[1] + m[2][2] * v[2],
]
};
let v = [0.6, -1.1, 0.35];
let expected = mat_vec(&r01, mat_vec(&r12, mat_vec(&r01, v)));
let mut x = v;
rotate_in_place(&mut x, theta);
for (got, want) in x.iter().zip(expected.iter()) {
assert!((got - want).abs() < TOL, "{got} vs {want}");
}
}
#[test]
fn rotation_preserves_l2_norm() {
for n in 2..=16 {
for (i, theta) in [0.0, 0.1, core::f64::consts::FRAC_PI_4, 1.2]
.iter()
.enumerate()
{
let mut x = pseudo_vector(n, (n * 31 + i) as u64);
let before = l2(&x);
rotate_in_place(&mut x, *theta);
assert!((l2(&x) - before).abs() < 1e-9, "n={n} theta={theta}");
}
}
}
#[test]
fn zero_angle_is_identity() {
let mut x = pseudo_vector(9, 7);
let orig = x.clone();
rotate_in_place(&mut x, 0.0);
for (got, want) in x.iter().zip(orig.iter()) {
assert!((got - want).abs() < TOL);
}
}
#[test]
fn rotation_is_linear_in_sign() {
let theta = 0.9;
let x = pseudo_vector(11, 99);
let mut pos = x.clone();
let mut neg: Vec<f64> = x.iter().map(|v| -v).collect();
rotate_in_place(&mut pos, theta);
rotate_in_place(&mut neg, theta);
for (p, n) in pos.iter().zip(neg.iter()) {
assert!((p + n).abs() < TOL);
}
}
#[test]
fn short_vectors_are_untouched() {
let mut empty: [f64; 0] = [];
rotate_in_place(&mut empty, 1.0);
let mut one = [2.5];
rotate_in_place(&mut one, 1.0);
assert_eq!(one, [2.5]);
}
#[test]
fn spreading_stride_worked_points() {
assert_eq!(spreading_stride(16, 1), Ok(4)); assert_eq!(spreading_stride(15, 1), Ok(4)); assert_eq!(spreading_stride(12, 1), Ok(3)); assert_eq!(spreading_stride(8, 4), Ok(1)); assert_eq!(spreading_stride(32, 2), Ok(4)); assert_eq!(spreading_stride(25, 4), Ok(3)); assert_eq!(spreading_stride(0, 3), Ok(1)); assert_eq!(spreading_stride(16, 0), Err(SpreadingError::ZeroBlocks));
}
#[test]
fn strided_rotation_stride_one_matches_plain() {
let theta = 0.6;
let mut a = pseudo_vector(10, 5);
let mut b = a.clone();
rotate_in_place(&mut a, theta);
rotate_strided(&mut b, 1, theta).unwrap();
assert_eq!(a, b);
}
#[test]
fn strided_rotation_matches_gather_rotate_scatter() {
let theta = 0.8;
let stride = 3;
let mut x = pseudo_vector(14, 21);
let orig = x.clone();
rotate_strided(&mut x, stride, theta).unwrap();
for k in 0..stride {
let mut set: Vec<f64> = orig[k..].iter().step_by(stride).copied().collect();
rotate_in_place(&mut set, theta);
for (m, want) in set.iter().enumerate() {
let got = x[k + m * stride];
assert!((got - want).abs() < TOL, "set {k} member {m}");
}
}
}
#[test]
fn strided_rotation_leaves_other_sets_untouched() {
let theta = 1.0;
let stride = 4;
let mut x = vec![0.0; 16];
for m in 0..4 {
x[m * stride] = 1.0 + m as f64;
}
rotate_strided(&mut x, stride, theta).unwrap();
for (i, v) in x.iter().enumerate() {
if i % stride != 0 {
assert_eq!(*v, 0.0, "index {i} leaked across sets");
}
}
}
#[test]
fn strided_rotation_preserves_l2_norm() {
for stride in 1..=6 {
let mut x = pseudo_vector(17, stride as u64);
let before = l2(&x);
rotate_strided(&mut x, stride, 0.7).unwrap();
assert!((l2(&x) - before).abs() < 1e-9, "stride={stride}");
}
}
#[test]
fn strided_rotation_large_stride_is_identity() {
let mut x = pseudo_vector(6, 77);
let orig = x.clone();
rotate_strided(&mut x, 6, 1.3).unwrap(); assert_eq!(x, orig);
assert_eq!(
rotate_strided(&mut x, 0, 1.3),
Err(SpreadingError::ZeroStride)
);
}
#[test]
fn apply_spreading_spread_zero_is_identity() {
let mut x = pseudo_vector(24, 3);
let orig = x.clone();
apply_spreading(&mut x, 5, 0, 2).unwrap();
assert_eq!(x, orig);
}
#[test]
fn apply_spreading_single_block_matches_rotate_in_place() {
let n = 16;
let k = 4;
let mut x = pseudo_vector(n, 13);
let mut expected = x.clone();
let theta = spread_theta(n, k, 3).unwrap().unwrap();
rotate_in_place(&mut expected, theta);
apply_spreading(&mut x, k, 3, 1).unwrap();
assert_eq!(x, expected);
}
#[test]
fn apply_spreading_small_blocks_skip_pre_rotation() {
let k = 2;
let mut x = pseudo_vector(8, 41);
let mut expected = x.clone();
let theta = spread_theta(4, k, 2).unwrap().unwrap();
for block in expected.chunks_exact_mut(4) {
rotate_in_place(block, theta);
}
apply_spreading(&mut x, k, 2, 2).unwrap();
assert_eq!(x, expected);
}
#[test]
fn apply_spreading_multi_block_runs_pre_rotation() {
let k = 3;
let mut x = pseudo_vector(16, 8);
let before = l2(&x);
let mut no_pre = x.clone();
let theta = spread_theta(8, k, 3).unwrap().unwrap();
for block in no_pre.chunks_exact_mut(8) {
rotate_in_place(block, theta);
}
apply_spreading(&mut x, k, 3, 2).unwrap();
assert!((l2(&x) - before).abs() < 1e-9);
assert!(
x.iter()
.zip(no_pre.iter())
.any(|(a, b)| (a - b).abs() > 1e-6),
"pre-rotation had no effect"
);
let mut expected = pseudo_vector(16, 8);
let stride = spreading_stride(16, 2).unwrap();
rotate_strided(&mut expected, stride, core::f64::consts::FRAC_PI_2 - theta).unwrap();
for block in expected.chunks_exact_mut(8) {
rotate_in_place(block, theta);
}
assert_eq!(x, expected);
}
#[test]
fn apply_spreading_zero_vector_stays_zero() {
let mut x = vec![0.0; 12];
apply_spreading(&mut x, 0, 3, 1).unwrap();
assert!(x.iter().all(|&v| v == 0.0));
}
#[test]
fn apply_spreading_error_paths() {
let mut x = pseudo_vector(10, 1);
assert_eq!(
apply_spreading(&mut x, 1, 4, 1),
Err(SpreadingError::SpreadOutOfRange { spread: 4 })
);
assert_eq!(
apply_spreading(&mut x, 1, 2, 0),
Err(SpreadingError::ZeroBlocks)
);
assert_eq!(
apply_spreading(&mut x, 1, 2, 3),
Err(SpreadingError::BlocksDoNotDivideLength {
len: 10,
nb_blocks: 3
})
);
let mut tiny = pseudo_vector(3, 2);
let orig = tiny.clone();
apply_spreading(&mut tiny, 1, 3, 3).unwrap();
assert_eq!(tiny, orig);
}
#[test]
fn spreading_error_display_is_stable() {
let cases: [(SpreadingError, &str); 5] = [
(
SpreadingError::SpreadOutOfRange { spread: 9 },
"spread value 9",
),
(SpreadingError::ZeroDimensions, "N >= 1"),
(SpreadingError::ZeroBlocks, "nb_blocks >= 1"),
(SpreadingError::ZeroStride, "stride >= 1"),
(
SpreadingError::BlocksDoNotDivideLength {
len: 10,
nb_blocks: 3,
},
"length 10",
),
];
for (err, needle) in cases {
let msg = err.to_string();
assert!(msg.contains(needle), "{msg:?} lacks {needle:?}");
}
}
}