use crate::{CodecError, CodecResult, SampleFormat};
use super::celt::CeltEncoder;
use super::packet::{OpusBandwidth, OpusMode, TocInfo};
use super::silk::SilkEncoder;
use super::vad::{VadConfig, VadDecision, VoiceActivityDetector};
#[derive(Debug, Clone)]
pub struct OpusEncoderConfig {
pub sample_rate: u32,
pub channels: usize,
pub bitrate: u32,
pub frame_duration_ms: f32,
pub mode: Option<OpusMode>,
pub bandwidth: Option<OpusBandwidth>,
pub complexity: u32,
pub vbr: bool,
pub cvbr: bool,
pub dtx: bool,
pub sample_format: SampleFormat,
}
impl OpusEncoderConfig {
pub fn new(sample_rate: u32, channels: usize, bitrate: u32) -> Self {
Self {
sample_rate,
channels,
bitrate,
frame_duration_ms: 20.0,
mode: None,
bandwidth: None,
complexity: 5,
vbr: true,
cvbr: false,
dtx: false,
sample_format: SampleFormat::F32,
}
}
#[must_use]
pub fn with_frame_duration(mut self, duration_ms: f32) -> Self {
self.frame_duration_ms = duration_ms;
self
}
#[must_use]
pub fn with_mode(mut self, mode: OpusMode) -> Self {
self.mode = Some(mode);
self
}
#[must_use]
pub fn with_bandwidth(mut self, bandwidth: OpusBandwidth) -> Self {
self.bandwidth = Some(bandwidth);
self
}
#[must_use]
pub fn with_complexity(mut self, complexity: u32) -> Self {
self.complexity = complexity.min(10);
self
}
#[must_use]
pub fn with_vbr(mut self, vbr: bool) -> Self {
self.vbr = vbr;
self
}
#[must_use]
pub fn with_dtx(mut self, dtx: bool) -> Self {
self.dtx = dtx;
self
}
}
impl Default for OpusEncoderConfig {
fn default() -> Self {
Self::new(48000, 2, 64000)
}
}
pub struct OpusEncoder {
config: OpusEncoderConfig,
silk: Option<SilkEncoder>,
celt: Option<CeltEncoder>,
current_mode: OpusMode,
frame_size: usize,
frame_count: u64,
input_buffer: Vec<f32>,
buffered_samples: usize,
vad: VoiceActivityDetector,
last_vad_decision: VadDecision,
dtx_silence_frames: u32,
}
impl OpusEncoder {
pub fn new(config: OpusEncoderConfig) -> CodecResult<Self> {
if !matches!(config.sample_rate, 8000 | 12000 | 16000 | 24000 | 48000) {
return Err(CodecError::InvalidParameter(format!(
"Invalid sample rate: {}",
config.sample_rate
)));
}
if config.channels == 0 || config.channels > 2 {
return Err(CodecError::InvalidParameter(format!(
"Invalid channel count: {}",
config.channels
)));
}
if config.bitrate < 6000 || config.bitrate > 510000 {
return Err(CodecError::InvalidParameter(format!(
"Invalid bitrate: {} (must be 6000-510000)",
config.bitrate
)));
}
let frame_size = Self::calculate_frame_size(config.sample_rate, config.frame_duration_ms)?;
let mode = config.mode.unwrap_or_else(|| Self::select_mode(&config));
let bandwidth = config
.bandwidth
.unwrap_or_else(|| Self::select_bandwidth(&config));
let vad_config = VadConfig::default();
let mut encoder = Self {
config,
silk: None,
celt: None,
current_mode: mode,
frame_size,
frame_count: 0,
input_buffer: Vec::new(),
buffered_samples: 0,
vad: VoiceActivityDetector::new(vad_config),
last_vad_decision: VadDecision::Silence,
dtx_silence_frames: 0,
};
encoder.initialize_encoder(mode, bandwidth)?;
Ok(encoder)
}
fn calculate_frame_size(sample_rate: u32, duration_ms: f32) -> CodecResult<usize> {
let frame_size = (sample_rate as f32 * duration_ms / 1000.0) as usize;
let valid_sizes = match sample_rate {
48000 => vec![120, 240, 480, 960, 1920, 2880],
24000 => vec![60, 120, 240, 480, 960, 1440],
16000 => vec![40, 80, 160, 320, 640, 960],
12000 => vec![30, 60, 120, 240, 480, 720],
8000 => vec![20, 40, 80, 160, 320, 480],
_ => {
return Err(CodecError::InvalidParameter(
"Invalid sample rate".to_string(),
))
}
};
if !valid_sizes.contains(&frame_size) {
return Err(CodecError::InvalidParameter(format!(
"Invalid frame duration {duration_ms}ms for sample rate {sample_rate}Hz"
)));
}
Ok(frame_size)
}
fn select_mode(config: &OpusEncoderConfig) -> OpusMode {
if config.bitrate > 32000 || config.sample_rate >= 24000 {
OpusMode::Celt
} else if config.bitrate < 20000 || config.sample_rate <= 16000 {
OpusMode::Silk
} else {
OpusMode::Hybrid
}
}
fn select_bandwidth(config: &OpusEncoderConfig) -> OpusBandwidth {
match config.sample_rate {
8000 => OpusBandwidth::Narrowband,
12000 => OpusBandwidth::Mediumband,
16000 => OpusBandwidth::Wideband,
24000 => OpusBandwidth::SuperWideband,
_ => OpusBandwidth::Fullband,
}
}
fn initialize_encoder(&mut self, mode: OpusMode, bandwidth: OpusBandwidth) -> CodecResult<()> {
match mode {
OpusMode::Silk => {
self.silk = Some(SilkEncoder::new(
self.config.sample_rate,
self.config.channels,
bandwidth,
));
}
OpusMode::Celt => {
self.celt = Some(CeltEncoder::new(
self.config.sample_rate,
self.config.channels,
bandwidth,
self.frame_size,
));
}
OpusMode::Hybrid => {
self.silk = Some(SilkEncoder::new(
self.config.sample_rate,
self.config.channels,
bandwidth,
));
self.celt = Some(CeltEncoder::new(
self.config.sample_rate,
self.config.channels,
bandwidth,
self.frame_size,
));
}
}
Ok(())
}
pub fn encode(&mut self, samples: &[f32]) -> CodecResult<Option<Vec<u8>>> {
self.input_buffer.extend_from_slice(samples);
self.buffered_samples += samples.len() / self.config.channels;
if self.buffered_samples < self.frame_size {
return Ok(None);
}
let frame_samples = self.frame_size * self.config.channels;
let frame_data: Vec<f32> = self.input_buffer.drain(..frame_samples).collect();
self.buffered_samples -= self.frame_size;
let packet = self.encode_frame(&frame_data)?;
self.frame_count += 1;
Ok(Some(packet))
}
fn encode_frame(&mut self, samples: &[f32]) -> CodecResult<Vec<u8>> {
let channels = self.config.channels.max(1);
let mono: Vec<f32> = samples
.chunks(channels)
.map(|ch| ch.iter().copied().sum::<f32>() / channels as f32)
.collect();
self.last_vad_decision = self.vad.process_f32(&mono, self.config.sample_rate);
if self.config.dtx && self.last_vad_decision == VadDecision::Silence {
self.dtx_silence_frames += 1;
if self.dtx_silence_frames > 1 {
return Ok(Vec::new());
}
} else {
self.dtx_silence_frames = 0;
}
let max_packet_size = 1275; let mut packet = vec![0u8; max_packet_size];
let toc_byte = self.generate_toc_byte()?;
packet[0] = toc_byte;
let frame_size = match self.current_mode {
OpusMode::Silk => {
if let Some(silk) = &mut self.silk {
let bytes = silk.encode(samples, &mut packet[1..], self.frame_size)?;
bytes + 1
} else {
return Err(CodecError::Internal(
"SILK encoder not initialized".to_string(),
));
}
}
OpusMode::Celt => {
if let Some(celt) = &mut self.celt {
let bytes = celt.encode(samples, &mut packet[1..], self.frame_size)?;
bytes + 1
} else {
return Err(CodecError::Internal(
"CELT encoder not initialized".to_string(),
));
}
}
OpusMode::Hybrid => {
if let (Some(silk), Some(celt)) = (&mut self.silk, &mut self.celt) {
let mut silk_data = vec![0u8; max_packet_size / 2];
let mut celt_data = vec![0u8; max_packet_size / 2];
let silk_bytes = silk.encode(samples, &mut silk_data, self.frame_size)?;
let celt_bytes = celt.encode(samples, &mut celt_data, self.frame_size)?;
let total_bytes = 1 + silk_bytes + celt_bytes;
if total_bytes > max_packet_size {
return Err(CodecError::BufferTooSmall {
needed: total_bytes,
have: max_packet_size,
});
}
packet[1..1 + silk_bytes].copy_from_slice(&silk_data[..silk_bytes]);
packet[1 + silk_bytes..1 + silk_bytes + celt_bytes]
.copy_from_slice(&celt_data[..celt_bytes]);
total_bytes
} else {
return Err(CodecError::Internal(
"Hybrid encoders not initialized".to_string(),
));
}
}
};
packet.truncate(frame_size);
Ok(packet)
}
fn generate_toc_byte(&self) -> CodecResult<u8> {
let config = self.encode_configuration()?;
let stereo_flag = if self.config.channels == 2 {
0x04
} else {
0x00
};
let frame_code = 0x00;
Ok((config << 3) | stereo_flag | frame_code)
}
fn encode_configuration(&self) -> CodecResult<u8> {
let bandwidth = self
.config
.bandwidth
.unwrap_or_else(|| Self::select_bandwidth(&self.config));
let config = match self.current_mode {
OpusMode::Silk => {
let bw_code = match bandwidth {
OpusBandwidth::Narrowband => 0,
OpusBandwidth::Mediumband => 1,
OpusBandwidth::Wideband => 2,
_ => 3,
};
let frame_code = match self.frame_size {
480 => 0, 960 => 1, 1920 => 2, 2880 => 3, _ => 1, };
(bw_code << 2) | frame_code
}
OpusMode::Hybrid => {
let bw_code = match bandwidth {
OpusBandwidth::SuperWideband => 0,
OpusBandwidth::Fullband => 1,
_ => 0,
};
12 + bw_code
}
OpusMode::Celt => {
let bw_code = match bandwidth {
OpusBandwidth::Narrowband => 0,
OpusBandwidth::Mediumband => 1,
OpusBandwidth::Wideband => 2,
OpusBandwidth::SuperWideband => 3,
OpusBandwidth::Fullband => 4,
};
let frame_code = match self.frame_size {
120 => 0, 240 => 1, 480 => 2, 960 => 3, _ => 2, };
(16 + (bw_code << 2)) | frame_code
}
};
Ok(config)
}
pub fn flush(&mut self) -> CodecResult<Option<Vec<u8>>> {
if self.buffered_samples == 0 {
return Ok(None);
}
let needed_samples = (self.frame_size - self.buffered_samples) * self.config.channels;
self.input_buffer.extend(vec![0.0f32; needed_samples]);
self.buffered_samples = self.frame_size;
let frame_samples = self.frame_size * self.config.channels;
let frame_data: Vec<f32> = self.input_buffer.drain(..frame_samples).collect();
self.buffered_samples = 0;
let packet = self.encode_frame(&frame_data)?;
Ok(Some(packet))
}
pub fn reset(&mut self) {
if let Some(silk) = &mut self.silk {
silk.reset();
}
if let Some(celt) = &mut self.celt {
celt.reset();
}
self.frame_count = 0;
self.input_buffer.clear();
self.buffered_samples = 0;
self.vad.reset();
self.last_vad_decision = VadDecision::Silence;
self.dtx_silence_frames = 0;
}
#[must_use]
pub const fn last_vad_decision(&self) -> VadDecision {
self.last_vad_decision
}
#[must_use]
pub const fn dtx_silence_frames(&self) -> u32 {
self.dtx_silence_frames
}
#[must_use]
pub const fn config(&self) -> &OpusEncoderConfig {
&self.config
}
#[must_use]
pub const fn frame_count(&self) -> u64 {
self.frame_count
}
#[must_use]
pub const fn current_mode(&self) -> OpusMode {
self.current_mode
}
#[must_use]
pub const fn frame_size(&self) -> usize {
self.frame_size
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encoder_creation() {
let config = OpusEncoderConfig::new(48000, 2, 64000);
let encoder = OpusEncoder::new(config);
assert!(encoder.is_ok());
}
#[test]
fn test_encoder_invalid_sample_rate() {
let config = OpusEncoderConfig::new(44100, 2, 64000);
let encoder = OpusEncoder::new(config);
assert!(encoder.is_err());
}
#[test]
fn test_encoder_invalid_channels() {
let config = OpusEncoderConfig::new(48000, 0, 64000);
let encoder = OpusEncoder::new(config);
assert!(encoder.is_err());
}
#[test]
fn test_encoder_invalid_bitrate() {
let config = OpusEncoderConfig::new(48000, 2, 1000);
let encoder = OpusEncoder::new(config);
assert!(encoder.is_err());
}
#[test]
fn test_config_builder() {
let config = OpusEncoderConfig::new(48000, 2, 64000)
.with_frame_duration(10.0)
.with_complexity(8)
.with_vbr(true);
assert_eq!(config.sample_rate, 48000);
assert_eq!(config.channels, 2);
assert_eq!(config.bitrate, 64000);
assert!((config.frame_duration_ms - 10.0).abs() < f32::EPSILON);
assert_eq!(config.complexity, 8);
assert!(config.vbr);
}
#[test]
fn test_mode_selection() {
let config_speech = OpusEncoderConfig::new(16000, 1, 16000);
assert_eq!(OpusEncoder::select_mode(&config_speech), OpusMode::Silk);
let config_music = OpusEncoderConfig::new(48000, 2, 64000);
assert_eq!(OpusEncoder::select_mode(&config_music), OpusMode::Celt);
}
#[test]
fn test_bandwidth_selection() {
let config_nb = OpusEncoderConfig::new(8000, 1, 16000);
assert_eq!(
OpusEncoder::select_bandwidth(&config_nb),
OpusBandwidth::Narrowband
);
let config_fb = OpusEncoderConfig::new(48000, 2, 64000);
assert_eq!(
OpusEncoder::select_bandwidth(&config_fb),
OpusBandwidth::Fullband
);
}
#[test]
fn test_frame_size_calculation() {
let size = OpusEncoder::calculate_frame_size(48000, 20.0);
assert!(size.is_ok());
assert_eq!(size.expect("should succeed"), 960);
let size = OpusEncoder::calculate_frame_size(48000, 10.0);
assert!(size.is_ok());
assert_eq!(size.expect("should succeed"), 480);
}
#[test]
fn test_encoder_reset() {
let config = OpusEncoderConfig::new(48000, 2, 64000);
let mut encoder = OpusEncoder::new(config).expect("should succeed");
encoder.reset();
assert_eq!(encoder.frame_count(), 0);
}
fn speech_frame_f32(len: usize, sample_rate: u32) -> Vec<f32> {
(0..len)
.map(|i| {
let t = i as f32 / sample_rate as f32;
(2.0 * std::f32::consts::PI * 200.0 * t).sin() * 0.8
})
.collect()
}
#[test]
fn test_vad_initial_decision_is_silence() {
let config = OpusEncoderConfig::new(16000, 1, 16000);
let encoder = OpusEncoder::new(config).expect("encoder creation");
assert_eq!(encoder.last_vad_decision(), VadDecision::Silence);
}
#[test]
fn test_vad_dtx_silence_frames_zero_initially() {
let config = OpusEncoderConfig::new(16000, 1, 16000);
let encoder = OpusEncoder::new(config).expect("encoder creation");
assert_eq!(encoder.dtx_silence_frames(), 0);
}
#[test]
fn test_vad_decision_updates_after_encode() {
let config = OpusEncoderConfig::new(16000, 1, 16000);
let mut encoder = OpusEncoder::new(config).expect("encoder creation");
let frame_size = encoder.frame_size();
for _ in 0..10 {
let silence = vec![0.0f32; frame_size];
let _ = encoder.encode(&silence);
}
let speech = speech_frame_f32(frame_size, 16000);
let _ = encoder.encode(&speech);
assert_eq!(
encoder.last_vad_decision(),
VadDecision::Voice,
"Loud speech should produce Voice decision"
);
}
#[test]
fn test_vad_reset_clears_decision() {
let config = OpusEncoderConfig::new(16000, 1, 16000);
let mut encoder = OpusEncoder::new(config).expect("encoder creation");
let frame_size = encoder.frame_size();
for _ in 0..5 {
let _ = encoder.encode(&vec![0.0f32; frame_size]);
}
encoder.reset();
assert_eq!(encoder.last_vad_decision(), VadDecision::Silence);
assert_eq!(encoder.dtx_silence_frames(), 0);
}
#[test]
fn test_dtx_suppresses_continuous_silence() {
let config = OpusEncoderConfig::new(16000, 1, 16000).with_dtx(true);
let mut encoder = OpusEncoder::new(config).expect("encoder creation");
let frame_size = encoder.frame_size();
let silence = vec![0.0f32; frame_size];
let mut suppressed = 0u32;
for _ in 0..40 {
if let Ok(Some(pkt)) = encoder.encode(&silence) {
if pkt.is_empty() {
suppressed += 1;
}
}
}
assert!(
suppressed > 0,
"DTX must suppress at least one silence frame over 40 frames"
);
}
#[test]
fn test_dtx_does_not_suppress_speech() {
let config = OpusEncoderConfig::new(16000, 1, 16000).with_dtx(true);
let mut encoder = OpusEncoder::new(config).expect("encoder creation");
let frame_size = encoder.frame_size();
for _ in 0..10 {
let _ = encoder.encode(&vec![0.0f32; frame_size]);
}
let speech = speech_frame_f32(frame_size, 16000);
if let Ok(Some(pkt)) = encoder.encode(&speech) {
assert!(!pkt.is_empty(), "DTX must NOT suppress speech frames");
}
}
#[test]
fn test_dtx_disabled_never_suppresses() {
let config = OpusEncoderConfig::new(16000, 1, 16000);
assert!(!config.dtx);
let mut encoder = OpusEncoder::new(config).expect("encoder creation");
let frame_size = encoder.frame_size();
let silence = vec![0.0f32; frame_size];
for _ in 0..30 {
if let Ok(Some(pkt)) = encoder.encode(&silence) {
assert!(
!pkt.is_empty(),
"Without DTX, no packets should be suppressed"
);
}
}
}
#[test]
fn test_dtx_silence_frame_counter_increases() {
let config = OpusEncoderConfig::new(16000, 1, 16000).with_dtx(true);
let mut encoder = OpusEncoder::new(config).expect("encoder creation");
let frame_size = encoder.frame_size();
let silence = vec![0.0f32; frame_size];
for _ in 0..40 {
let _ = encoder.encode(&silence);
}
assert!(
encoder.dtx_silence_frames() > 0,
"dtx_silence_frames must be > 0 after sustained silence with DTX"
);
}
}