Skip to main content

opus_rs/
lib.rs

1pub mod bands;
2pub mod celt;
3pub mod celt_lpc;
4pub mod hp_cutoff;
5pub mod mdct;
6pub mod modes;
7pub mod pitch;
8pub mod pvq;
9pub mod quant_bands;
10pub mod range_coder;
11pub mod rate;
12pub mod silk;
13
14use celt::{CeltDecoder, CeltEncoder};
15use hp_cutoff::hp_cutoff;
16use range_coder::RangeCoder;
17use silk::control_codec::silk_control_encoder;
18use silk::enc_api::{silk_encode, silk_encode_prefill};
19use silk::init_encoder::silk_init_encoder;
20use silk::lin2log::silk_lin2log;
21use silk::log2lin::silk_log2lin;
22use silk::macros::*;
23use silk::resampler::{silk_resampler_down2, silk_resampler_down2_3};
24use silk::structs::SilkEncoderState;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum Application {
28    Voip = 2048,
29    Audio = 2049,
30    RestrictedLowDelay = 2051,
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum Bandwidth {
35    Auto = -1000,
36    Narrowband = 1101,
37    Mediumband = 1102,
38    Wideband = 1103,
39    Superwideband = 1104,
40    Fullband = 1105,
41}
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44enum OpusMode {
45    SilkOnly,
46    Hybrid,
47    CeltOnly,
48}
49
50pub struct OpusEncoder {
51    celt_enc: CeltEncoder,
52    silk_enc: Box<SilkEncoderState>,
53    application: Application,
54    sampling_rate: i32,
55    channels: usize,
56    bandwidth: Bandwidth,
57    pub bitrate_bps: i32,
58    pub complexity: i32,
59    pub use_cbr: bool,
60    /// Enable in-band FEC (LBRR). When set to true, the encoder will
61    /// include redundant copies of previous frames (at lower quality) to
62    /// allow packet-loss recovery on the decoder side.
63    pub use_inband_fec: bool,
64    /// Expected packet loss percentage (0-100). Used to determine whether
65    /// LBRR encoding is worth the overhead. Non-zero value enables LBRR
66    /// when `use_inband_fec` is also true.
67    pub packet_loss_perc: i32,
68    silk_initialized: bool,
69    mode: OpusMode,
70    // HP filter state
71    variable_hp_smth2_q15: i32,
72    hp_mem: Vec<i32>, // [4] for stereo, [2] for mono
73    // Pre-allocated buffers for encoding (reuse across frames to avoid allocations)
74    buf_filtered: Vec<i16>,
75    buf_silk_input: Vec<i16>,
76    buf_stereo_mid: Vec<i16>,
77    buf_stereo_side: Vec<i16>,
78    down2_state_first: [i32; 2],
79    down2_state_second: [i32; 2],
80    down2_3_state: [i32; 6],
81}
82
83impl OpusEncoder {
84    pub fn new(
85        sampling_rate: i32,
86        channels: usize,
87        application: Application,
88    ) -> Result<Self, &'static str> {
89        if ![8000, 12000, 16000, 24000, 48000].contains(&sampling_rate) {
90            return Err("Invalid sampling rate");
91        }
92        if ![1, 2].contains(&channels) {
93            return Err("Invalid number of channels");
94        }
95
96        let mode = modes::default_mode();
97        let celt_enc = CeltEncoder::new(mode, channels);
98
99        let mut silk_enc = Box::new(SilkEncoderState::default());
100        if silk_init_encoder(&mut *silk_enc, 0) != 0 {
101            return Err("SILK encoder initialization failed");
102        }
103
104        let (opus_mode, bw) = match application {
105            Application::Voip => (OpusMode::SilkOnly, Bandwidth::Narrowband),
106            _ => (OpusMode::CeltOnly, Bandwidth::Fullband),
107        };
108
109        // Initialize HP filter: min cutoff = 60 Hz, in log scale Q15
110        // VARIABLE_HP_MIN_CUTOFF_HZ = 60
111        use silk::lin2log::silk_lin2log;
112        let variable_hp_smth2_q15 = silk_lin2log(60) << 8;
113
114        Ok(Self {
115            celt_enc,
116            silk_enc,
117            application,
118            sampling_rate,
119            channels,
120            bandwidth: bw,
121            bitrate_bps: 64000,
122            complexity: 0,
123            use_cbr: false,
124            use_inband_fec: false,
125            packet_loss_perc: 0,
126            silk_initialized: false,
127            mode: opus_mode,
128            variable_hp_smth2_q15,
129            hp_mem: vec![0; channels * 2],
130            // Pre-allocated buffers (will be resized on first use)
131            buf_filtered: Vec::new(),
132            buf_silk_input: Vec::new(),
133            buf_stereo_mid: Vec::new(),
134            buf_stereo_side: Vec::new(),
135            down2_state_first: [0; 2],
136            down2_state_second: [0; 2],
137            down2_3_state: [0; 6],
138        })
139    }
140
141    /// Enable Hybrid mode (SILK + CELT).
142    /// This is only valid for sampling rates of 24000 or 48000 Hz.
143    /// In Hybrid mode, SILK handles the low-frequency portion and CELT handles the high-frequency.
144    /// IMPORTANT: Hybrid mode requires the input to be at the encoder's sampling rate (24kHz or 48kHz).
145    /// The SILK encoder will internally use 16kHz for its processing.
146    pub fn enable_hybrid_mode(&mut self) -> Result<(), &'static str> {
147        if self.sampling_rate != 24000 && self.sampling_rate != 48000 {
148            return Err("Hybrid mode requires 24kHz or 48kHz sampling rate");
149        }
150        let bw = if self.sampling_rate == 48000 {
151            Bandwidth::Fullband
152        } else {
153            Bandwidth::Superwideband
154        };
155        self.mode = OpusMode::Hybrid;
156        self.bandwidth = bw;
157        self.silk_initialized = false; // Force re-initialization at SILK internal rate
158        Ok(())
159    }
160
161    pub fn encode(
162        &mut self,
163        input: &[f32],
164        frame_size: usize,
165        output: &mut [u8],
166    ) -> Result<usize, &'static str> {
167        // Minimum output buffer size: 1 byte TOC + 1 byte frame count (for Code 3 packets)
168        if output.len() < 2 {
169            return Err("Output buffer too small");
170        }
171
172        let frame_rate = frame_rate_from_params(self.sampling_rate, frame_size)
173            .ok_or("Invalid frame size for sampling rate")?;
174
175        // Mode decision can be dynamic here, for now using stored mode
176        let mode = self.mode;
177        if mode == OpusMode::CeltOnly {
178            match frame_rate {
179                400 | 200 | 100 | 50 => {}
180                _ => return Err("Unsupported frame size for CELT-only mode"),
181            }
182        }
183
184        let toc = gen_toc(mode, frame_rate, self.bandwidth, self.channels);
185        output[0] = toc;
186
187        // C: cbr_bytes = IMIN((bitrate_to_bits(bitrate, Fs, frame_size)+4)/8, max_data_bytes)
188        // bitrate_to_bits = bitrate * 6 / (6 * Fs / frame_size) = bitrate * frame_size / Fs
189        let target_bits =
190            (self.bitrate_bps as i64 * frame_size as i64 / self.sampling_rate as i64) as i32;
191        let cbr_bytes = ((target_bits + 4) / 8) as usize;
192        let max_data_bytes = output.len();
193
194        // In CBR mode, use cbr_bytes; in VBR mode, use full buffer
195        let n_bytes = if self.use_cbr {
196            cbr_bytes.min(max_data_bytes).max(1)
197        } else {
198            max_data_bytes
199        };
200
201        // C: ec_enc_init(&enc, data+1, orig_max_data_bytes-1)
202        // Range coder buffer is payload only (excluding TOC byte)
203        let mut rc = RangeCoder::new_encoder((max_data_bytes - 1) as u32);
204
205        if mode == OpusMode::SilkOnly || mode == OpusMode::Hybrid {
206            /* Initialize/configure SILK encoder if needed.
207               For Hybrid mode, SILK operates at 16kHz internal rate regardless of
208               the encoder's API sampling rate (24kHz or 48kHz). The input must be
209               downsampled before passing to SILK.
210               For SilkOnly mode, SILK only supports 8, 12, or 16 kHz internal rates.
211               Sample rates above 16kHz are capped to 16kHz (WB mode).
212            */
213            let silk_fs_khz = if mode == OpusMode::Hybrid {
214                16 // SILK always uses 16kHz in Hybrid (SWB) mode
215            } else {
216                // Cap SILK internal rate to max 16kHz (WB)
217                self.sampling_rate.min(16000) / 1000
218            };
219            // SILK frame duration matches the API frame duration
220            let frame_ms = (frame_size as i32 * 1000) / self.sampling_rate;
221            if !self.silk_initialized || self.silk_enc.s_cmn.fs_khz != silk_fs_khz as i32 {
222                let silk_init_bitrate = (((n_bytes - 1) * 8) as i64 * self.sampling_rate as i64
223                    / frame_size as i64) as i32;
224                silk_control_encoder(
225                    &mut *self.silk_enc,
226                    silk_fs_khz as i32,
227                    frame_ms,
228                    silk_init_bitrate,
229                    self.complexity,
230                );
231                self.silk_enc.s_cmn.use_cbr = if self.use_cbr { 1 } else { 0 };
232                // Set number of channels for stereo support
233                self.silk_enc.s_cmn.n_channels = self.channels as i32;
234                self.silk_initialized = true;
235                self.down2_state_first = [0; 2];
236                self.down2_state_second = [0; 2];
237                self.down2_3_state = [0; 6];
238
239                // Use SILK internal sample rate for prefill buffer size
240                let silk_internal_rate = silk_fs_khz * 1000;
241                let encoder_buffer = silk_internal_rate as usize / 100; // Fs/100 = 80 samples @8kHz
242                let prefill_zeros = vec![0i16; encoder_buffer];
243                silk_encode_prefill(
244                    &mut *self.silk_enc,
245                    &prefill_zeros,
246                    0, // activity = inactive (silence)
247                );
248            }
249            // Always update FEC and packet loss settings (may change between calls)
250            self.silk_enc.s_cmn.use_in_band_fec = if self.use_inband_fec { 1 } else { 0 };
251            self.silk_enc.s_cmn.packet_loss_perc = self.packet_loss_perc.clamp(0, 100);
252            // Enable LBRR in the state struct when FEC is requested
253            self.silk_enc.s_cmn.lbrr_enabled = if self.use_inband_fec { 1 } else { 0 };
254            // LBRR gain increase (C default: 2 steps ~= 3dB quality reduction for LBRR copy)
255            if self.silk_enc.s_cmn.lbrr_gain_increases == 0 {
256                self.silk_enc.s_cmn.lbrr_gain_increases = 2;
257            }
258
259            // Apply HP filter before SILK encoding (matching C opus_encode_native)
260            // Update variable_HP_smth2_Q15 from SILK's smth1
261            let hp_freq_smth1 = if mode == OpusMode::CeltOnly {
262                silk_lin2log(60) << 8 // VARIABLE_HP_MIN_CUTOFF_HZ = 60
263            } else {
264                self.silk_enc.s_cmn.variable_hp_smth1_q15
265            };
266
267            // Second-order smoother: smth2 = smth2 + COEF2 * (smth1 - smth2)
268            // VARIABLE_HP_SMTH_COEF2 = 0.0025 (Q16 = 164)
269            const VARIABLE_HP_SMTH_COEF2_Q16: i32 = 164;
270            self.variable_hp_smth2_q15 = silk_smlawb(
271                self.variable_hp_smth2_q15,
272                hp_freq_smth1 - self.variable_hp_smth2_q15,
273                VARIABLE_HP_SMTH_COEF2_Q16,
274            );
275
276            // Convert from log scale to Hz
277            let cutoff_hz = silk_log2lin(silk_rshift(self.variable_hp_smth2_q15, 8));
278
279            // Apply HP filter to input (reuse pre-allocated buffer)
280            let required_size = frame_size * self.channels;
281            self.buf_filtered.resize(required_size, 0);
282            if self.application == Application::Voip {
283                hp_cutoff(
284                    input,
285                    cutoff_hz,
286                    &mut self.buf_filtered,
287                    &mut self.hp_mem,
288                    frame_size,
289                    self.channels,
290                    self.sampling_rate,
291                );
292            } else {
293                // No HP filter for non-VOIP, just convert
294                for (i, &x) in input.iter().enumerate() {
295                    self.buf_filtered[i] = (x * 32768.0).clamp(-32768.0, 32767.0) as i16;
296                }
297            }
298
299            let input_i16 = &self.buf_filtered;
300
301            // Convert input to SILK format (reuse pre-allocated buffers)
302            // For SILK stereo: convert L/R to Mid (L+R)/2 and Side (L-R) for encoding.
303            // For Hybrid mode: downsample from API rate to SILK's 16kHz internal rate.
304            let silk_input: &[i16] = if mode == OpusMode::SilkOnly && self.channels == 2 {
305                // Convert interleaved stereo (L R L R ...) to Mid channel
306                // Mid = (L + R) / 2, Side = L - R
307                let frame_length = input_i16.len() / 2;
308                self.buf_stereo_mid.resize(frame_length, 0);
309                self.buf_stereo_side.resize(frame_length, 0);
310                for i in 0..frame_length {
311                    // L is at index 2*i, R is at index 2*i+1
312                    let l = input_i16[2 * i] as i32;
313                    let r = input_i16[2 * i + 1] as i32;
314                    self.buf_stereo_mid[i] = ((l + r) / 2) as i16;
315                    self.buf_stereo_side[i] = (l - r) as i16;
316                }
317                // Store side for later stereo encoding
318                self.silk_enc.stereo.side = self.buf_stereo_side.clone();
319                &self.buf_stereo_mid
320            } else if mode == OpusMode::Hybrid && self.sampling_rate > 16000 {
321                if self.sampling_rate == 48000 {
322                    let stage1_size = frame_size / 2;
323                    let mut stage1_buf = vec![0i16; stage1_size];
324                    silk_resampler_down2(
325                        &mut self.down2_state_first,
326                        &mut stage1_buf,
327                        input_i16,
328                        frame_size as i32,
329                    );
330                    let silk_frame_size = stage1_size * 2 / 3;
331                    self.buf_silk_input.resize(silk_frame_size, 0);
332                    silk_resampler_down2_3(
333                        &mut self.down2_3_state,
334                        &mut self.buf_silk_input,
335                        &stage1_buf,
336                        stage1_size as i32,
337                    );
338                } else {
339                    // 24kHz → 16kHz via silk_resampler_down2_3 (2/3 ratio)
340                    let silk_frame_size = frame_size * 2 / 3;
341                    self.buf_silk_input.resize(silk_frame_size, 0);
342                    silk_resampler_down2_3(
343                        &mut self.down2_3_state,
344                        &mut self.buf_silk_input,
345                        input_i16,
346                        frame_size as i32,
347                    );
348                }
349                &self.buf_silk_input
350            } else {
351                input_i16
352            };
353
354            let mut pn_bytes = 0;
355
356            let silk_max_bits = if mode == OpusMode::Hybrid {
357                ((n_bytes - 1) * 8 * 2 / 5) as i32 // ~40% for SILK in Hybrid
358            } else {
359                ((n_bytes - 1) * 8) as i32
360            };
361            let silk_rate_for_calc = if mode == OpusMode::Hybrid {
362                16000
363            } else {
364                self.sampling_rate
365            };
366            let silk_frame_len = silk_input.len();
367            let silk_bitrate = if mode == OpusMode::Hybrid {
368                (silk_max_bits as i64 * silk_rate_for_calc as i64 / silk_frame_len as i64) as i32
369            } else {
370                // Match C: silk_mode.bitRate = 8 * cbr_bytes * (Fs / frame_size)
371                (8i64 * n_bytes as i64 * silk_rate_for_calc as i64 / silk_frame_len as i64) as i32
372            };
373            let ret = silk_encode(
374                &mut *self.silk_enc,
375                silk_input,
376                silk_input.len(),
377                &mut rc,
378                &mut pn_bytes,
379                silk_bitrate,
380                silk_max_bits,
381                if self.use_cbr { 1 } else { 0 },
382                1, // activity = 1 (assume active)
383            );
384            if ret != 0 {
385                return Err("SILK encoding failed");
386            }
387        }
388
389        if mode == OpusMode::CeltOnly || mode == OpusMode::Hybrid {
390            let start_band = if mode == OpusMode::Hybrid { 17 } else { 0 };
391            self.celt_enc
392                .encode_with_start_band(input, frame_size, &mut rc, start_band);
393        }
394
395        rc.done();
396
397        let silk_payload: Vec<u8> = if mode == OpusMode::SilkOnly {
398            // Build the complete output buffer
399            let mut combined = Vec::with_capacity(rc.storage as usize);
400            combined.extend_from_slice(&rc.buf[0..rc.offs as usize]);
401            combined.extend_from_slice(
402                &rc.buf[(rc.storage - rc.end_offs) as usize..rc.storage as usize],
403            );
404
405            // Strip trailing zeros (C: while(ret>2&&data[ret]==0)ret--)
406            while combined.len() > 2 && combined[combined.len() - 1] == 0 {
407                combined.pop();
408            }
409
410            combined
411        } else {
412            Vec::new()
413        };
414
415        let total_bytes = if mode == OpusMode::SilkOnly {
416            silk_payload.len()
417        } else {
418            n_bytes
419        };
420
421        let payload_bytes = total_bytes.min(output.len() - 1);
422        let ret_with_toc = payload_bytes + 1; // +1 for TOC byte
423
424        if mode == OpusMode::SilkOnly {
425            let target_total = if self.use_cbr {
426                n_bytes.min(output.len())
427            } else {
428                ret_with_toc
429            };
430
431            // Clamp frame data to the space available after TOC + count byte (2 bytes)
432            let max_frame_bytes = target_total
433                .saturating_sub(2)
434                .min(output.len().saturating_sub(2));
435            let frame_len = payload_bytes.min(max_frame_bytes);
436
437            output[0] = toc | 0x03;
438            output[1] = 0x01;
439
440            if frame_len > 0 {
441                output[2..2 + frame_len].copy_from_slice(&silk_payload[..frame_len]);
442            }
443
444            let out_len = output.len();
445            let written = 2 + frame_len;
446            let fill_end = target_total.min(out_len);
447            if written < fill_end {
448                for byte in output[written..fill_end].iter_mut() {
449                    *byte = 0;
450                }
451            }
452
453            return Ok(target_total.min(out_len));
454        }
455
456        if mode == OpusMode::SilkOnly {
457            output[1..1 + payload_bytes].copy_from_slice(&silk_payload[..payload_bytes]);
458        } else {
459            output[1..1 + payload_bytes].copy_from_slice(&rc.buf[..payload_bytes]);
460        }
461        Ok(ret_with_toc)
462    }
463}
464
465pub struct OpusDecoder {
466    celt_dec: CeltDecoder,
467    silk_dec: silk::dec_api::SilkDecoder,
468    sampling_rate: i32,
469    channels: usize,
470    // State tracking for mode transitions
471    prev_mode: Option<OpusMode>,
472    frame_size: usize,
473    /// Internal bandwidth from previous frame
474    bandwidth: Bandwidth,
475    /// Stream channels from previous frame
476    stream_channels: usize,
477    /// SILK decoder resampler (internal rate → API rate)
478    silk_resampler: silk::resampler::SilkResampler,
479    /// Previous internal sample rate (for resampler re-init)
480    prev_internal_rate: i32,
481}
482
483impl OpusDecoder {
484    pub fn new(sampling_rate: i32, channels: usize) -> Result<Self, &'static str> {
485        if ![8000, 12000, 16000, 24000, 48000].contains(&sampling_rate) {
486            return Err("Invalid sampling rate");
487        }
488        if ![1, 2].contains(&channels) {
489            return Err("Invalid number of channels");
490        }
491
492        let mode = modes::default_mode();
493        let celt_dec = CeltDecoder::new(mode, channels);
494
495        let mut silk_dec = silk::dec_api::SilkDecoder::new();
496        silk_dec.init(sampling_rate.min(16000), channels as i32);
497        silk_dec.channel_state[0].fs_api_hz = sampling_rate;
498
499        Ok(Self {
500            celt_dec,
501            silk_dec,
502            sampling_rate,
503            channels,
504            prev_mode: None,
505            frame_size: 0,
506            bandwidth: Bandwidth::Auto,
507            stream_channels: channels,
508            silk_resampler: silk::resampler::SilkResampler::default(),
509            prev_internal_rate: 0,
510        })
511    }
512
513    pub fn decode(
514        &mut self,
515        input: &[u8],
516        frame_size: usize,
517        output: &mut [f32],
518    ) -> Result<usize, &'static str> {
519        if input.is_empty() {
520            return Err("Input packet empty");
521        }
522
523        let toc = input[0];
524        let mode = mode_from_toc(toc);
525        let packet_channels = channels_from_toc(toc);
526        let bandwidth = bandwidth_from_toc(toc);
527        let frame_duration_ms = frame_duration_ms_from_toc(toc);
528
529        if packet_channels != self.channels {
530            return Err("Channel count mismatch between packet and decoder");
531        }
532
533        // Parse packet structure: handle Code 0, 1, 2, 3
534        let code = toc & 0x03;
535        let payload_data;
536
537        match code {
538            0 => {
539                // Code 0: one frame
540                payload_data = &input[1..];
541            }
542            3 => {
543                // Code 3: arbitrary number of frames (CBR or VBR)
544                if input.len() < 2 {
545                    return Err("Code 3 packet too short");
546                }
547                let count_byte = input[1];
548                let _frame_count = (count_byte & 0x3F) as usize;
549                let padding_flag = (count_byte & 0x40) != 0;
550
551                // Parse padding
552                let mut ptr = 2usize;
553                if padding_flag {
554                    let mut pad_len = 0usize;
555                    loop {
556                        if ptr >= input.len() {
557                            return Err("Padding overflow");
558                        }
559                        let p = input[ptr] as usize;
560                        ptr += 1;
561                        if p == 255 {
562                            pad_len += 254; // 255 means 254 data bytes + next count byte
563                        } else {
564                            pad_len += p;
565                            break;
566                        }
567                    }
568                    // Frame data is between ptr and (input.len() - pad_len)
569                    let end = input.len().saturating_sub(pad_len);
570                    if ptr > end {
571                        return Err("Padding exceeds packet");
572                    }
573                    payload_data = &input[ptr..end];
574                } else {
575                    payload_data = &input[ptr..];
576                }
577            }
578            _ => {
579                // Code 1 / 2: two frames
580                payload_data = &input[1..];
581            }
582        }
583
584        self.frame_size = frame_size;
585        self.bandwidth = bandwidth;
586        self.stream_channels = packet_channels;
587
588        match mode {
589            OpusMode::SilkOnly => {
590                // Determine internal sample rate from bandwidth
591                let internal_sample_rate = match bandwidth {
592                    Bandwidth::Narrowband => 8000,
593                    Bandwidth::Mediumband => 12000,
594                    Bandwidth::Wideband => 16000,
595                    _ => 16000,
596                };
597
598                // Decode SILK frame
599                let mut rc = RangeCoder::new_decoder(payload_data.to_vec());
600                let internal_frame_size =
601                    (frame_duration_ms * internal_sample_rate / 1000) as usize;
602                let mut pcm_i16 = vec![0i16; internal_frame_size * self.channels];
603
604                let payload_size_ms = frame_duration_ms;
605
606                let ret = self.silk_dec.decode(
607                    &mut rc,
608                    &mut pcm_i16,
609                    silk::decode_frame::FLAG_DECODE_NORMAL,
610                    true, // new_packet
611                    payload_size_ms,
612                    internal_sample_rate,
613                );
614
615                if ret < 0 {
616                    return Err("SILK decoding failed");
617                }
618
619                let decoded_samples = ret as usize;
620
621                // Sample rate conversion via SilkResampler
622                if self.sampling_rate == internal_sample_rate {
623                    // No resampling needed: direct i16 → f32 conversion
624                    let n = decoded_samples.min(frame_size).min(output.len());
625                    for i in 0..n {
626                        output[i] = pcm_i16[i] as f32 / 32768.0;
627                    }
628                    self.prev_mode = Some(OpusMode::SilkOnly);
629                    Ok(n)
630                } else {
631                    // (Re-)initialize resampler if internal rate changed
632                    if internal_sample_rate != self.prev_internal_rate {
633                        self.silk_resampler
634                            .init(internal_sample_rate, self.sampling_rate);
635                        self.prev_internal_rate = internal_sample_rate;
636                    }
637
638                    let ratio = self.sampling_rate as f64 / internal_sample_rate as f64;
639                    let out_len = ((decoded_samples as f64 * ratio) as usize).min(frame_size);
640                    let mut pcm_out = vec![0i16; out_len];
641                    self.silk_resampler.process(
642                        &mut pcm_out,
643                        &pcm_i16[..decoded_samples],
644                        decoded_samples as i32,
645                    );
646
647                    let n = out_len.min(output.len());
648                    for i in 0..n {
649                        output[i] = pcm_out[i] as f32 / 32768.0;
650                    }
651                    self.prev_mode = Some(OpusMode::SilkOnly);
652                    Ok(n)
653                }
654            }
655
656            OpusMode::CeltOnly => {
657                // CELT decoding (existing path)
658                self.celt_dec.decode(payload_data, frame_size, output);
659                self.prev_mode = Some(OpusMode::CeltOnly);
660                Ok(frame_size)
661            }
662
663            OpusMode::Hybrid => {
664                let internal_sample_rate = match bandwidth {
665                    Bandwidth::Superwideband => 16000, // 24kHz SWB: SILK at 16kHz
666                    Bandwidth::Fullband => 16000,      // 48kHz FB:  SILK at 16kHz
667                    _ => 16000,
668                };
669
670                // Step 1: Decode SILK low-frequency portion
671                let mut rc = RangeCoder::new_decoder(payload_data.to_vec());
672                let internal_frame_size =
673                    (frame_duration_ms * internal_sample_rate / 1000) as usize;
674                let mut pcm_silk_i16 = vec![0i16; internal_frame_size * self.channels];
675
676                let ret = self.silk_dec.decode(
677                    &mut rc,
678                    &mut pcm_silk_i16,
679                    silk::decode_frame::FLAG_DECODE_NORMAL,
680                    true,
681                    frame_duration_ms,
682                    internal_sample_rate,
683                );
684
685                // Convert SILK PCM to f32 at output rate
686                let mut silk_out = vec![0.0f32; frame_size * self.channels];
687                if ret > 0 {
688                    let decoded_samples = ret as usize;
689                    if self.sampling_rate == internal_sample_rate {
690                        let n = decoded_samples.min(frame_size);
691                        for i in 0..n {
692                            silk_out[i] = pcm_silk_i16[i] as f32 / 32768.0;
693                        }
694                    } else {
695                        // Resample SILK output to API rate
696                        if internal_sample_rate != self.prev_internal_rate {
697                            self.silk_resampler
698                                .init(internal_sample_rate, self.sampling_rate);
699                            self.prev_internal_rate = internal_sample_rate;
700                        }
701                        let ratio = self.sampling_rate as f64 / internal_sample_rate as f64;
702                        let out_len = ((decoded_samples as f64 * ratio) as usize).min(frame_size);
703                        let mut pcm_resampled = vec![0i16; out_len];
704                        self.silk_resampler.process(
705                            &mut pcm_resampled,
706                            &pcm_silk_i16[..decoded_samples],
707                            decoded_samples as i32,
708                        );
709                        for i in 0..out_len.min(frame_size) {
710                            silk_out[i] = pcm_resampled[i] as f32 / 32768.0;
711                        }
712                    }
713                }
714
715                // Step 2: Decode CELT high-frequency portion (start_band = 17)
716                let mut celt_out = vec![0.0f32; frame_size * self.channels];
717                // start_band = 17 matches the SWB/FB Hybrid crossover point
718                self.celt_dec
719                    .decode_with_start_band(payload_data, frame_size, &mut celt_out, 17);
720
721                let n = frame_size.min(output.len());
722                for i in 0..n {
723                    output[i] = silk_out[i] + celt_out[i];
724                    // Clamp to avoid clipping
725                    output[i] = output[i].clamp(-1.0, 1.0);
726                }
727
728                self.prev_mode = Some(OpusMode::Hybrid);
729                Ok(n)
730            }
731        }
732    }
733}
734
735fn frame_rate_from_params(sampling_rate: i32, frame_size: usize) -> Option<i32> {
736    let frame_size = frame_size as i32;
737    if frame_size == 0 || sampling_rate % frame_size != 0 {
738        return None;
739    }
740    Some(sampling_rate / frame_size)
741}
742
743fn gen_toc(mode: OpusMode, frame_rate: i32, bandwidth: Bandwidth, channels: usize) -> u8 {
744    let mut rate = frame_rate;
745    let mut period = 0;
746    while rate < 400 {
747        rate <<= 1;
748        period += 1;
749    }
750
751    let mut toc = match mode {
752        OpusMode::SilkOnly => {
753            let bw = (bandwidth as i32 - Bandwidth::Narrowband as i32) << 5;
754            let per = (period - 2) << 3;
755            (bw | per) as u8
756        }
757        OpusMode::CeltOnly => {
758            let mut tmp = bandwidth as i32 - Bandwidth::Mediumband as i32;
759            if tmp < 0 {
760                tmp = 0;
761            }
762            let per = period << 3;
763            (0x80 | (tmp << 5) | per) as u8
764        }
765        OpusMode::Hybrid => {
766            // Hybrid configs: 16-19 = SWB, 20-23 = FB
767            // Period 0 = 10ms, period 1 = 20ms
768            // (Hybrid only supports 10ms and 20ms)
769            let base_config = if bandwidth == Bandwidth::Superwideband {
770                16
771            } else {
772                20
773            };
774            let hybrid_period = if frame_rate >= 100 { 0 } else { 1 }; // 100fps=10ms, 50fps=20ms
775            ((base_config + hybrid_period) << 3) as u8
776        }
777    };
778
779    if channels == 2 {
780        toc |= 0x04;
781    }
782    toc
783}
784
785fn mode_from_toc(toc: u8) -> OpusMode {
786    if toc & 0x80 != 0 {
787        OpusMode::CeltOnly
788    } else if toc & 0x60 == 0x60 {
789        OpusMode::Hybrid
790    } else {
791        OpusMode::SilkOnly
792    }
793}
794
795/// Extract bandwidth from TOC byte
796fn bandwidth_from_toc(toc: u8) -> Bandwidth {
797    let mode = mode_from_toc(toc);
798    match mode {
799        OpusMode::SilkOnly => {
800            let bw_bits = (toc >> 5) & 0x03;
801            match bw_bits {
802                0 => Bandwidth::Narrowband,
803                1 => Bandwidth::Mediumband,
804                2 => Bandwidth::Wideband,
805                _ => Bandwidth::Wideband,
806            }
807        }
808        OpusMode::Hybrid => {
809            let bw_bit = (toc >> 4) & 0x01;
810            if bw_bit == 0 {
811                Bandwidth::Superwideband
812            } else {
813                Bandwidth::Fullband
814            }
815        }
816        OpusMode::CeltOnly => {
817            let bw_bits = (toc >> 5) & 0x03;
818            match bw_bits {
819                0 => Bandwidth::Mediumband,
820                1 => Bandwidth::Wideband,
821                2 => Bandwidth::Superwideband,
822                3 => Bandwidth::Fullband,
823                _ => Bandwidth::Fullband,
824            }
825        }
826    }
827}
828
829/// Extract frame duration in milliseconds from TOC byte
830fn frame_duration_ms_from_toc(toc: u8) -> i32 {
831    let mode = mode_from_toc(toc);
832    match mode {
833        OpusMode::SilkOnly => {
834            let config = (toc >> 3) & 0x03;
835            match config {
836                0 => 10,
837                1 => 20,
838                2 => 40,
839                3 => 60,
840                _ => 20,
841            }
842        }
843        OpusMode::Hybrid => {
844            let config = (toc >> 3) & 0x01;
845            if config == 0 { 10 } else { 20 }
846        }
847        OpusMode::CeltOnly => {
848            let config = (toc >> 3) & 0x03;
849            match config {
850                0 => {
851                    // 2.5 ms
852                    // Return 2 for 2.5ms (caller will handle)
853                    2 // Approximation; actual is 2.5ms
854                }
855                1 => 5,
856                2 => 10,
857                3 => 20,
858                _ => 20,
859            }
860        }
861    }
862}
863
864fn channels_from_toc(toc: u8) -> usize {
865    if toc & 0x04 != 0 { 2 } else { 1 }
866}
867
868#[cfg(test)]
869mod tests {
870    use super::*;
871
872    fn frame_size_from_toc(toc: u8, sampling_rate: i32) -> Option<usize> {
873        let mode = mode_from_toc(toc);
874        match mode {
875            OpusMode::CeltOnly => {
876                let period = ((toc >> 3) & 0x03) as i32;
877                let frame_rate = 400 >> period;
878                if frame_rate == 0 || sampling_rate % frame_rate != 0 {
879                    return None;
880                }
881                Some((sampling_rate / frame_rate) as usize)
882            }
883            OpusMode::SilkOnly => {
884                let duration_ms = frame_duration_ms_from_toc(toc);
885                Some((sampling_rate as i64 * duration_ms as i64 / 1000) as usize)
886            }
887            OpusMode::Hybrid => {
888                let duration_ms = frame_duration_ms_from_toc(toc);
889                Some((sampling_rate as i64 * duration_ms as i64 / 1000) as usize)
890            }
891        }
892    }
893
894    #[test]
895    fn gen_toc_matches_celt_reference_values() {
896        let sampling_rate = 48_000;
897        let cases = [
898            (120usize, 0xE0u8),
899            (240usize, 0xE8u8),
900            (480usize, 0xF0u8),
901            (960usize, 0xF8u8),
902        ];
903
904        for (frame_size, expected_toc) in cases {
905            let frame_rate = frame_rate_from_params(sampling_rate, frame_size).unwrap();
906            let toc = gen_toc(OpusMode::CeltOnly, frame_rate, Bandwidth::Fullband, 1);
907            assert_eq!(
908                toc, expected_toc,
909                "frame_size {} expected TOC {:02X} got {:02X}",
910                frame_size, expected_toc, toc
911            );
912            let decoded_size = frame_size_from_toc(toc, sampling_rate).unwrap();
913            assert_eq!(decoded_size, frame_size);
914        }
915
916        let stereo_toc = gen_toc(
917            OpusMode::CeltOnly,
918            frame_rate_from_params(sampling_rate, 960).unwrap(),
919            Bandwidth::Fullband,
920            2,
921        );
922        assert_eq!(channels_from_toc(stereo_toc), 2);
923    }
924
925    /// Test that large frame sizes don't cause integer underflow in CELT decoder.
926    /// This is a regression test for the bug where frame_size > decode_buffer_size + overlap
927    /// (e.g., 2880 > 2048 + 120 = 2168) caused panic due to integer underflow.
928    #[test]
929    fn test_celt_decoder_large_frame_sizes() {
930        let sampling_rate = 48000;
931        let channels = 1;
932
933        let mut decoder = OpusDecoder::new(sampling_rate, channels).unwrap();
934
935        // Test all valid CELT frame sizes at 48kHz
936        // CELT supports: 2.5ms, 5ms, 10ms, 20ms (120, 240, 480, 960 samples @ 48kHz)
937        // Note: 40ms and 60ms are SILK-only frame sizes, not CELT
938        let frame_sizes = [120, 240, 480, 960];
939
940        for frame_size in frame_sizes {
941            // Create a minimal valid CELT packet (just TOC byte + some data)
942            let toc = gen_toc(OpusMode::CeltOnly, frame_rate_from_params(sampling_rate, frame_size).unwrap(), Bandwidth::Fullband, channels);
943            let packet = [toc, 0, 0, 0, 0]; // Minimal packet
944
945            let mut output = vec![0.0f32; frame_size * channels];
946
947            // This should not panic - the decode may fail with invalid data,
948            // but it should not cause integer underflow
949            let _ = decoder.decode(&packet, frame_size, &mut output);
950        }
951
952        // Test stereo as well
953        let channels = 2;
954        let mut decoder = OpusDecoder::new(sampling_rate, channels).unwrap();
955
956        for frame_size in frame_sizes {
957            let toc = gen_toc(OpusMode::CeltOnly, frame_rate_from_params(sampling_rate, frame_size).unwrap(), Bandwidth::Fullband, channels);
958            let packet = [toc, 0, 0, 0, 0];
959
960            let mut output = vec![0.0f32; frame_size * channels];
961            let _ = decoder.decode(&packet, frame_size, &mut output);
962        }
963    }
964
965    /// Test edge cases around the old boundary (2048 + 120 = 2168)
966    /// These frame sizes are not valid for CELT but should not cause integer underflow
967    #[test]
968    fn test_celt_decoder_edge_case_frame_sizes() {
969        let sampling_rate = 48000;
970        let channels = 1;
971        let mut decoder = OpusDecoder::new(sampling_rate, channels).unwrap();
972
973        // Test frame sizes around the old boundary
974        // Old bug: frame_size > 2048 + 120 = 2168 would cause underflow
975        // These sizes are invalid for CELT but should return gracefully, not panic
976        let edge_sizes = [2048, 2167, 2168, 2169, 2880, 3072];
977
978        for frame_size in edge_sizes {
979            let mut output = vec![0.0f32; frame_size * channels];
980            // Use arbitrary data - decode will likely fail, but shouldn't panic
981            let _ = decoder.decode(&[0x80, 0, 0, 0], frame_size, &mut output);
982        }
983    }
984}