use crate::celt_band_layout::{
celt_band_bins_per_channel, celt_end_coded_band, celt_first_coded_band, CeltFrameSize,
CELT_NUM_BANDS,
};
use crate::celt_static_alloc::{
STATIC_ALLOC, STATIC_ALLOC_INTERP_STEPS, STATIC_ALLOC_Q_MAX, STATIC_ALLOC_RIGHT_SHIFT,
};
pub const Q_FP_MAX: u32 = STATIC_ALLOC_Q_MAX * STATIC_ALLOC_INTERP_STEPS;
pub const STATIC_ALLOC_INTERP_RIGHT_SHIFT: u32 = STATIC_ALLOC_RIGHT_SHIFT + 6;
const _: () = {
assert!(STATIC_ALLOC_INTERP_STEPS == 64);
assert!(STATIC_ALLOC_INTERP_RIGHT_SHIFT == 8);
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AllocSearchError {
ChannelsOutOfRange { channels: u32 },
QFpOutOfRange { q_fp: u32 },
BandOutOfRange { band: u32 },
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct QFpComponents {
pub q_lo: u32,
pub frac: u32,
}
pub const fn q_fp_to_components(q_fp: u32) -> Result<QFpComponents, AllocSearchError> {
if q_fp > Q_FP_MAX {
return Err(AllocSearchError::QFpOutOfRange { q_fp });
}
if q_fp == Q_FP_MAX {
return Ok(QFpComponents {
q_lo: STATIC_ALLOC_Q_MAX - 1,
frac: STATIC_ALLOC_INTERP_STEPS,
});
}
Ok(QFpComponents {
q_lo: q_fp / STATIC_ALLOC_INTERP_STEPS,
frac: q_fp % STATIC_ALLOC_INTERP_STEPS,
})
}
pub const fn q_fp_from_components(q_lo: u32, frac: u32) -> Result<u32, AllocSearchError> {
if q_lo > STATIC_ALLOC_Q_MAX {
return Err(AllocSearchError::QFpOutOfRange {
q_fp: q_lo * STATIC_ALLOC_INTERP_STEPS + frac,
});
}
if frac > STATIC_ALLOC_INTERP_STEPS {
return Err(AllocSearchError::QFpOutOfRange {
q_fp: q_lo * STATIC_ALLOC_INTERP_STEPS + frac,
});
}
if frac == STATIC_ALLOC_INTERP_STEPS && q_lo != STATIC_ALLOC_Q_MAX - 1 {
return Err(AllocSearchError::QFpOutOfRange {
q_fp: q_lo * STATIC_ALLOC_INTERP_STEPS + frac,
});
}
if q_lo == STATIC_ALLOC_Q_MAX && frac != 0 {
return Err(AllocSearchError::QFpOutOfRange {
q_fp: q_lo * STATIC_ALLOC_INTERP_STEPS + frac,
});
}
Ok(q_lo * STATIC_ALLOC_INTERP_STEPS + frac)
}
pub fn per_band_eighth_bits_at_q_fp(
band: u32,
q_fp: u32,
channels: u32,
n_bins: u32,
lm: u32,
) -> Result<u64, AllocSearchError> {
if channels == 0 || channels > 2 {
return Err(AllocSearchError::ChannelsOutOfRange { channels });
}
if band >= CELT_NUM_BANDS as u32 {
return Err(AllocSearchError::BandOutOfRange { band });
}
let comps = q_fp_to_components(q_fp)?;
let q_lo = comps.q_lo as usize;
let frac = comps.frac as u64;
let row = &STATIC_ALLOC[band as usize];
let cell_lo = row[q_lo] as u64;
let cell_hi = row[q_lo + 1] as u64;
let cell_q11 = cell_lo * (STATIC_ALLOC_INTERP_STEPS as u64 - frac) + cell_hi * frac;
let scaled = (channels as u64) * (n_bins as u64) * cell_q11;
Ok((scaled << lm) >> STATIC_ALLOC_INTERP_RIGHT_SHIFT)
}
pub fn total_eighth_bits_at_q_fp(
q_fp: u32,
channels: u32,
frame_size: CeltFrameSize,
is_hybrid: bool,
) -> Result<u64, AllocSearchError> {
if channels == 0 || channels > 2 {
return Err(AllocSearchError::ChannelsOutOfRange { channels });
}
if q_fp > Q_FP_MAX {
return Err(AllocSearchError::QFpOutOfRange { q_fp });
}
let lm = frame_size.column_index() as u32;
let first = celt_first_coded_band(is_hybrid);
let end = celt_end_coded_band();
let mut sum: u64 = 0;
for band in first..end {
let n_bins = celt_band_bins_per_channel(band, frame_size)
.expect("first..end is in-range for celt_band_bins_per_channel")
as u32;
sum += per_band_eighth_bits_at_q_fp(band as u32, q_fp, channels, n_bins, lm)?;
}
Ok(sum)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct AllocSearchOutcome {
pub q_fp: u32,
pub total_eighth_bits: u64,
}
pub fn search_q_fp(
budget_eighth_bits: u64,
channels: u32,
frame_size: CeltFrameSize,
is_hybrid: bool,
) -> Result<AllocSearchOutcome, AllocSearchError> {
if channels == 0 || channels > 2 {
return Err(AllocSearchError::ChannelsOutOfRange { channels });
}
let mut q_fp = Q_FP_MAX;
loop {
let total = total_eighth_bits_at_q_fp(q_fp, channels, frame_size, is_hybrid)?;
if total <= budget_eighth_bits {
return Ok(AllocSearchOutcome {
q_fp,
total_eighth_bits: total,
});
}
if q_fp == 0 {
return Ok(AllocSearchOutcome {
q_fp: 0,
total_eighth_bits: 0,
});
}
q_fp -= 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::celt_static_alloc::static_alloc_eighth_bits;
#[test]
fn q_fp_max_constant() {
assert_eq!(Q_FP_MAX, 640);
assert_eq!(STATIC_ALLOC_INTERP_RIGHT_SHIFT, 8);
}
#[test]
fn q_fp_zero_decomposes_to_q_lo_zero_frac_zero() {
let c = q_fp_to_components(0).unwrap();
assert_eq!(c.q_lo, 0);
assert_eq!(c.frac, 0);
}
#[test]
fn q_fp_at_integer_column_has_zero_frac() {
for q in 0..=9 {
let c = q_fp_to_components(q * 64).unwrap();
assert_eq!(c.q_lo, q);
assert_eq!(c.frac, 0);
}
}
#[test]
fn q_fp_saturation_decomposes_to_q_lo_9_frac_64() {
let c = q_fp_to_components(Q_FP_MAX).unwrap();
assert_eq!(c.q_lo, 9);
assert_eq!(c.frac, 64);
}
#[test]
fn q_fp_mid_step_decomposes() {
let c = q_fp_to_components(352).unwrap();
assert_eq!(c.q_lo, 5);
assert_eq!(c.frac, 32);
}
#[test]
fn q_fp_to_components_rejects_out_of_range() {
assert_eq!(
q_fp_to_components(Q_FP_MAX + 1).unwrap_err(),
AllocSearchError::QFpOutOfRange { q_fp: Q_FP_MAX + 1 },
);
assert_eq!(
q_fp_to_components(u32::MAX).unwrap_err(),
AllocSearchError::QFpOutOfRange { q_fp: u32::MAX },
);
}
#[test]
fn q_fp_round_trip_through_components() {
for q_fp in 0..=Q_FP_MAX {
let c = q_fp_to_components(q_fp).unwrap();
let back = q_fp_from_components(c.q_lo, c.frac).unwrap();
assert_eq!(back, q_fp, "round trip failed at q_fp = {}", q_fp);
}
}
#[test]
fn q_fp_from_components_rejects_invalid_combinations() {
assert!(q_fp_from_components(10, 1).is_err());
assert!(q_fp_from_components(0, 64).is_err());
assert!(q_fp_from_components(8, 64).is_err());
assert!(q_fp_from_components(5, 65).is_err());
assert!(q_fp_from_components(11, 0).is_err());
}
#[test]
fn per_band_at_integer_q_matches_static_alloc_eighth_bits() {
for band in 0..(CELT_NUM_BANDS as u32) {
for q in 0..=9u32 {
for &channels in &[1u32, 2] {
for &n_bins in &[1u32, 4, 88, 176] {
for lm in 0..4u32 {
let got =
per_band_eighth_bits_at_q_fp(band, q * 64, channels, n_bins, lm)
.unwrap();
let want =
static_alloc_eighth_bits(band, q, channels, n_bins, lm).unwrap();
assert_eq!(
got, want as u64,
"parity at band {band} q {q} channels {channels} n_bins {n_bins} lm {lm}",
);
}
}
}
}
}
}
#[test]
fn per_band_at_saturation_matches_column_ten() {
for band in 0..(CELT_NUM_BANDS as u32) {
for &channels in &[1u32, 2] {
for &n_bins in &[1u32, 88] {
for lm in 0..4u32 {
let got =
per_band_eighth_bits_at_q_fp(band, Q_FP_MAX, channels, n_bins, lm)
.unwrap();
let want =
static_alloc_eighth_bits(band, 10, channels, n_bins, lm).unwrap();
assert_eq!(got, want as u64, "saturation parity at band {band}");
}
}
}
}
}
#[test]
fn per_band_at_column_zero_is_zero() {
for band in 0..(CELT_NUM_BANDS as u32) {
assert_eq!(per_band_eighth_bits_at_q_fp(band, 0, 2, 88, 3).unwrap(), 0);
}
}
#[test]
fn per_band_monotone_in_q_fp() {
for band in 0..(CELT_NUM_BANDS as u32) {
let mut prev = 0u64;
for q_fp in 0..=Q_FP_MAX {
let cur = per_band_eighth_bits_at_q_fp(band, q_fp, 1, 4, 0).unwrap();
assert!(
cur >= prev,
"monotonicity failed at band {band} q_fp {q_fp}: {cur} < {prev}",
);
prev = cur;
}
}
}
#[test]
fn per_band_rejects_invalid_band() {
assert_eq!(
per_band_eighth_bits_at_q_fp(CELT_NUM_BANDS as u32, 0, 1, 1, 0).unwrap_err(),
AllocSearchError::BandOutOfRange {
band: CELT_NUM_BANDS as u32
},
);
}
#[test]
fn per_band_rejects_invalid_channels() {
assert_eq!(
per_band_eighth_bits_at_q_fp(0, 0, 0, 1, 0).unwrap_err(),
AllocSearchError::ChannelsOutOfRange { channels: 0 },
);
assert_eq!(
per_band_eighth_bits_at_q_fp(0, 0, 3, 1, 0).unwrap_err(),
AllocSearchError::ChannelsOutOfRange { channels: 3 },
);
}
#[test]
fn per_band_rejects_invalid_q_fp() {
assert_eq!(
per_band_eighth_bits_at_q_fp(0, Q_FP_MAX + 1, 1, 1, 0).unwrap_err(),
AllocSearchError::QFpOutOfRange { q_fp: Q_FP_MAX + 1 },
);
}
#[test]
fn total_at_q_fp_zero_is_zero() {
for &is_hybrid in &[false, true] {
for fs in [
CeltFrameSize::Ms2_5,
CeltFrameSize::Ms5,
CeltFrameSize::Ms10,
CeltFrameSize::Ms20,
] {
for &channels in &[1u32, 2] {
assert_eq!(
total_eighth_bits_at_q_fp(0, channels, fs, is_hybrid).unwrap(),
0,
"fs {:?} channels {} is_hybrid {}",
fs,
channels,
is_hybrid,
);
}
}
}
}
#[test]
fn total_monotone_in_q_fp() {
for &is_hybrid in &[false, true] {
for fs in [
CeltFrameSize::Ms2_5,
CeltFrameSize::Ms5,
CeltFrameSize::Ms10,
CeltFrameSize::Ms20,
] {
for &channels in &[1u32, 2] {
let mut prev = 0u64;
for q_fp in 0..=Q_FP_MAX {
let cur = total_eighth_bits_at_q_fp(q_fp, channels, fs, is_hybrid).unwrap();
assert!(
cur >= prev,
"total non-monotone at fs {:?} channels {} is_hybrid {} q_fp {}: {} < {}",
fs, channels, is_hybrid, q_fp, cur, prev,
);
prev = cur;
}
}
}
}
}
#[test]
fn total_celt_only_exceeds_hybrid_for_same_q_fp() {
let celt_only = total_eighth_bits_at_q_fp(Q_FP_MAX, 1, CeltFrameSize::Ms20, false).unwrap();
let hybrid = total_eighth_bits_at_q_fp(Q_FP_MAX, 1, CeltFrameSize::Ms20, true).unwrap();
assert!(
celt_only > hybrid,
"expected CELT-only > Hybrid at saturation; got celt = {}, hybrid = {}",
celt_only,
hybrid,
);
}
#[test]
fn total_stereo_at_least_mono_at_saturation() {
for fs in [
CeltFrameSize::Ms2_5,
CeltFrameSize::Ms5,
CeltFrameSize::Ms10,
CeltFrameSize::Ms20,
] {
let mono = total_eighth_bits_at_q_fp(Q_FP_MAX, 1, fs, false).unwrap();
let stereo = total_eighth_bits_at_q_fp(Q_FP_MAX, 2, fs, false).unwrap();
assert!(
stereo >= mono * 2,
"stereo < 2 * mono at fs {:?}: mono = {} stereo = {}",
fs,
mono,
stereo,
);
assert!(
stereo <= mono * 2 + CELT_NUM_BANDS as u64,
"stereo - 2 * mono exceeds per-band slack at fs {:?}: mono = {} stereo = {}",
fs,
mono,
stereo,
);
}
}
#[test]
fn total_rejects_invalid_channels() {
assert!(matches!(
total_eighth_bits_at_q_fp(0, 0, CeltFrameSize::Ms20, false),
Err(AllocSearchError::ChannelsOutOfRange { channels: 0 }),
));
}
#[test]
fn total_rejects_invalid_q_fp() {
assert!(matches!(
total_eighth_bits_at_q_fp(Q_FP_MAX + 1, 1, CeltFrameSize::Ms20, false),
Err(AllocSearchError::QFpOutOfRange { .. }),
));
}
#[test]
fn search_zero_budget_returns_q_fp_zero() {
let out = search_q_fp(0, 1, CeltFrameSize::Ms20, false).unwrap();
assert_eq!(out.q_fp, 0);
assert_eq!(out.total_eighth_bits, 0);
}
#[test]
fn search_saturation_budget_returns_q_fp_max() {
let out = search_q_fp(u64::MAX, 1, CeltFrameSize::Ms20, false).unwrap();
assert_eq!(out.q_fp, Q_FP_MAX);
let expect = total_eighth_bits_at_q_fp(Q_FP_MAX, 1, CeltFrameSize::Ms20, false).unwrap();
assert_eq!(out.total_eighth_bits, expect);
}
#[test]
fn search_picks_exact_budget() {
let fs = CeltFrameSize::Ms20;
let channels = 1u32;
for &q_fp in &[0u32, 64, 128, 320, 384, 640] {
let exact = total_eighth_bits_at_q_fp(q_fp, channels, fs, false).unwrap();
let out = search_q_fp(exact, channels, fs, false).unwrap();
assert!(
out.q_fp >= q_fp,
"search undercut exact target {}: got {}",
q_fp,
out.q_fp,
);
assert!(
out.total_eighth_bits <= exact,
"search violated budget cap at exact {}: total {} > budget",
exact,
out.total_eighth_bits,
);
}
}
#[test]
fn search_budget_one_less_picks_lower_q_fp() {
let fs = CeltFrameSize::Ms20;
let channels = 2u32;
let q_fp = 320u32; let here = total_eighth_bits_at_q_fp(q_fp, channels, fs, false).unwrap();
let down = total_eighth_bits_at_q_fp(q_fp - 1, channels, fs, false).unwrap();
assert!(
here > down,
"test precondition: total at q_fp 320 should exceed q_fp 319",
);
let out = search_q_fp(here - 1, channels, fs, false).unwrap();
assert!(
out.q_fp < q_fp,
"search exceeded budget at q_fp 320 - 1: got {} (total {})",
out.q_fp,
out.total_eighth_bits,
);
assert!(
out.total_eighth_bits < here,
"search violated budget: {} >= {}",
out.total_eighth_bits,
here,
);
}
#[test]
fn search_result_is_self_consistent() {
let budgets = [0u64, 100, 1_000, 10_000, 100_000];
for &is_hybrid in &[false, true] {
for fs in [
CeltFrameSize::Ms2_5,
CeltFrameSize::Ms5,
CeltFrameSize::Ms10,
CeltFrameSize::Ms20,
] {
for &channels in &[1u32, 2] {
for &budget in &budgets {
let out = search_q_fp(budget, channels, fs, is_hybrid).unwrap();
let recomputed =
total_eighth_bits_at_q_fp(out.q_fp, channels, fs, is_hybrid).unwrap();
assert_eq!(
out.total_eighth_bits, recomputed,
"total mismatch at fs {:?} channels {} is_hybrid {} budget {}",
fs, channels, is_hybrid, budget,
);
assert!(
out.total_eighth_bits <= budget,
"search exceeded budget at fs {:?} budget {}: total {}",
fs,
budget,
out.total_eighth_bits,
);
if out.q_fp < Q_FP_MAX {
let next =
total_eighth_bits_at_q_fp(out.q_fp + 1, channels, fs, is_hybrid)
.unwrap();
assert!(
next > budget,
"search undercut at fs {:?} budget {}: next q_fp = {} total {} <= budget",
fs, budget, out.q_fp + 1, next,
);
}
}
}
}
}
}
#[test]
fn search_rejects_invalid_channels() {
assert!(matches!(
search_q_fp(1000, 0, CeltFrameSize::Ms20, false),
Err(AllocSearchError::ChannelsOutOfRange { channels: 0 }),
));
assert!(matches!(
search_q_fp(1000, 3, CeltFrameSize::Ms20, false),
Err(AllocSearchError::ChannelsOutOfRange { channels: 3 }),
));
}
}