use crate::{CodecError, CodecResult};
use super::celt::CeltDecoder;
use super::packet::OpusBandwidth;
use super::range_decoder::RangeDecoder;
use super::silk::SilkDecoder;
use std::f32::consts::PI;
#[derive(Debug)]
pub struct HybridDecoder {
silk: SilkDecoder,
celt: CeltDecoder,
sample_rate: u32,
channels: usize,
#[allow(dead_code)]
bandwidth: OpusBandwidth,
crossover_freq: u32,
lowpass_state: Vec<LowPassState>,
highpass_state: Vec<HighPassState>,
}
#[derive(Debug, Clone)]
struct LowPassState {
prev_input: [f32; 2],
prev_output: [f32; 2],
}
impl LowPassState {
fn new() -> Self {
Self {
prev_input: [0.0; 2],
prev_output: [0.0; 2],
}
}
}
#[derive(Debug, Clone)]
struct HighPassState {
prev_input: [f32; 2],
prev_output: [f32; 2],
}
impl HighPassState {
fn new() -> Self {
Self {
prev_input: [0.0; 2],
prev_output: [0.0; 2],
}
}
}
impl HybridDecoder {
pub fn new(
sample_rate: u32,
channels: usize,
bandwidth: OpusBandwidth,
frame_size: usize,
) -> Self {
let crossover_freq = 8000;
let silk = SilkDecoder::new(sample_rate, channels, OpusBandwidth::Wideband);
let celt = CeltDecoder::new(sample_rate, channels, bandwidth, frame_size);
let lowpass_state = (0..channels).map(|_| LowPassState::new()).collect();
let highpass_state = (0..channels).map(|_| HighPassState::new()).collect();
Self {
silk,
celt,
sample_rate,
channels,
bandwidth,
crossover_freq,
lowpass_state,
highpass_state,
}
}
pub fn decode(
&mut self,
silk_data: &[u8],
celt_data: &[u8],
output: &mut [f32],
frame_size: usize,
) -> CodecResult<()> {
if output.len() < frame_size * self.channels {
return Err(CodecError::InvalidData(
"Output buffer too small".to_string(),
));
}
let (silk_bytes, remaining) = self.decode_split_point(silk_data)?;
let mut silk_output = vec![0.0f32; frame_size * self.channels];
let mut celt_output = vec![0.0f32; frame_size * self.channels];
let silk_slice = &silk_data[..silk_bytes.min(silk_data.len())];
self.silk.decode(silk_slice, &mut silk_output, frame_size)?;
self.celt.decode(celt_data, &mut celt_output, frame_size)?;
self.combine_outputs(&silk_output, &celt_output, output, frame_size)?;
Ok(())
}
fn decode_split_point(&self, data: &[u8]) -> CodecResult<(usize, usize)> {
if data.is_empty() {
return Ok((0, 0));
}
let mut decoder = RangeDecoder::new(data)?;
let split_size = decoder.decode_uniform(256)? as usize;
Ok((split_size, data.len().saturating_sub(split_size)))
}
fn combine_outputs(
&mut self,
silk_output: &[f32],
celt_output: &[f32],
output: &mut [f32],
frame_size: usize,
) -> CodecResult<()> {
let mut silk_filtered = silk_output.to_vec();
let mut celt_filtered = celt_output.to_vec();
self.apply_lowpass(&mut silk_filtered, frame_size);
self.apply_highpass(&mut celt_filtered, frame_size);
for i in 0..(frame_size * self.channels) {
output[i] = silk_filtered[i] + celt_filtered[i];
}
Ok(())
}
fn apply_lowpass(&mut self, samples: &mut [f32], frame_size: usize) {
let cutoff_ratio = self.crossover_freq as f32 / self.sample_rate as f32;
let omega = 2.0 * PI * cutoff_ratio;
let cos_omega = omega.cos();
let alpha = omega.sin() / (2.0 * 1.414);
let b0 = (1.0 - cos_omega) / 2.0;
let b1 = 1.0 - cos_omega;
let b2 = (1.0 - cos_omega) / 2.0;
let a0 = 1.0 + alpha;
let a1 = -2.0 * cos_omega;
let a2 = 1.0 - alpha;
let b0 = b0 / a0;
let b1 = b1 / a0;
let b2 = b2 / a0;
let a1 = a1 / a0;
let a2 = a2 / a0;
for ch in 0..self.channels {
let state = &mut self.lowpass_state[ch];
for i in 0..frame_size {
let idx = i * self.channels + ch;
if idx < samples.len() {
let input = samples[idx];
let output = b0 * input + b1 * state.prev_input[0] + b2 * state.prev_input[1]
- a1 * state.prev_output[0]
- a2 * state.prev_output[1];
state.prev_input[1] = state.prev_input[0];
state.prev_input[0] = input;
state.prev_output[1] = state.prev_output[0];
state.prev_output[0] = output;
samples[idx] = output;
}
}
}
}
fn apply_highpass(&mut self, samples: &mut [f32], frame_size: usize) {
let cutoff_ratio = self.crossover_freq as f32 / self.sample_rate as f32;
let omega = 2.0 * PI * cutoff_ratio;
let cos_omega = omega.cos();
let alpha = omega.sin() / (2.0 * 1.414);
let b0 = (1.0 + cos_omega) / 2.0;
let b1 = -(1.0 + cos_omega);
let b2 = (1.0 + cos_omega) / 2.0;
let a0 = 1.0 + alpha;
let a1 = -2.0 * cos_omega;
let a2 = 1.0 - alpha;
let b0 = b0 / a0;
let b1 = b1 / a0;
let b2 = b2 / a0;
let a1 = a1 / a0;
let a2 = a2 / a0;
for ch in 0..self.channels {
let state = &mut self.highpass_state[ch];
for i in 0..frame_size {
let idx = i * self.channels + ch;
if idx < samples.len() {
let input = samples[idx];
let output = b0 * input + b1 * state.prev_input[0] + b2 * state.prev_input[1]
- a1 * state.prev_output[0]
- a2 * state.prev_output[1];
state.prev_input[1] = state.prev_input[0];
state.prev_input[0] = input;
state.prev_output[1] = state.prev_output[0];
state.prev_output[0] = output;
samples[idx] = output;
}
}
}
}
pub fn reset(&mut self) {
self.silk.reset();
self.celt.reset();
for state in &mut self.lowpass_state {
state.prev_input.fill(0.0);
state.prev_output.fill(0.0);
}
for state in &mut self.highpass_state {
state.prev_input.fill(0.0);
state.prev_output.fill(0.0);
}
}
#[must_use]
pub const fn sample_rate(&self) -> u32 {
self.sample_rate
}
#[must_use]
pub const fn channels(&self) -> usize {
self.channels
}
#[must_use]
pub const fn crossover_frequency(&self) -> u32 {
self.crossover_freq
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hybrid_decoder_creation() {
let decoder = HybridDecoder::new(48000, 2, OpusBandwidth::SuperWideband, 480);
assert_eq!(decoder.sample_rate(), 48000);
assert_eq!(decoder.channels(), 2);
assert_eq!(decoder.crossover_frequency(), 8000);
}
#[test]
fn test_hybrid_decoder_decode() {
let mut decoder = HybridDecoder::new(48000, 1, OpusBandwidth::SuperWideband, 480);
let silk_data = vec![0x80, 0x00, 0x00];
let celt_data = vec![0x80, 0x00, 0x00, 0x00];
let mut output = vec![0.0f32; 480];
let result = decoder.decode(&silk_data, &celt_data, &mut output, 480);
assert!(result.is_ok());
}
#[test]
fn test_hybrid_reset() {
let mut decoder = HybridDecoder::new(48000, 1, OpusBandwidth::SuperWideband, 480);
decoder.reset();
}
#[test]
fn test_combine_outputs() {
let mut decoder = HybridDecoder::new(48000, 1, OpusBandwidth::SuperWideband, 480);
let silk_output = vec![1.0f32; 480];
let celt_output = vec![2.0f32; 480];
let mut output = vec![0.0f32; 480];
let result = decoder.combine_outputs(&silk_output, &celt_output, &mut output, 480);
assert!(result.is_ok());
let last = output[479];
assert!(
last.is_finite() && last.abs() < 10.0,
"Expected finite output within reasonable range, got {last}"
);
}
}