use crate::{CodecError, CodecResult};
use super::mdct::{Mdct, OverlapAdd};
use super::packet::OpusBandwidth;
use super::range_decoder::RangeDecoder;
use super::range_encoder::RangeEncoder;
const CELT_BANDS: usize = 21;
const MIN_BAND_WIDTH: usize = 2;
const BARK_BAND_BOUNDARIES: [usize; CELT_BANDS + 1] = [
0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32, 40, 48, 56, 68, 80, 96, 120, 156, 240,
];
const ENERGY_FINE_BITS: [u8; CELT_BANDS] = [
3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1,
];
const ALLOCATION_TRIM: [f32; 11] = [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0];
const POST_FILTER_COEFFS: [f32; 3] = [0.85, 0.0, -0.85];
#[derive(Debug)]
pub struct CeltDecoder {
sample_rate: u32,
channels: usize,
#[allow(dead_code)]
bandwidth: OpusBandwidth,
mdct: Mdct,
overlap_add: Vec<OverlapAdd>,
band_energy: Vec<f32>,
frame_size: usize,
postfilter_state: Vec<Vec<f32>>,
fine_energy_prev: Vec<f32>,
}
impl CeltDecoder {
pub fn new(
sample_rate: u32,
channels: usize,
bandwidth: OpusBandwidth,
frame_size: usize,
) -> Self {
let mdct = Mdct::new(frame_size);
let overlap_add = (0..channels).map(|_| OverlapAdd::new(frame_size)).collect();
Self {
sample_rate,
channels,
bandwidth,
mdct,
overlap_add,
band_energy: vec![0.0; CELT_BANDS],
frame_size,
postfilter_state: vec![vec![0.0; 3]; channels],
fine_energy_prev: vec![0.0; CELT_BANDS],
}
}
pub fn decode(
&mut self,
data: &[u8],
output: &mut [f32],
frame_size: usize,
) -> CodecResult<()> {
if output.len() < frame_size * self.channels {
return Err(CodecError::InvalidData(
"Output buffer too small".to_string(),
));
}
let mut decoder = RangeDecoder::new(data)?;
let global_params = self.decode_global_params(&mut decoder)?;
for ch in 0..self.channels {
self.decode_channel(&mut decoder, output, frame_size, ch, &global_params)?;
}
Ok(())
}
fn decode_global_params(&mut self, decoder: &mut RangeDecoder) -> CodecResult<GlobalParams> {
let silence = decoder.decode_bit(16384)?;
let postfilter = decoder.decode_bit(16384)?;
let transient = decoder.decode_bit(16384)?;
let trim_index = decoder.decode_uniform(11)? as usize;
let allocation_trim = if trim_index < ALLOCATION_TRIM.len() {
ALLOCATION_TRIM[trim_index]
} else {
0.0
};
let fine_bits = decoder.decode_uniform(4)? as u8;
Ok(GlobalParams {
silence,
postfilter,
transient,
allocation_trim,
fine_bits,
})
}
fn decode_channel(
&mut self,
decoder: &mut RangeDecoder,
output: &mut [f32],
frame_size: usize,
channel: usize,
global_params: &GlobalParams,
) -> CodecResult<()> {
if global_params.silence {
let ch_offset = channel * frame_size;
for i in 0..frame_size {
output[ch_offset + i] = 0.0;
}
return Ok(());
}
let mut coeffs = vec![0.0f32; frame_size];
let mut time_domain = vec![0.0f32; 2 * frame_size];
let band_sizes = self.get_band_sizes(frame_size);
self.decode_coarse_energy(decoder, &band_sizes)?;
self.decode_fine_energy(decoder, &band_sizes, global_params.fine_bits)?;
let bit_allocation = self.compute_bit_allocation(&band_sizes, global_params)?;
self.decode_pvq(decoder, &mut coeffs, &band_sizes, &bit_allocation)?;
self.denormalize_coeffs(&mut coeffs, &band_sizes);
self.mdct.inverse(&coeffs, &mut time_domain);
let mut frame_samples = vec![0.0f32; frame_size];
self.overlap_add[channel].process(&time_domain, &mut frame_samples);
if global_params.postfilter {
self.apply_postfilter(&mut frame_samples, channel);
}
let ch_offset = channel * frame_size;
output[ch_offset..ch_offset + frame_size].copy_from_slice(&frame_samples);
Ok(())
}
fn decode_coarse_energy(
&mut self,
decoder: &mut RangeDecoder,
band_sizes: &[usize],
) -> CodecResult<()> {
for (band_idx, _) in band_sizes.iter().enumerate() {
if band_idx >= CELT_BANDS {
break;
}
let energy_delta = decoder.decode_int(6)? as f32;
let predicted = self.band_energy[band_idx];
let alpha = 0.9;
self.band_energy[band_idx] = alpha * predicted + energy_delta * 0.5;
}
Ok(())
}
fn decode_fine_energy(
&mut self,
decoder: &mut RangeDecoder,
band_sizes: &[usize],
fine_bits: u8,
) -> CodecResult<()> {
for (band_idx, _) in band_sizes.iter().enumerate() {
if band_idx >= CELT_BANDS {
break;
}
let bits = ENERGY_FINE_BITS[band_idx].min(fine_bits);
if bits > 0 {
let fine_energy = decoder.decode_uint(u32::from(bits))?;
let fine_scale = 1.0 / ((1 << bits) as f32);
self.band_energy[band_idx] += (fine_energy as f32) * fine_scale;
}
self.fine_energy_prev[band_idx] = self.band_energy[band_idx];
}
Ok(())
}
fn compute_bit_allocation(
&self,
band_sizes: &[usize],
global_params: &GlobalParams,
) -> CodecResult<Vec<u32>> {
let mut allocation = Vec::with_capacity(band_sizes.len());
for (band_idx, &band_size) in band_sizes.iter().enumerate() {
if band_idx >= CELT_BANDS {
allocation.push(0);
continue;
}
let base = (band_size as f32 * 2.0).log2().max(0.0);
let adjusted = base + global_params.allocation_trim;
let bits = adjusted.max(0.0) as u32;
allocation.push(bits);
}
Ok(allocation)
}
fn decode_pvq(
&mut self,
decoder: &mut RangeDecoder,
coeffs: &mut [f32],
band_sizes: &[usize],
bit_allocation: &[u32],
) -> CodecResult<()> {
let mut offset = 0;
for (band_idx, &band_size) in band_sizes.iter().enumerate() {
if band_idx >= band_sizes.len() || band_idx >= bit_allocation.len() {
break;
}
let bits = bit_allocation[band_idx];
if bits > 0 && offset + band_size <= coeffs.len() {
let k = self.decode_pulse_count(decoder, bits)?;
if k > 0 {
self.decode_pvq_vector(
decoder,
&mut coeffs[offset..offset + band_size],
k,
band_size,
)?;
} else {
for i in 0..band_size {
coeffs[offset + i] = 0.0;
}
}
}
offset += band_size;
}
Ok(())
}
fn decode_pulse_count(&self, decoder: &mut RangeDecoder, bits: u32) -> CodecResult<u32> {
if bits <= 3 {
Ok(bits)
} else {
decoder.decode_uniform(bits + 1)
}
}
fn decode_pvq_vector(
&self,
decoder: &mut RangeDecoder,
band: &mut [f32],
k: u32,
n: usize,
) -> CodecResult<()> {
if k == 0 || n == 0 {
band.fill(0.0);
return Ok(());
}
if n == 1 {
band[0] = k as f32;
let sign = decoder.decode_bit(16384)?;
if sign {
band[0] = -band[0];
}
return Ok(());
}
let mid = n / 2;
let k_left = self.decode_pvq_split(decoder, k, n)?;
let k_right = k.saturating_sub(k_left);
self.decode_pvq_vector(decoder, &mut band[..mid], k_left, mid)?;
self.decode_pvq_vector(decoder, &mut band[mid..], k_right, n - mid)?;
Ok(())
}
fn decode_pvq_split(&self, decoder: &mut RangeDecoder, k: u32, n: usize) -> CodecResult<u32> {
if k == 0 {
return Ok(0);
}
let max_split = k + 1;
let split = decoder.decode_uniform(max_split)?;
Ok(split.min(k))
}
fn denormalize_coeffs(&self, coeffs: &mut [f32], band_sizes: &[usize]) {
let mut offset = 0;
for (band_idx, &band_size) in band_sizes.iter().enumerate() {
if band_idx >= CELT_BANDS || offset >= coeffs.len() {
break;
}
let energy = self.band_energy[band_idx].exp();
let coeffs_len = coeffs.len();
let end = offset.saturating_add(band_size).min(coeffs_len);
let band_slice = &mut coeffs[offset..end];
let band_norm = self.compute_band_norm(band_slice);
if band_norm > 1e-10 {
let scale = energy / band_norm;
for coeff in band_slice.iter_mut() {
*coeff *= scale;
}
}
offset = offset.saturating_add(band_size);
}
}
fn compute_band_norm(&self, band: &[f32]) -> f32 {
band.iter().map(|x| x * x).sum::<f32>().sqrt()
}
fn get_band_sizes(&self, frame_size: usize) -> Vec<usize> {
let mut sizes = Vec::new();
let scale = frame_size as f32 / 240.0;
for i in 0..CELT_BANDS {
let start = (BARK_BAND_BOUNDARIES[i] as f32 * scale) as usize;
let end = (BARK_BAND_BOUNDARIES[i + 1] as f32 * scale) as usize;
let size = end.saturating_sub(start).max(MIN_BAND_WIDTH);
sizes.push(size);
}
let total: usize = sizes.iter().sum();
if total < frame_size {
if let Some(last) = sizes.last_mut() {
*last += frame_size - total;
}
} else if total > frame_size && !sizes.is_empty() {
if let Some(last) = sizes.last_mut() {
*last = last.saturating_sub(total - frame_size);
}
}
sizes
}
fn apply_postfilter(&mut self, samples: &mut [f32], channel: usize) {
let state = &mut self.postfilter_state[channel];
for sample in samples.iter_mut() {
let mut filtered = *sample;
for (i, &coeff) in POST_FILTER_COEFFS.iter().enumerate() {
if i < state.len() {
filtered += coeff * state[i];
}
}
state.rotate_right(1);
state[0] = *sample;
*sample = filtered;
}
}
pub fn reset(&mut self) {
for ola in &mut self.overlap_add {
ola.reset();
}
self.band_energy.fill(0.0);
self.fine_energy_prev.fill(0.0);
for state in &mut self.postfilter_state {
state.fill(0.0);
}
}
#[must_use]
pub const fn sample_rate(&self) -> u32 {
self.sample_rate
}
#[must_use]
pub const fn channels(&self) -> usize {
self.channels
}
#[must_use]
pub const fn frame_size(&self) -> usize {
self.frame_size
}
}
#[derive(Debug, Clone)]
struct GlobalParams {
silence: bool,
postfilter: bool,
#[allow(dead_code)]
transient: bool,
allocation_trim: f32,
fine_bits: u8,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_celt_decoder_creation() {
let decoder = CeltDecoder::new(48000, 2, OpusBandwidth::Fullband, 480);
assert_eq!(decoder.sample_rate(), 48000);
assert_eq!(decoder.channels(), 2);
assert_eq!(decoder.frame_size(), 480);
}
#[test]
fn test_celt_decoder_decode() {
let mut decoder = CeltDecoder::new(48000, 1, OpusBandwidth::Fullband, 480);
let data = vec![0x80, 0x00, 0x00, 0x00];
let mut output = vec![0.0f32; 480];
let result = decoder.decode(&data, &mut output, 480);
assert!(result.is_ok());
}
#[test]
fn test_celt_band_sizes() {
let decoder = CeltDecoder::new(48000, 1, OpusBandwidth::Fullband, 480);
let sizes = decoder.get_band_sizes(480);
assert_eq!(sizes.len(), CELT_BANDS);
let total: usize = sizes.iter().sum();
assert_eq!(total, 480);
}
#[test]
fn test_celt_band_norm() {
let decoder = CeltDecoder::new(48000, 1, OpusBandwidth::Fullband, 480);
let band = vec![3.0f32, 4.0f32];
let norm = decoder.compute_band_norm(&band);
assert!((norm - 5.0).abs() < 0.001);
}
#[test]
fn test_celt_reset() {
let mut decoder = CeltDecoder::new(48000, 1, OpusBandwidth::Fullband, 480);
decoder.band_energy[0] = 10.0;
decoder.reset();
assert_eq!(decoder.band_energy[0], 0.0);
}
#[test]
fn test_bark_band_boundaries() {
assert_eq!(BARK_BAND_BOUNDARIES.len(), CELT_BANDS + 1);
assert_eq!(BARK_BAND_BOUNDARIES[0], 0);
for i in 1..BARK_BAND_BOUNDARIES.len() {
assert!(BARK_BAND_BOUNDARIES[i] > BARK_BAND_BOUNDARIES[i - 1]);
}
}
#[test]
fn test_energy_fine_bits() {
assert_eq!(ENERGY_FINE_BITS.len(), CELT_BANDS);
for &bits in &ENERGY_FINE_BITS {
assert!(bits <= 4);
}
}
}
#[derive(Debug)]
pub struct CeltEncoder {
sample_rate: u32,
channels: usize,
bandwidth: OpusBandwidth,
mdct: Mdct,
overlap_add: Vec<OverlapAdd>,
band_energy: Vec<f32>,
#[allow(dead_code)]
frame_size: usize,
}
impl CeltEncoder {
pub fn new(
sample_rate: u32,
channels: usize,
bandwidth: OpusBandwidth,
frame_size: usize,
) -> Self {
let mdct = Mdct::new(frame_size);
let overlap_add = (0..channels).map(|_| OverlapAdd::new(frame_size)).collect();
Self {
sample_rate,
channels,
bandwidth,
mdct,
overlap_add,
band_energy: vec![0.0; CELT_BANDS],
frame_size,
}
}
pub fn encode(
&mut self,
input: &[f32],
output: &mut [u8],
frame_size: usize,
) -> CodecResult<usize> {
if input.len() < frame_size * self.channels {
return Err(CodecError::InvalidData(
"Input buffer too small".to_string(),
));
}
let mut encoder = RangeEncoder::new(output.len());
for ch in 0..self.channels {
self.encode_channel(&mut encoder, input, frame_size, ch)?;
}
let compressed = encoder.finalize()?;
if compressed.len() > output.len() {
return Err(CodecError::BufferTooSmall {
needed: compressed.len(),
have: output.len(),
});
}
output[..compressed.len()].copy_from_slice(&compressed);
Ok(compressed.len())
}
fn encode_channel(
&mut self,
encoder: &mut RangeEncoder,
input: &[f32],
frame_size: usize,
channel: usize,
) -> CodecResult<()> {
let mut channel_samples = vec![0.0f32; 2 * frame_size];
for i in 0..frame_size {
let idx = i * self.channels + channel;
if idx < input.len() {
channel_samples[i] = input[idx];
}
}
let mut coeffs = vec![0.0f32; frame_size];
self.mdct.forward(&channel_samples, &mut coeffs);
self.encode_band_energy(encoder, &coeffs, frame_size)?;
self.normalize_coeffs(&mut coeffs, frame_size);
self.encode_spectral_shape(encoder, &coeffs, frame_size)?;
Ok(())
}
fn encode_band_energy(
&mut self,
encoder: &mut RangeEncoder,
coeffs: &[f32],
frame_size: usize,
) -> CodecResult<()> {
let band_sizes = self.get_band_sizes(frame_size);
let mut offset = 0;
for (band_idx, &band_size) in band_sizes.iter().enumerate() {
if band_idx >= CELT_BANDS {
break;
}
let mut energy = 0.0f32;
for i in 0..band_size {
if offset + i < coeffs.len() {
energy += coeffs[offset + i] * coeffs[offset + i];
}
}
energy = (energy / band_size as f32).sqrt().max(1e-10);
let log_energy = energy.ln();
let quantized = ((log_energy * 2.0).round() as i32).clamp(-31, 31);
encoder.encode_int(quantized, 6)?;
self.band_energy[band_idx] = log_energy;
offset += band_size;
}
Ok(())
}
fn normalize_coeffs(&self, coeffs: &mut [f32], frame_size: usize) {
let band_sizes = self.get_band_sizes(frame_size);
let mut offset = 0;
for (band_idx, &band_size) in band_sizes.iter().enumerate() {
if band_idx >= CELT_BANDS {
break;
}
let energy = self.band_energy[band_idx].exp();
let norm_factor = if energy > 1e-10 { 1.0 / energy } else { 0.0 };
for i in 0..band_size {
if offset + i < coeffs.len() {
coeffs[offset + i] *= norm_factor;
}
}
offset += band_size;
}
}
fn encode_spectral_shape(
&mut self,
encoder: &mut RangeEncoder,
coeffs: &[f32],
frame_size: usize,
) -> CodecResult<()> {
let band_sizes = self.get_band_sizes(frame_size);
let mut offset = 0;
for (band_idx, &band_size) in band_sizes.iter().enumerate() {
if band_idx >= CELT_BANDS {
break;
}
let mut pulse_count = 0u32;
for i in 0..band_size {
if offset + i < coeffs.len() {
pulse_count += (coeffs[offset + i].abs() * 10.0).round() as u32;
}
}
pulse_count = pulse_count.min(15);
encoder.encode_uniform(pulse_count, 16)?;
if pulse_count > 0 {
for i in 0..band_size.min(4) {
if offset + i < coeffs.len() {
let pulse_val = (coeffs[offset + i].abs() * 4.0).round() as u32;
encoder.encode_uniform(pulse_val.min(3), 4)?;
}
}
}
offset += band_size;
}
Ok(())
}
fn get_band_sizes(&self, frame_size: usize) -> Vec<usize> {
let avg_size = frame_size / CELT_BANDS;
vec![avg_size; CELT_BANDS]
}
pub fn reset(&mut self) {
for ola in &mut self.overlap_add {
ola.reset();
}
self.band_energy.fill(0.0);
}
#[must_use]
pub const fn sample_rate(&self) -> u32 {
self.sample_rate
}
#[must_use]
pub const fn channels(&self) -> usize {
self.channels
}
#[must_use]
pub const fn bandwidth(&self) -> OpusBandwidth {
self.bandwidth
}
}