use crate::{CodecError, CodecResult};
use super::celt::CeltDecoder;
use super::packet::OpusBandwidth;
use super::silk::SilkDecoder;
use super::silk_range::SilkRangeDecoder;
#[derive(Debug)]
pub struct HybridDecoder {
silk: SilkDecoder,
celt: CeltDecoder,
sample_rate: u32,
channels: usize,
#[allow(dead_code)]
bandwidth: OpusBandwidth,
crossover_freq: u32,
lowpass_state: Vec<BiquadState>,
highpass_state: Vec<BiquadState>,
}
#[derive(Debug, Clone, Default)]
struct BiquadState {
prev_input: [f32; 2],
prev_output: [f32; 2],
}
impl HybridDecoder {
pub fn new(
sample_rate: u32,
channels: usize,
bandwidth: OpusBandwidth,
frame_size: usize,
) -> Self {
let crossover_freq = 8000;
let silk = SilkDecoder::new(sample_rate, channels, OpusBandwidth::Wideband);
let celt = CeltDecoder::new(sample_rate, channels, bandwidth, frame_size);
Self {
silk,
celt,
sample_rate,
channels,
bandwidth,
crossover_freq,
lowpass_state: vec![BiquadState::default(); channels],
highpass_state: vec![BiquadState::default(); channels],
}
}
pub fn decode(
&mut self,
data: &[u8],
output: &mut [f32],
frame_size: usize,
) -> CodecResult<()> {
if output.len() < frame_size * self.channels {
return Err(CodecError::InvalidData(
"Output buffer too small".to_string(),
));
}
if data.is_empty() {
return self.silk.decode(data, output, frame_size);
}
let mut silk_output = vec![0.0f32; frame_size * self.channels];
let mut celt_output = vec![0.0f32; frame_size * self.channels];
let mut shared = SilkRangeDecoder::new(data)?;
self.silk
.decode_with(&mut shared, &mut silk_output, frame_size)?;
let silk_consumed = shared.front_bytes_consumed().min(data.len());
let celt_data = &data[silk_consumed..];
if celt_data.is_empty() {
celt_output.fill(0.0);
} else {
self.celt.decode(celt_data, &mut celt_output, frame_size)?;
}
self.combine_outputs(&silk_output, &celt_output, output, frame_size)
}
fn combine_outputs(
&mut self,
silk_output: &[f32],
celt_output: &[f32],
output: &mut [f32],
frame_size: usize,
) -> CodecResult<()> {
let mut silk_filtered = silk_output.to_vec();
let mut celt_filtered = celt_output.to_vec();
self.apply_lowpass(&mut silk_filtered, frame_size);
self.apply_highpass(&mut celt_filtered, frame_size);
for i in 0..(frame_size * self.channels) {
output[i] = silk_filtered[i] + celt_filtered[i];
}
Ok(())
}
fn apply_lowpass(&mut self, samples: &mut [f32], frame_size: usize) {
let (b0, b1, b2, a1, a2) = lowpass_coeffs(self.crossover_freq, self.sample_rate);
for ch in 0..self.channels {
let state = &mut self.lowpass_state[ch];
for i in 0..frame_size {
let idx = i * self.channels + ch;
if idx < samples.len() {
let input = samples[idx];
let out = b0 * input + b1 * state.prev_input[0] + b2 * state.prev_input[1]
- a1 * state.prev_output[0]
- a2 * state.prev_output[1];
state.prev_input[1] = state.prev_input[0];
state.prev_input[0] = input;
state.prev_output[1] = state.prev_output[0];
state.prev_output[0] = out;
samples[idx] = out;
}
}
}
}
fn apply_highpass(&mut self, samples: &mut [f32], frame_size: usize) {
let (b0, b1, b2, a1, a2) = highpass_coeffs(self.crossover_freq, self.sample_rate);
for ch in 0..self.channels {
let state = &mut self.highpass_state[ch];
for i in 0..frame_size {
let idx = i * self.channels + ch;
if idx < samples.len() {
let input = samples[idx];
let out = b0 * input + b1 * state.prev_input[0] + b2 * state.prev_input[1]
- a1 * state.prev_output[0]
- a2 * state.prev_output[1];
state.prev_input[1] = state.prev_input[0];
state.prev_input[0] = input;
state.prev_output[1] = state.prev_output[0];
state.prev_output[0] = out;
samples[idx] = out;
}
}
}
}
pub fn reset(&mut self) {
self.silk.reset();
self.celt.reset();
for state in &mut self.lowpass_state {
*state = BiquadState::default();
}
for state in &mut self.highpass_state {
*state = BiquadState::default();
}
}
#[must_use]
pub const fn sample_rate(&self) -> u32 {
self.sample_rate
}
#[must_use]
pub const fn channels(&self) -> usize {
self.channels
}
#[must_use]
pub const fn crossover_frequency(&self) -> u32 {
self.crossover_freq
}
}
fn lowpass_coeffs(cutoff: u32, sample_rate: u32) -> (f32, f32, f32, f32, f32) {
use std::f32::consts::PI;
let omega = 2.0 * PI * cutoff as f32 / sample_rate as f32;
let cos_omega = omega.cos();
let alpha = omega.sin() / (2.0 * std::f32::consts::FRAC_1_SQRT_2.recip());
let b0 = (1.0 - cos_omega) / 2.0;
let b1 = 1.0 - cos_omega;
let b2 = (1.0 - cos_omega) / 2.0;
let a0 = 1.0 + alpha;
let a1 = -2.0 * cos_omega;
let a2 = 1.0 - alpha;
(b0 / a0, b1 / a0, b2 / a0, a1 / a0, a2 / a0)
}
fn highpass_coeffs(cutoff: u32, sample_rate: u32) -> (f32, f32, f32, f32, f32) {
use std::f32::consts::PI;
let omega = 2.0 * PI * cutoff as f32 / sample_rate as f32;
let cos_omega = omega.cos();
let alpha = omega.sin() / (2.0 * std::f32::consts::FRAC_1_SQRT_2.recip());
let b0 = (1.0 + cos_omega) / 2.0;
let b1 = -(1.0 + cos_omega);
let b2 = (1.0 + cos_omega) / 2.0;
let a0 = 1.0 + alpha;
let a1 = -2.0 * cos_omega;
let a2 = 1.0 - alpha;
(b0 / a0, b1 / a0, b2 / a0, a1 / a0, a2 / a0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hybrid_decoder_creation() {
let decoder = HybridDecoder::new(48000, 2, OpusBandwidth::SuperWideband, 480);
assert_eq!(decoder.sample_rate(), 48000);
assert_eq!(decoder.channels(), 2);
assert_eq!(decoder.crossover_frequency(), 8000);
}
#[test]
fn test_hybrid_decode_single_bitstream() {
let mut decoder = HybridDecoder::new(48000, 1, OpusBandwidth::SuperWideband, 480);
let data: Vec<u8> = (0u8..64)
.map(|i| i.wrapping_mul(43).wrapping_add(17))
.collect();
let mut output = vec![0.0f32; 480];
let result = decoder.decode(&data, &mut output, 480);
assert!(result.is_ok(), "hybrid decode should succeed");
for &s in &output {
assert!(s.is_finite(), "hybrid output must be finite");
}
}
#[test]
fn test_hybrid_decode_stereo() {
let mut decoder = HybridDecoder::new(48000, 2, OpusBandwidth::Fullband, 480);
let data: Vec<u8> = (0u8..96)
.map(|i| i.wrapping_mul(31).wrapping_add(9))
.collect();
let mut output = vec![0.0f32; 480 * 2];
decoder
.decode(&data, &mut output, 480)
.expect("stereo hybrid");
assert!(output.iter().all(|s| s.is_finite()));
}
#[test]
fn test_hybrid_decode_empty_packet() {
let mut decoder = HybridDecoder::new(48000, 1, OpusBandwidth::SuperWideband, 480);
let mut output = vec![0.0f32; 480];
let result = decoder.decode(&[], &mut output, 480);
assert!(result.is_ok());
assert!(output.iter().all(|s| s.is_finite()));
}
#[test]
fn test_hybrid_reset() {
let mut decoder = HybridDecoder::new(48000, 1, OpusBandwidth::SuperWideband, 480);
decoder.reset();
}
#[test]
fn test_combine_outputs() {
let mut decoder = HybridDecoder::new(48000, 1, OpusBandwidth::SuperWideband, 480);
let silk_output = vec![1.0f32; 480];
let celt_output = vec![2.0f32; 480];
let mut output = vec![0.0f32; 480];
let result = decoder.combine_outputs(&silk_output, &celt_output, &mut output, 480);
assert!(result.is_ok());
let last = output[479];
assert!(
last.is_finite() && last.abs() < 10.0,
"expected finite output within reasonable range, got {last}"
);
}
}