use std::num::NonZeroUsize;
use non_empty_slice::{NonEmptySlice, NonEmptyVec};
use super::BandMetrics;
#[derive(Debug, Clone, PartialEq)]
pub struct BandAllocation {
pub start_bin: usize,
pub end_bin: usize,
pub bits: u32,
pub step_size: f32,
}
#[derive(Debug, Clone, PartialEq)]
pub struct BitAllocationResult {
pub allocations: NonEmptyVec<BandAllocation>,
}
#[inline]
#[must_use]
pub fn step_size_from_allowed_noise(allowed_noise_db: f32) -> f32 {
let noise_amplitude = 10.0_f32.powf(allowed_noise_db / 20.0);
(noise_amplitude * 12.0_f32.sqrt()).max(1e-6)
}
#[must_use]
pub fn allocate_bits(
band_metrics: &BandMetrics,
total_bits: u32,
min_bits_per_band: u8,
) -> BitAllocationResult {
let n_bands = band_metrics.metrics.len().get();
let reserved = (n_bands as u32).saturating_mul(min_bits_per_band as u32);
let remaining = total_bits.saturating_sub(reserved);
let band_weight = |m: &super::BandMetric| -> f32 {
let energy_linear = 10.0_f32.powf(m.energy / 10.0);
energy_linear * (1.0 + m.importance.max(0.0))
};
let total_weight: f32 = band_metrics.metrics.iter().map(band_weight).sum();
let allocations: Vec<BandAllocation> = band_metrics
.metrics
.iter()
.map(|m| {
let extra_bits = if total_weight > 0.0 {
let fraction = band_weight(m) / total_weight;
(fraction * remaining as f32).round().max(0.0) as u32
} else {
0
};
let bits = u32::from(min_bits_per_band).saturating_add(extra_bits);
let step_size = step_size_from_allowed_noise(m.allowed_noise);
BandAllocation {
start_bin: m.band.start_bin,
end_bin: m.band.end_bin,
bits,
step_size,
}
})
.collect();
let allocations = unsafe { NonEmptyVec::new_unchecked(allocations) };
BitAllocationResult { allocations }
}
#[inline]
#[must_use]
pub fn quantize_band(coefficients: &[f32], step_size: f32) -> Vec<i32> {
coefficients
.iter()
.map(|&c| (c / step_size).round() as i32)
.collect()
}
#[inline]
#[must_use]
pub fn dequantize_band(quantized: &[i32], step_size: f32) -> Vec<f32> {
quantized.iter().map(|&q| q as f32 * step_size).collect()
}
#[must_use]
pub fn quantize(
coefficients: &NonEmptySlice<f32>,
n_coefficients: NonZeroUsize,
n_frames: NonZeroUsize,
allocation: &BitAllocationResult,
) -> NonEmptyVec<i32> {
let nc = n_coefficients.get();
let nf = n_frames.get();
debug_assert_eq!(coefficients.len().get(), nc * nf);
let mut out = vec![0i32; nc * nf];
for alloc in allocation.allocations.iter() {
let step = alloc.step_size;
for k in alloc.start_bin..alloc.end_bin.min(nc) {
for f in 0..nf {
let idx = k * nf + f;
out[idx] = (coefficients[idx] / step).round() as i32;
}
}
}
unsafe { NonEmptyVec::new_unchecked(out) }
}
#[must_use]
pub fn dequantize(
quantized: &NonEmptySlice<i32>,
n_coefficients: NonZeroUsize,
n_frames: NonZeroUsize,
allocation: &BitAllocationResult,
) -> NonEmptyVec<f32> {
let nc = n_coefficients.get();
let nf = n_frames.get();
debug_assert_eq!(quantized.len().get(), nc * nf);
let mut out = vec![0.0f32; nc * nf];
for alloc in allocation.allocations.iter() {
let step = alloc.step_size;
for k in alloc.start_bin..alloc.end_bin.min(nc) {
for f in 0..nf {
let idx = k * nf + f;
out[idx] = quantized[idx] as f32 * step;
}
}
}
unsafe { NonEmptyVec::new_unchecked(out) }
}
const MAX_WORDLENGTH_BITS: u32 = 24;
pub fn refine_step_sizes(
allocation: &mut BitAllocationResult,
coefficients: &NonEmptySlice<f32>,
n_coefficients: NonZeroUsize,
n_frames: NonZeroUsize,
) {
let nc = n_coefficients.get();
let nf = n_frames.get();
struct BandStat {
peak: f32,
variance: f32,
n_coeffs: usize,
}
let stats: Vec<BandStat> = allocation
.allocations
.iter()
.map(|alloc| {
let k_end = alloc.end_bin.min(nc);
if alloc.start_bin >= k_end {
return BandStat {
peak: 0.0,
variance: 0.0,
n_coeffs: 0,
};
}
let mut peak = 0.0f32;
let mut sum_sq = 0.0f64;
let mut n = 0usize;
for k in alloc.start_bin..k_end {
for f in 0..nf {
let v = coefficients[k * nf + f];
let a = v.abs();
if a > peak {
peak = a;
}
sum_sq += f64::from(v) * f64::from(v);
n += 1;
}
}
let variance = if n > 0 {
(sum_sq / n as f64) as f32
} else {
0.0
};
BandStat {
peak,
variance,
n_coeffs: n,
}
})
.collect();
let total_bits: u64 = allocation
.allocations
.iter()
.map(|a| u64::from(a.bits))
.sum();
let total_coeffs: usize = stats.iter().map(|s| s.n_coeffs).sum();
if total_coeffs == 0 {
return;
}
let r_avg = total_bits as f32 / total_coeffs as f32;
let mut log_sum = 0.0f32;
let mut active = 0usize;
for s in &stats {
if s.variance > 0.0 {
log_sum += s.variance.log2();
active += 1;
}
}
let geomean_log2 = if active > 0 {
log_sum / active as f32
} else {
0.0
};
for (alloc, s) in allocation.allocations.iter_mut().zip(stats.iter()) {
if s.peak <= 0.0 || s.variance <= 0.0 {
alloc.step_size = (2.0 * s.peak).max(1e-6);
continue;
}
let w = (r_avg + 0.5 * (s.variance.log2() - geomean_log2))
.clamp(0.0, MAX_WORDLENGTH_BITS as f32);
if w < 0.5 {
alloc.step_size = (2.0 * s.peak).max(1e-6);
} else {
let levels = 2.0_f32.powf(w);
alloc.step_size = ((2.0 * s.peak) / levels).max(1e-6);
}
}
}