Skip to main content

opus_rs/
lib.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2#![allow(clippy::too_many_arguments)]
3#![allow(clippy::needless_range_loop)]
4
5pub mod bands;
6pub mod celt;
7pub mod celt_lpc;
8pub mod hp_cutoff;
9pub mod kiss_fft;
10pub mod mdct;
11pub mod modes;
12pub mod pitch;
13pub mod pvq;
14pub mod quant_bands;
15pub mod range_coder;
16pub mod rate;
17pub mod silk;
18
19pub use silk::{SilkResampler, SilkResamplerDown1_3, SilkResamplerDown1_6};
20
21pub use celt::{CeltDecoder, CeltEncoder};
22use hp_cutoff::hp_cutoff;
23use range_coder::RangeCoder;
24use silk::control_codec::silk_control_encoder;
25use silk::enc_api::silk_encode;
26use silk::init_encoder::silk_init_encoder;
27use silk::lin2log::silk_lin2log;
28use silk::log2lin::silk_log2lin;
29use silk::macros::*;
30use silk::resampler::{silk_resampler_down2, silk_resampler_down2_3};
31use silk::structs::SilkEncoderState;
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum Application {
35    Voip = 2048,
36    Audio = 2049,
37    RestrictedLowDelay = 2051,
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum Bandwidth {
42    Auto = -1000,
43    Narrowband = 1101,
44    Mediumband = 1102,
45    Wideband = 1103,
46    Superwideband = 1104,
47    Fullband = 1105,
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51enum OpusMode {
52    SilkOnly,
53    Hybrid,
54    CeltOnly,
55}
56
57pub struct OpusEncoder {
58    celt_enc: CeltEncoder,
59    silk_enc: Box<SilkEncoderState>,
60    application: Application,
61    sampling_rate: i32,
62    channels: usize,
63    bandwidth: Bandwidth,
64    pub bitrate_bps: i32,
65    pub complexity: i32,
66    pub use_cbr: bool,
67
68    pub use_inband_fec: bool,
69
70    pub packet_loss_perc: i32,
71    silk_initialized: bool,
72    mode: OpusMode,
73    prev_enc_mode: Option<OpusMode>,
74
75    variable_hp_smth2_q15: i32,
76    hp_mem: Vec<i32>,
77
78    buf_filtered: Vec<i16>,
79    buf_silk_input: Vec<i16>,
80    buf_stereo_mid: Vec<i16>,
81    buf_stereo_side: Vec<i16>,
82    buf_celt_input: Vec<f32>,
83    down2_state_first: [i32; 2],
84    down2_state_second: [i32; 2],
85    down2_3_state: [i32; 6],
86    down_1_3_state: silk::resampler::SilkResamplerDown1_3,
87
88    rc: RangeCoder,
89}
90
91fn compute_equiv_rate(
92    bitrate: i32,
93    channels: usize,
94    frame_rate: i32,
95    vbr: bool,
96    complexity: i32,
97    loss: i32,
98) -> i32 {
99    let mut equiv = bitrate;
100    if frame_rate > 50 {
101        equiv -= (40 * channels as i32 + 20) * (frame_rate - 50);
102    }
103    if !vbr {
104        equiv -= equiv / 12;
105    }
106    equiv = equiv * (90 + complexity) / 100;
107    if loss > 0 {
108        equiv -= equiv * loss / (12 * loss + 20);
109    }
110    equiv
111}
112
113fn compute_mode_threshold(
114    application: Application,
115    channels: usize,
116    prev_was_celt: bool,
117    has_prev_mode: bool,
118    voice_est: i32,
119) -> i32 {
120    let mode_voice = if channels == 1 { 64000 } else { 44000 };
121    let mode_music = 10000;
122
123    let diff = mode_voice - mode_music;
124    let offset = (voice_est * voice_est * diff) >> 14;
125    let mut threshold = mode_music + offset;
126
127    if application == Application::Voip {
128        threshold += 8000;
129    }
130
131    if has_prev_mode {
132        if prev_was_celt {
133            threshold -= 4000;
134        } else {
135            threshold += 4000;
136        }
137    }
138
139    if application == Application::RestrictedLowDelay {
140        threshold = 0;
141    }
142
143    threshold
144}
145
146fn compute_silk_rate_for_hybrid(rate_bps: i32, frame20ms: bool) -> i32 {
147    const RATE_TABLE: &[(i32, i32, i32)] = &[
148        (0, 0, 0),
149        (12000, 10000, 10000),
150        (16000, 13500, 13500),
151        (20000, 16000, 16000),
152        (24000, 18000, 18000),
153        (32000, 22000, 22000),
154        (64000, 38000, 38000),
155    ];
156    let n = RATE_TABLE.len();
157    let mut i = 1;
158    while i < n && RATE_TABLE[i].0 <= rate_bps {
159        i += 1;
160    }
161    if i == n {
162        let (x_last, r10_last, r20_last) = RATE_TABLE[n - 1];
163        let base = if frame20ms { r20_last } else { r10_last };
164        base + (rate_bps - x_last) / 2
165    } else {
166        let (x0, lo10, lo20) = RATE_TABLE[i - 1];
167        let (x1, hi10, hi20) = RATE_TABLE[i];
168        let (lo, hi) = if frame20ms {
169            (lo20, hi20)
170        } else {
171            (lo10, hi10)
172        };
173        (lo * (x1 - rate_bps) + hi * (rate_bps - x0)) / (x1 - x0)
174    }
175}
176
177#[cfg(test)]
178mod silk_rate_tests {
179    use super::compute_silk_rate_for_hybrid;
180
181    #[test]
182    fn test_reference_table_exact_entries() {
183        assert_eq!(compute_silk_rate_for_hybrid(12000, true), 10000);
184        assert_eq!(compute_silk_rate_for_hybrid(16000, true), 13500);
185        assert_eq!(compute_silk_rate_for_hybrid(20000, true), 16000);
186        assert_eq!(compute_silk_rate_for_hybrid(24000, true), 18000);
187        assert_eq!(compute_silk_rate_for_hybrid(32000, true), 22000);
188        assert_eq!(compute_silk_rate_for_hybrid(64000, true), 38000);
189    }
190
191    #[test]
192    fn test_32kbps_gives_22kbps_silk() {
193        assert_eq!(compute_silk_rate_for_hybrid(32000, true), 22000);
194    }
195
196    #[test]
197    fn test_interpolation_between_table_entries() {
198        let r = compute_silk_rate_for_hybrid(18000, true);
199        assert_eq!(r, 14750);
200    }
201
202    #[test]
203    fn test_above_table_max_gives_half_extra() {
204        let r = compute_silk_rate_for_hybrid(72000, true);
205        assert_eq!(r, 38000 + (72000 - 64000) / 2);
206    }
207}
208
209impl OpusEncoder {
210    pub fn new(
211        sampling_rate: i32,
212        channels: usize,
213        application: Application,
214    ) -> Result<Self, &'static str> {
215        if ![8000, 12000, 16000, 24000, 48000].contains(&sampling_rate) {
216            return Err("Invalid sampling rate");
217        }
218        if ![1, 2].contains(&channels) {
219            return Err("Invalid number of channels");
220        }
221
222        let mode = modes::default_mode();
223        let celt_enc = CeltEncoder::new(mode, channels);
224
225        let mut silk_enc = Box::new(SilkEncoderState::default());
226        if silk_init_encoder(&mut silk_enc, 0) != 0 {
227            return Err("SILK encoder initialization failed");
228        }
229
230        let (opus_mode, bw) = match application {
231            Application::Voip => {
232                let bw = match sampling_rate {
233                    8000 => Bandwidth::Narrowband,
234                    12000 => Bandwidth::Mediumband,
235                    16000 => Bandwidth::Wideband,
236                    24000 => Bandwidth::Superwideband,
237                    48000 => Bandwidth::Fullband,
238                    _ => Bandwidth::Narrowband,
239                };
240
241                let mode = if sampling_rate > 16000 {
242                    OpusMode::Hybrid
243                } else {
244                    OpusMode::SilkOnly
245                };
246                (mode, bw)
247            }
248            Application::RestrictedLowDelay => {
249                let bw = match sampling_rate {
250                    8000 => Bandwidth::Narrowband,
251                    12000 => Bandwidth::Mediumband,
252                    16000 => Bandwidth::Wideband,
253                    24000 => Bandwidth::Superwideband,
254                    _ => Bandwidth::Fullband,
255                };
256                (OpusMode::CeltOnly, bw)
257            }
258            Application::Audio => {
259                if sampling_rate <= 16000 {
260                    let bw = match sampling_rate {
261                        8000 => Bandwidth::Narrowband,
262                        12000 => Bandwidth::Mediumband,
263                        _ => Bandwidth::Wideband,
264                    };
265                    (OpusMode::SilkOnly, bw)
266                } else {
267                    let bw = match sampling_rate {
268                        24000 => Bandwidth::Superwideband,
269                        _ => Bandwidth::Fullband,
270                    };
271                    (OpusMode::Hybrid, bw)
272                }
273            }
274        };
275
276        use silk::lin2log::silk_lin2log;
277        let variable_hp_smth2_q15 = silk_lin2log(60) << 8;
278
279        Ok(Self {
280            celt_enc,
281            silk_enc,
282            application,
283            sampling_rate,
284            channels,
285            bandwidth: bw,
286            bitrate_bps: 64000,
287            complexity: 9,
288            use_cbr: false,
289            use_inband_fec: false,
290            packet_loss_perc: 0,
291            silk_initialized: false,
292            prev_enc_mode: None,
293            mode: opus_mode,
294            variable_hp_smth2_q15,
295            hp_mem: vec![0; channels * 2],
296
297            buf_filtered: Vec::new(),
298            buf_silk_input: Vec::new(),
299            buf_stereo_mid: Vec::new(),
300            buf_stereo_side: Vec::new(),
301            buf_celt_input: Vec::new(),
302            down2_state_first: [0; 2],
303            down2_state_second: [0; 2],
304            down2_3_state: [0; 6],
305            down_1_3_state: silk::resampler::SilkResamplerDown1_3::default(),
306            rc: RangeCoder::new_encoder(1),
307        })
308    }
309
310    pub fn enable_hybrid_mode(&mut self) -> Result<(), &'static str> {
311        if self.sampling_rate != 24000 && self.sampling_rate != 48000 {
312            return Err("Hybrid mode requires 24kHz or 48kHz sampling rate");
313        }
314        let bw = if self.sampling_rate == 48000 {
315            Bandwidth::Fullband
316        } else {
317            Bandwidth::Superwideband
318        };
319        self.mode = OpusMode::Hybrid;
320        self.bandwidth = bw;
321        self.silk_initialized = false;
322        Ok(())
323    }
324
325    pub fn encode(
326        &mut self,
327        input: &[f32],
328        frame_size: usize,
329        output: &mut [u8],
330    ) -> Result<usize, &'static str> {
331        if output.len() < 2 {
332            return Err("Output buffer too small");
333        }
334
335        let frame_rate = frame_rate_from_params(self.sampling_rate, frame_size)
336            .ok_or("Invalid frame size for sampling rate")?;
337
338        // Mode selection: match C's opus_encode_native() behavior.
339        // C reference auto-selects between SILK_ONLY and CELT_ONLY; Hybrid is
340        // produced afterwards by bandwidth overrides (SILK-only + FB/SWB → Hybrid).
341        let mut mode = if self.application == Application::RestrictedLowDelay {
342            OpusMode::CeltOnly
343        } else {
344            let equiv = compute_equiv_rate(
345                self.bitrate_bps,
346                self.channels,
347                frame_rate,
348                !self.use_cbr,
349                self.complexity,
350                self.packet_loss_perc,
351            );
352            let prev_was_celt = self.prev_enc_mode == Some(OpusMode::CeltOnly);
353            let has_prev_mode = self.prev_enc_mode.is_some();
354            let voice_est = match self.application {
355                Application::Voip => 115,
356                Application::Audio => 48,
357                Application::RestrictedLowDelay => 0,
358            };
359            let threshold = compute_mode_threshold(
360                self.application,
361                self.channels,
362                prev_was_celt,
363                has_prev_mode,
364                voice_est,
365            );
366            if equiv >= threshold && self.sampling_rate >= 24000 {
367                OpusMode::CeltOnly
368            } else {
369                OpusMode::SilkOnly
370            }
371        };
372
373        let curr_bw = self.bandwidth;
374        if mode == OpusMode::SilkOnly
375            && (curr_bw == Bandwidth::Superwideband || curr_bw == Bandwidth::Fullband)
376        {
377            mode = OpusMode::Hybrid;
378        }
379        if mode == OpusMode::Hybrid
380            && (curr_bw == Bandwidth::Narrowband
381                || curr_bw == Bandwidth::Mediumband
382                || curr_bw == Bandwidth::Wideband)
383        {
384            mode = OpusMode::SilkOnly;
385        }
386
387        if mode == OpusMode::CeltOnly {
388            match frame_rate {
389                400 | 200 | 100 | 50 => {}
390                _ => return Err("Unsupported frame size for CELT-only mode"),
391            }
392        }
393
394        if mode == OpusMode::Hybrid {
395            match frame_rate {
396                100 | 50 => {}
397                _ => return Err("Unsupported frame size for Hybrid mode"),
398            }
399        }
400
401        if mode == OpusMode::SilkOnly {
402            match frame_rate {
403                400 | 200 | 100 | 50 | 25 => {}
404                _ => return Err("Unsupported frame size for SILK-only mode"),
405            }
406        }
407
408        let toc = gen_toc(mode, frame_rate, self.bandwidth, self.channels);
409        output[0] = toc;
410
411        let target_bits =
412            (self.bitrate_bps as i64 * frame_size as i64 / self.sampling_rate as i64) as i32;
413        let cbr_bytes = ((target_bits + 4) / 8) as usize;
414        let max_data_bytes = output.len();
415
416        let n_bytes = cbr_bytes.min(max_data_bytes).max(1);
417
418        let init_rc_size = n_bytes - 1;
419        self.rc.reset_for_encode(init_rc_size as u32);
420
421        if mode == OpusMode::SilkOnly || mode == OpusMode::Hybrid {
422            let silk_fs_khz = if mode == OpusMode::Hybrid {
423                16
424            } else {
425                self.sampling_rate.min(16000) / 1000
426            };
427
428            let frame_ms = (frame_size as i32 * 1000) / self.sampling_rate;
429            if !self.silk_initialized || self.silk_enc.s_cmn.fs_khz != silk_fs_khz {
430                let silk_init_bitrate = (((n_bytes - 1) * 8) as i64 * self.sampling_rate as i64
431                    / frame_size as i64) as i32;
432                silk_control_encoder(
433                    &mut self.silk_enc,
434                    silk_fs_khz,
435                    frame_ms,
436                    silk_init_bitrate,
437                    self.complexity,
438                );
439                self.silk_enc.s_cmn.use_cbr = if self.use_cbr { 1 } else { 0 };
440
441                self.silk_enc.s_cmn.n_channels = self.channels as i32;
442                self.silk_initialized = true;
443                self.down2_state_first = [0; 2];
444                self.down2_state_second = [0; 2];
445                self.down2_3_state = [0; 6];
446                self.down_1_3_state = silk::resampler::SilkResamplerDown1_3::default();
447            }
448
449            self.silk_enc.s_cmn.use_in_band_fec = if self.use_inband_fec { 1 } else { 0 };
450            self.silk_enc.s_cmn.packet_loss_perc = self.packet_loss_perc.clamp(0, 100);
451
452            self.silk_enc.s_cmn.lbrr_enabled = if self.use_inband_fec { 1 } else { 0 };
453
454            if self.silk_enc.s_cmn.lbrr_gain_increases == 0 {
455                self.silk_enc.s_cmn.lbrr_gain_increases = 2;
456            }
457
458            let hp_freq_smth1 = if mode == OpusMode::CeltOnly {
459                silk_lin2log(60) << 8
460            } else {
461                self.silk_enc.s_cmn.variable_hp_smth1_q15
462            };
463
464            const VARIABLE_HP_SMTH_COEF2_Q16: i32 = 984;
465            self.variable_hp_smth2_q15 = silk_smlawb(
466                self.variable_hp_smth2_q15,
467                hp_freq_smth1 - self.variable_hp_smth2_q15,
468                VARIABLE_HP_SMTH_COEF2_Q16,
469            );
470
471            let cutoff_hz = silk_log2lin(silk_rshift(self.variable_hp_smth2_q15, 8));
472
473            let required_size = frame_size * self.channels;
474            self.buf_filtered.resize(required_size, 0);
475            if self.application == Application::Voip {
476                hp_cutoff(
477                    input,
478                    cutoff_hz,
479                    &mut self.buf_filtered,
480                    &mut self.hp_mem,
481                    frame_size,
482                    self.channels,
483                    self.sampling_rate,
484                );
485            } else {
486                for (i, &x) in input.iter().enumerate() {
487                    self.buf_filtered[i] = (x * 32768.0).clamp(-32768.0, 32767.0) as i16;
488                }
489            }
490
491            let input_i16 = &self.buf_filtered;
492
493            let silk_input: &[i16] = if mode == OpusMode::SilkOnly && self.sampling_rate > 16000 {
494                if self.sampling_rate == 48000 {
495                    let stage1_size = frame_size / 2;
496                    let mut stage1_buf = [0i16; 480];
497                    silk_resampler_down2(
498                        &mut self.down2_state_first,
499                        &mut stage1_buf[..stage1_size],
500                        input_i16,
501                        frame_size as i32,
502                    );
503                    let silk_frame_size = stage1_size * 2 / 3;
504                    self.buf_silk_input.resize(silk_frame_size, 0);
505                    silk_resampler_down2_3(
506                        &mut self.down2_3_state,
507                        &mut self.buf_silk_input,
508                        &stage1_buf[..stage1_size],
509                        stage1_size as i32,
510                    );
511                    &self.buf_silk_input
512                } else if self.sampling_rate == 24000 {
513                    let silk_frame_size = frame_size * 2 / 3;
514                    self.buf_silk_input.resize(silk_frame_size, 0);
515                    silk_resampler_down2_3(
516                        &mut self.down2_3_state,
517                        &mut self.buf_silk_input,
518                        input_i16,
519                        frame_size as i32,
520                    );
521                    &self.buf_silk_input
522                } else {
523                    input_i16
524                }
525            } else if mode == OpusMode::SilkOnly && self.channels == 2 {
526                let frame_length = input_i16.len() / 2;
527                self.buf_stereo_mid.resize(frame_length, 0);
528                self.buf_stereo_side.resize(frame_length, 0);
529                for i in 0..frame_length {
530                    let l = input_i16[2 * i] as i32;
531                    let r = input_i16[2 * i + 1] as i32;
532                    self.buf_stereo_mid[i] = ((l + r) / 2) as i16;
533                    self.buf_stereo_side[i] = (l - r) as i16;
534                }
535
536                self.silk_enc.stereo.side.resize(frame_length, 0);
537                self.silk_enc
538                    .stereo
539                    .side
540                    .copy_from_slice(&self.buf_stereo_side[..frame_length]);
541                &self.buf_stereo_mid
542            } else if mode == OpusMode::Hybrid && self.sampling_rate > 16000 {
543                if self.sampling_rate == 48000 {
544                    let silk_frame_size = frame_size / 3;
545                    self.buf_silk_input.resize(silk_frame_size, 0);
546                    silk::resampler::silk_resampler_down_1_3(
547                        &mut self.down_1_3_state,
548                        &mut self.buf_silk_input,
549                        input_i16,
550                    );
551                } else {
552                    let silk_frame_size = frame_size * 2 / 3;
553                    self.buf_silk_input.resize(silk_frame_size, 0);
554                    silk_resampler_down2_3(
555                        &mut self.down2_3_state,
556                        &mut self.buf_silk_input,
557                        input_i16,
558                        frame_size as i32,
559                    );
560                }
561                &self.buf_silk_input
562            } else {
563                input_i16
564            };
565
566            let mut pn_bytes = 0;
567
568            let silk_rate_for_calc = if mode == OpusMode::Hybrid {
569                16000
570            } else {
571                self.sampling_rate
572            };
573            let silk_frame_len = silk_input.len();
574
575            let silk_bitrate = if mode == OpusMode::Hybrid {
576                let frame_duration_ms = frame_size as i32 * 1000 / self.sampling_rate;
577                let frame20ms = frame_duration_ms >= 20;
578                compute_silk_rate_for_hybrid(self.bitrate_bps, frame20ms)
579            } else {
580                (8i64 * (n_bytes - 1) as i64 * silk_rate_for_calc as i64 / silk_frame_len as i64)
581                    as i32
582            };
583            let silk_max_bits = if mode == OpusMode::Hybrid {
584                let total_max_bits = ((n_bytes - 1) * 8) as i32;
585                if self.use_cbr {
586                    let silk_bits = (silk_bitrate as i64 * silk_frame_len as i64
587                        / silk_rate_for_calc as i64) as i32;
588                    let other_bits = 0i32.max(total_max_bits - silk_bits);
589                    0i32.max(total_max_bits - other_bits * 3 / 4)
590                } else {
591                    let frame_duration_ms = frame_size as i32 * 1000 / self.sampling_rate;
592                    let frame20ms = frame_duration_ms >= 20;
593                    let max_bit_rate = compute_silk_rate_for_hybrid(
594                        total_max_bits * self.sampling_rate / frame_size as i32,
595                        frame20ms,
596                    );
597                    max_bit_rate * frame_size as i32 / self.sampling_rate
598                }
599            } else {
600                ((n_bytes - 1) * 8) as i32
601            };
602            let silk_use_cbr = if mode == OpusMode::Hybrid && self.use_cbr {
603                0
604            } else if self.use_cbr {
605                1
606            } else {
607                0
608            };
609            let ret = silk_encode(
610                &mut self.silk_enc,
611                silk_input,
612                silk_input.len(),
613                &mut self.rc,
614                &mut pn_bytes,
615                silk_bitrate,
616                silk_max_bits,
617                silk_use_cbr,
618                1,
619            );
620            if ret != 0 {
621                return Err("SILK encoding failed");
622            }
623        }
624
625        if mode == OpusMode::Hybrid {
626            self.rc.encode_bit_logp(false, 12); // redundancy = 0
627        }
628
629        if mode == OpusMode::Hybrid {
630            let nb_compr_bytes = (n_bytes - 1) as u32;
631            self.rc.shrink(nb_compr_bytes);
632        }
633
634        let silk_ret_bytes = if mode == OpusMode::SilkOnly {
635            ((self.rc.tell() + 7) >> 3) as usize
636        } else {
637            0
638        };
639
640        if mode == OpusMode::CeltOnly || mode == OpusMode::Hybrid {
641            self.celt_enc.complexity = self.complexity;
642            let start_band = if mode == OpusMode::Hybrid { 17 } else { 0 };
643            let total_packet_bits = ((n_bytes - 1) * 8) as i32;
644
645            let celt_input: &[f32] = if self.channels == 1 {
646                input
647            } else {
648                let n = frame_size * self.channels;
649                self.buf_celt_input.resize(n, 0.0);
650                for i in 0..frame_size {
651                    for ch in 0..self.channels {
652                        self.buf_celt_input[ch * frame_size + i] = input[i * self.channels + ch];
653                    }
654                }
655                &self.buf_celt_input
656            };
657
658            if self.rc.tell() <= total_packet_bits {
659                self.celt_enc.encode_with_budget(
660                    celt_input,
661                    frame_size,
662                    &mut self.rc,
663                    start_band,
664                    total_packet_bits,
665                );
666            }
667        }
668
669        self.rc.done();
670
671        if mode == OpusMode::SilkOnly {
672            let mut ret = silk_ret_bytes.min(self.rc.storage as usize);
673            while ret > 2 && self.rc.buf[ret - 1] == 0 {
674                ret -= 1;
675            }
676
677            let target_total = if self.use_cbr {
678                n_bytes.min(output.len())
679            } else {
680                (ret + 1).min(output.len())
681            };
682
683            let silk_len = ret;
684
685            if !self.use_cbr || silk_len + 1 >= target_total {
686                // VBR or payload fills the target: simple code 0 packet
687                output[0] = toc;
688                let copy_len = silk_len.min(target_total - 1);
689                output[1..1 + copy_len].copy_from_slice(&self.rc.buf[..copy_len]);
690                return Ok((copy_len + 1).min(output.len()));
691            }
692
693            output[0] = toc | 0x03;
694
695            if silk_len + 2 >= target_total {
696                output[1] = 0x01;
697                let copy_len = (target_total - 2).min(silk_len);
698                output[2..2 + copy_len].copy_from_slice(&self.rc.buf[..copy_len]);
699                self.prev_enc_mode = Some(mode);
700                return Ok(target_total.min(output.len()));
701            }
702
703            let pad_amount = target_total - silk_len - 2;
704            output[1] = 0x41;
705
706            let nb_255s = (pad_amount - 1) / 255;
707            let mut ptr = 2;
708            for _ in 0..nb_255s {
709                output[ptr] = 255;
710                ptr += 1;
711            }
712            output[ptr] = (pad_amount - 255 * nb_255s - 1) as u8;
713            ptr += 1;
714
715            output[ptr..ptr + silk_len].copy_from_slice(&self.rc.buf[..silk_len]);
716            ptr += silk_len;
717
718            let fill_end = target_total.min(output.len());
719            for byte in output[ptr..fill_end].iter_mut() {
720                *byte = 0;
721            }
722
723            self.prev_enc_mode = Some(mode);
724            return Ok(target_total.min(output.len()));
725        }
726
727        let payload_len = n_bytes - 1;
728        output[1..1 + payload_len].copy_from_slice(&self.rc.buf[..payload_len]);
729        self.prev_enc_mode = Some(mode);
730        Ok(n_bytes)
731    }
732}
733
734pub struct OpusDecoder {
735    celt_dec: CeltDecoder,
736    silk_dec: silk::dec_api::SilkDecoder,
737    sampling_rate: i32,
738    channels: usize,
739
740    prev_mode: Option<OpusMode>,
741    frame_size: usize,
742
743    bandwidth: Bandwidth,
744
745    stream_channels: usize,
746
747    silk_resampler: silk::resampler::SilkResampler,
748
749    prev_internal_rate: i32,
750
751    pub hybrid_skip_celt: bool,
752
753    w_pcm_i16: Vec<i16>,
754    w_silk_out: Vec<f32>,
755    w_pcm_resampled: Vec<i16>,
756    w_celt_planar: Vec<f32>,
757    w_celt_out: Vec<f32>,
758}
759
760impl OpusDecoder {
761    pub fn new(sampling_rate: i32, channels: usize) -> Result<Self, &'static str> {
762        if ![8000, 12000, 16000, 24000, 48000].contains(&sampling_rate) {
763            return Err("Invalid sampling rate");
764        }
765        if ![1, 2].contains(&channels) {
766            return Err("Invalid number of channels");
767        }
768
769        let mode = modes::default_mode();
770        let celt_dec = CeltDecoder::new(mode, channels);
771
772        let mut silk_dec = silk::dec_api::SilkDecoder::new();
773        silk_dec.init(sampling_rate.min(16000), channels as i32);
774        silk_dec.channel_state[0].fs_api_hz = sampling_rate;
775
776        Ok(Self {
777            celt_dec,
778            silk_dec,
779            sampling_rate,
780            channels,
781            prev_mode: None,
782            frame_size: 0,
783            bandwidth: Bandwidth::Auto,
784            stream_channels: channels,
785            silk_resampler: silk::resampler::SilkResampler::default(),
786            prev_internal_rate: 0,
787            hybrid_skip_celt: false,
788
789            w_pcm_i16: vec![0i16; 640],
790
791            w_silk_out: vec![0.0f32; 5760 * channels],
792            w_pcm_resampled: vec![0i16; 5760 * channels],
793            w_celt_planar: vec![0.0f32; 5760 * channels],
794            w_celt_out: vec![0.0f32; 5760 * channels],
795        })
796    }
797
798    pub fn decode(
799        &mut self,
800        input: &[u8],
801        frame_size: usize,
802        output: &mut [f32],
803    ) -> Result<usize, &'static str> {
804        if input.is_empty() {
805            return Err("Input packet empty");
806        }
807
808        let toc = input[0];
809        let mode = mode_from_toc(toc);
810        let packet_channels = channels_from_toc(toc);
811        let bandwidth = bandwidth_from_toc(toc);
812        let frame_duration_ms = frame_duration_ms_from_toc(toc);
813
814        if packet_channels != self.channels {
815            return Err("Channel count mismatch between packet and decoder");
816        }
817
818        let code = toc & 0x03;
819        let frame_count: usize;
820        let frame_payloads: Vec<&[u8]>;
821
822        match code {
823            0 => {
824                frame_count = 1;
825                frame_payloads = vec![&input[1..]];
826            }
827            1 => {
828                frame_count = 2;
829                let half = (input.len() - 1) / 2;
830                if half == 0 {
831                    return Err("Code 1: empty frame");
832                }
833                frame_payloads = vec![&input[1..1 + half], &input[1 + half..]];
834            }
835            2 => {
836                frame_count = 2;
837                let data = &input[1..];
838                if data.is_empty() {
839                    return Err("Code 2 packet has no data");
840                }
841                let (first_len, header_size) = if data[0] & 0x80 != 0 {
842                    if data.len() < 2 {
843                        return Err("Code 2 packet too short for 2-byte length");
844                    }
845                    (((data[0] & 0x7F) as usize) << 8 | data[1] as usize, 2)
846                } else {
847                    (data[0] as usize, 1)
848                };
849                if header_size + first_len > data.len() {
850                    return Err("Code 2: first frame size exceeds packet");
851                }
852                frame_payloads = vec![
853                    &data[header_size..header_size + first_len],
854                    &data[header_size + first_len..],
855                ];
856            }
857            3 => {
858                if input.len() < 2 {
859                    return Err("Code 3 packet too short");
860                }
861                let count_byte = input[1];
862                let n_frames = (count_byte & 0x3F) as usize;
863                if n_frames < 1 || n_frames > 48 {
864                    return Err("Code 3: invalid frame count");
865                }
866                frame_count = n_frames;
867                let padding_flag = (count_byte & 0x40) != 0;
868
869                if padding_flag {
870                    let mut ptr = 2usize;
871                    let mut pad_len = 0usize;
872                    loop {
873                        if ptr >= input.len() {
874                            return Err("Padding overflow");
875                        }
876                        let p = input[ptr] as usize;
877                        ptr += 1;
878                        if p == 255 {
879                            pad_len += 254;
880                        } else {
881                            pad_len += p;
882                            break;
883                        }
884                    }
885
886                    let end = input.len().saturating_sub(pad_len);
887                    if ptr > end {
888                        return Err("Padding exceeds packet");
889                    }
890                    let compressed = &input[ptr..end];
891                    let frame_len = compressed.len() / frame_count;
892                    if frame_len == 0 {
893                        return Err("Code 3 with padding: empty frame");
894                    }
895                    frame_payloads = compressed.chunks(frame_len).collect();
896                    if frame_payloads.len() != frame_count {
897                        return Err("Code 3: frame count mismatch");
898                    }
899                } else {
900                    // Self-delimiting format:
901                    // - Single frame: no length prefix, remaining data is the frame
902                    // - Multi-frame: lengths for all frames except the last
903                    let mut payload_ptr = 2usize;
904                    if frame_count == 1 {
905                        frame_payloads = vec![&input[payload_ptr..]];
906                    } else {
907                        let mut payloads = Vec::with_capacity(frame_count);
908                        for _f in 0..frame_count - 1 {
909                            if payload_ptr >= input.len() {
910                                return Err("Code 3: unexpected end in self-delimiting header");
911                            }
912                            let (frame_len, header_bytes) = if input[payload_ptr] & 0x80 != 0 {
913                                if payload_ptr + 2 > input.len() {
914                                    return Err("Code 3: short frame length");
915                                }
916                                (
917                                    ((input[payload_ptr] & 0x7F) as usize) << 8
918                                        | input[payload_ptr + 1] as usize,
919                                    2,
920                                )
921                            } else {
922                                (input[payload_ptr] as usize, 1)
923                            };
924                            payload_ptr += header_bytes;
925                            if payload_ptr + frame_len > input.len() {
926                                return Err("Code 3: frame length exceeds packet");
927                            }
928                            payloads.push(&input[payload_ptr..payload_ptr + frame_len]);
929                            payload_ptr += frame_len;
930                        }
931                        // Last frame: no length prefix, remaining data
932                        if payload_ptr > input.len() {
933                            return Err("Code 3: no data for last frame");
934                        }
935                        payloads.push(&input[payload_ptr..]);
936                        frame_payloads = payloads;
937                    }
938                }
939            }
940            _ => unreachable!(),
941        }
942
943        self.frame_size = frame_size;
944        self.bandwidth = bandwidth;
945        self.stream_channels = packet_channels;
946
947        let sub_frame_size = frame_size / frame_count;
948        let sub_output_len = sub_frame_size * self.channels;
949
950        match mode {
951            OpusMode::SilkOnly => {
952                let internal_sample_rate = match bandwidth {
953                    Bandwidth::Narrowband => 8000,
954                    Bandwidth::Mediumband => 12000,
955                    Bandwidth::Wideband => 16000,
956                    _ => 16000,
957                };
958                let internal_frame_size =
959                    (frame_duration_ms * internal_sample_rate / 1000) as usize;
960
961                if self.sampling_rate != internal_sample_rate
962                    && internal_sample_rate != self.prev_internal_rate
963                {
964                    self.silk_resampler
965                        .init(internal_sample_rate, self.sampling_rate);
966                    self.prev_internal_rate = internal_sample_rate;
967                }
968
969                for (fi, payload) in frame_payloads.iter().enumerate() {
970                    let mut rc = RangeCoder::new_decoder(payload);
971                    let pcm_i16_len = internal_frame_size * self.channels;
972                    debug_assert!(pcm_i16_len <= self.w_pcm_i16.len());
973
974                    let ret = {
975                        let (silk_dec, pcm_i16) = (&mut self.silk_dec, &mut self.w_pcm_i16);
976                        silk_dec.decode(
977                            &mut rc,
978                            &mut pcm_i16[..pcm_i16_len],
979                            silk::decode_frame::FLAG_DECODE_NORMAL,
980                            true,
981                            frame_duration_ms,
982                            internal_sample_rate,
983                        )
984                    };
985
986                    if ret < 0 {
987                        return Err("SILK decoding failed");
988                    }
989
990                    let decoded_samples = ret as usize;
991                    let out_start = fi * sub_output_len;
992
993                    // SILK only decodes channel 0 (mono). For multi-channel output,
994                    // replicate the mono samples to every channel.
995                    if self.sampling_rate == internal_sample_rate {
996                        let frames = decoded_samples.min(sub_frame_size);
997                        for i in 0..frames {
998                            let v = self.w_pcm_i16[i] as f32 / 32768.0;
999                            for ch in 0..self.channels {
1000                                let idx = out_start + i * self.channels + ch;
1001                                if idx < output.len() {
1002                                    output[idx] = v;
1003                                }
1004                            }
1005                        }
1006                    } else {
1007                        let ratio = self.sampling_rate as f64 / internal_sample_rate as f64;
1008                        let out_len =
1009                            ((decoded_samples as f64 * ratio) as usize).min(sub_frame_size);
1010                        debug_assert!(out_len <= self.w_pcm_resampled.len());
1011                        {
1012                            let (silk_res, pcm_i16, pcm_out) = (
1013                                &mut self.silk_resampler,
1014                                &self.w_pcm_i16,
1015                                &mut self.w_pcm_resampled,
1016                            );
1017                            silk_res.process(
1018                                &mut pcm_out[..out_len],
1019                                &pcm_i16[..decoded_samples],
1020                                decoded_samples as i32,
1021                            );
1022                        }
1023                        let frames = out_len.min(sub_frame_size);
1024                        for i in 0..frames {
1025                            let v = self.w_pcm_resampled[i] as f32 / 32768.0;
1026                            for ch in 0..self.channels {
1027                                let idx = out_start + i * self.channels + ch;
1028                                if idx < output.len() {
1029                                    output[idx] = v;
1030                                }
1031                            }
1032                        }
1033                    }
1034                }
1035                self.prev_mode = Some(OpusMode::SilkOnly);
1036                Ok(frame_size)
1037            }
1038
1039            OpusMode::CeltOnly => {
1040                let celt_end_band = self.celt_end_band_from_toc(toc);
1041
1042                for (fi, payload) in frame_payloads.iter().enumerate() {
1043                    let mut rc = RangeCoder::new_decoder(payload);
1044                    let total_bits = (payload.len() * 8) as i32;
1045                    let needed = sub_frame_size * self.channels;
1046                    let out_start = fi * needed;
1047                    let out_end = (out_start + needed).min(output.len());
1048
1049                    if output.len() < out_end {
1050                        return Err("Output buffer too small");
1051                    }
1052
1053                    if self.channels == 1 {
1054                        self.celt_dec.decode_from_range_coder_with_band_range(
1055                            &mut rc,
1056                            total_bits,
1057                            sub_frame_size,
1058                            &mut output[out_start..out_end],
1059                            0,
1060                            celt_end_band,
1061                        );
1062                        for sample in &mut output[out_start..out_end] {
1063                            *sample = sample.clamp(-1.0, 1.0);
1064                        }
1065                    } else {
1066                        self.celt_dec.decode_from_range_coder_with_band_range(
1067                            &mut rc,
1068                            total_bits,
1069                            sub_frame_size,
1070                            &mut self.w_celt_planar[..needed],
1071                            0,
1072                            celt_end_band,
1073                        );
1074                        for i in 0..sub_frame_size {
1075                            for ch in 0..self.channels {
1076                                let idx = out_start + i * self.channels + ch;
1077                                output[idx] =
1078                                    self.w_celt_planar[ch * sub_frame_size + i].clamp(-1.0, 1.0);
1079                            }
1080                        }
1081                    }
1082                }
1083                self.prev_mode = Some(OpusMode::CeltOnly);
1084                Ok(frame_size)
1085            }
1086
1087            OpusMode::Hybrid => {
1088                let internal_sample_rate = 16000;
1089                let internal_frame_size =
1090                    (frame_duration_ms * internal_sample_rate / 1000) as usize;
1091                let celt_end_band = self.celt_end_band_from_toc(toc);
1092
1093                if self.sampling_rate != internal_sample_rate
1094                    && internal_sample_rate != self.prev_internal_rate
1095                {
1096                    self.silk_resampler
1097                        .init(internal_sample_rate, self.sampling_rate);
1098                    self.prev_internal_rate = internal_sample_rate;
1099                }
1100
1101                for (fi, payload) in frame_payloads.iter().enumerate() {
1102                    let mut rc = RangeCoder::new_decoder(payload);
1103                    let pcm_silk_i16_len = internal_frame_size * self.channels;
1104                    debug_assert!(pcm_silk_i16_len <= self.w_pcm_i16.len());
1105
1106                    let ret = {
1107                        let (silk_dec, pcm_i16) = (&mut self.silk_dec, &mut self.w_pcm_i16);
1108                        silk_dec.decode(
1109                            &mut rc,
1110                            &mut pcm_i16[..pcm_silk_i16_len],
1111                            silk::decode_frame::FLAG_DECODE_NORMAL,
1112                            true,
1113                            frame_duration_ms,
1114                            internal_sample_rate,
1115                        )
1116                    };
1117
1118                    if ret < 0 {
1119                        return Err("SILK decoding failed");
1120                    }
1121
1122                    let silk_out_len = sub_frame_size * self.channels;
1123                    self.w_silk_out[..silk_out_len].fill(0.0);
1124                    if ret > 0 {
1125                        let decoded_samples = ret as usize;
1126                        // SILK only decodes channel 0 (mono). For multi-channel output,
1127                        // replicate the mono samples to every channel in w_silk_out.
1128                        if self.sampling_rate == internal_sample_rate {
1129                            let frames = decoded_samples.min(sub_frame_size);
1130                            for i in 0..frames {
1131                                let v = self.w_pcm_i16[i] as f32 / 32768.0;
1132                                for ch in 0..self.channels {
1133                                    let idx = i * self.channels + ch;
1134                                    if idx < silk_out_len {
1135                                        self.w_silk_out[idx] = v;
1136                                    }
1137                                }
1138                            }
1139                        } else {
1140                            let ratio = self.sampling_rate as f64 / internal_sample_rate as f64;
1141                            let out_len =
1142                                ((decoded_samples as f64 * ratio) as usize).min(sub_frame_size);
1143                            debug_assert!(out_len <= self.w_pcm_resampled.len());
1144                            {
1145                                let (silk_res, pcm_i16, pcm_resampled) = (
1146                                    &mut self.silk_resampler,
1147                                    &self.w_pcm_i16,
1148                                    &mut self.w_pcm_resampled,
1149                                );
1150                                silk_res.process(
1151                                    &mut pcm_resampled[..out_len],
1152                                    &pcm_i16[..decoded_samples],
1153                                    decoded_samples as i32,
1154                                );
1155                            }
1156                            let frames = out_len.min(sub_frame_size);
1157                            for i in 0..frames {
1158                                let v = self.w_pcm_resampled[i] as f32 / 32768.0;
1159                                for ch in 0..self.channels {
1160                                    let idx = i * self.channels + ch;
1161                                    if idx < silk_out_len {
1162                                        self.w_silk_out[idx] = v;
1163                                    }
1164                                }
1165                            }
1166                        }
1167                    }
1168
1169                    let total_bits = (payload.len() * 8) as i32;
1170                    let redundancy = rc.decode_bit_logp(12);
1171                    let skip_celt = if redundancy {
1172                        let _ = rc.decode_bit_logp(1);
1173                        true
1174                    } else {
1175                        false
1176                    };
1177
1178                    if skip_celt {
1179                        self.w_celt_out[..silk_out_len].fill(0.0);
1180                    } else {
1181                        let (celt_dec, celt_planar) = (&mut self.celt_dec, &mut self.w_celt_planar);
1182                        celt_dec.decode_from_range_coder_with_band_range(
1183                            &mut rc,
1184                            total_bits,
1185                            sub_frame_size,
1186                            &mut celt_planar[..silk_out_len],
1187                            17,
1188                            celt_end_band,
1189                        );
1190
1191                        if self.channels == 1 {
1192                            self.w_celt_out[..silk_out_len]
1193                                .copy_from_slice(&self.w_celt_planar[..silk_out_len]);
1194                        } else {
1195                            for i in 0..sub_frame_size {
1196                                for ch in 0..self.channels {
1197                                    self.w_celt_out[i * self.channels + ch] =
1198                                        self.w_celt_planar[ch * sub_frame_size + i];
1199                                }
1200                            }
1201                        }
1202                    }
1203
1204                    let out_start = fi * silk_out_len;
1205                    let total = silk_out_len.min(output.len() - out_start);
1206                    for j in 0..total {
1207                        output[out_start + j] =
1208                            (self.w_silk_out[j] + self.w_celt_out[j]).clamp(-1.0, 1.0);
1209                    }
1210                }
1211                self.prev_mode = Some(OpusMode::Hybrid);
1212                Ok(frame_size)
1213            }
1214        }
1215    }
1216}
1217
1218impl OpusDecoder {
1219    #[inline(always)]
1220    fn celt_end_band_from_toc(&self, toc: u8) -> usize {
1221        let mode = modes::default_mode();
1222        let top = mode.eff_ebands;
1223        if mode_from_toc(toc) == OpusMode::CeltOnly && toc >= 0x80 {
1224            const FROM_OPUS_TABLE: [u8; 16] = [
1225                0x80, 0x88, 0x90, 0x98, 0x40, 0x48, 0x50, 0x58, 0x20, 0x28, 0x30, 0x38, 0x00, 0x08,
1226                0x10, 0x18,
1227            ];
1228            let idx = ((toc >> 3) - 16) as usize;
1229            let data0 = FROM_OPUS_TABLE[idx] | (toc & 0x7);
1230            let trim = (data0 >> 5) as usize;
1231            return top.saturating_sub(2 * trim).max(1);
1232        }
1233        top
1234    }
1235}
1236
1237fn frame_rate_from_params(sampling_rate: i32, frame_size: usize) -> Option<i32> {
1238    let frame_size = frame_size as i32;
1239    if frame_size == 0 || sampling_rate % frame_size != 0 {
1240        return None;
1241    }
1242    Some(sampling_rate / frame_size)
1243}
1244
1245fn gen_toc(mode: OpusMode, frame_rate: i32, bandwidth: Bandwidth, channels: usize) -> u8 {
1246    let mut rate = frame_rate;
1247    let mut period = 0;
1248    while rate < 400 {
1249        rate <<= 1;
1250        period += 1;
1251    }
1252
1253    let mut toc = match mode {
1254        OpusMode::SilkOnly => {
1255            let bw = (bandwidth as i32 - Bandwidth::Narrowband as i32) << 5;
1256            let per = (period - 2) << 3;
1257            (bw | per) as u8
1258        }
1259        OpusMode::CeltOnly => {
1260            let mut tmp = bandwidth as i32 - Bandwidth::Mediumband as i32;
1261            if tmp < 0 {
1262                tmp = 0;
1263            }
1264            let per = period << 3;
1265            (0x80 | (tmp << 5) | per) as u8
1266        }
1267        OpusMode::Hybrid => {
1268            let base_config = if bandwidth == Bandwidth::Superwideband {
1269                12
1270            } else {
1271                14
1272            };
1273            let period_offset = if frame_rate >= 100 { 0 } else { 1 };
1274            ((base_config + period_offset) << 3) as u8
1275        }
1276    };
1277
1278    if channels == 2 {
1279        toc |= 0x04;
1280    }
1281    toc
1282}
1283
1284fn mode_from_toc(toc: u8) -> OpusMode {
1285    if toc & 0x80 != 0 {
1286        OpusMode::CeltOnly
1287    } else if toc & 0x60 == 0x60 {
1288        OpusMode::Hybrid
1289    } else {
1290        OpusMode::SilkOnly
1291    }
1292}
1293
1294fn bandwidth_from_toc(toc: u8) -> Bandwidth {
1295    let mode = mode_from_toc(toc);
1296    match mode {
1297        OpusMode::SilkOnly => {
1298            let bw_bits = (toc >> 5) & 0x03;
1299            match bw_bits {
1300                0 => Bandwidth::Narrowband,
1301                1 => Bandwidth::Mediumband,
1302                2 => Bandwidth::Wideband,
1303                _ => Bandwidth::Wideband,
1304            }
1305        }
1306        OpusMode::Hybrid => {
1307            let bw_bit = (toc >> 4) & 0x01;
1308            if bw_bit == 0 {
1309                Bandwidth::Superwideband
1310            } else {
1311                Bandwidth::Fullband
1312            }
1313        }
1314        OpusMode::CeltOnly => {
1315            let bw_bits = (toc >> 5) & 0x03;
1316            match bw_bits {
1317                0 => Bandwidth::Mediumband,
1318                1 => Bandwidth::Wideband,
1319                2 => Bandwidth::Superwideband,
1320                3 => Bandwidth::Fullband,
1321                _ => Bandwidth::Fullband,
1322            }
1323        }
1324    }
1325}
1326
1327fn frame_duration_ms_from_toc(toc: u8) -> i32 {
1328    let mode = mode_from_toc(toc);
1329    match mode {
1330        OpusMode::SilkOnly => {
1331            let config = (toc >> 3) & 0x03;
1332            match config {
1333                0 => 10,
1334                1 => 20,
1335                2 => 40,
1336                3 => 60,
1337                _ => 20,
1338            }
1339        }
1340        OpusMode::Hybrid => {
1341            let config = (toc >> 3) & 0x01;
1342            if config == 0 { 10 } else { 20 }
1343        }
1344        OpusMode::CeltOnly => {
1345            let config = (toc >> 3) & 0x03;
1346            match config {
1347                0 => 2,
1348                1 => 5,
1349                2 => 10,
1350                3 => 20,
1351                _ => 20,
1352            }
1353        }
1354    }
1355}
1356
1357fn channels_from_toc(toc: u8) -> usize {
1358    if toc & 0x04 != 0 { 2 } else { 1 }
1359}
1360
1361#[cfg(test)]
1362mod tests {
1363    use super::*;
1364
1365    fn frame_size_from_toc(toc: u8, sampling_rate: i32) -> Option<usize> {
1366        let mode = mode_from_toc(toc);
1367        match mode {
1368            OpusMode::CeltOnly => {
1369                let period = ((toc >> 3) & 0x03) as i32;
1370                let frame_rate = 400 >> period;
1371                if frame_rate == 0 || sampling_rate % frame_rate != 0 {
1372                    return None;
1373                }
1374                Some((sampling_rate / frame_rate) as usize)
1375            }
1376            OpusMode::SilkOnly => {
1377                let duration_ms = frame_duration_ms_from_toc(toc);
1378                Some((sampling_rate as i64 * duration_ms as i64 / 1000) as usize)
1379            }
1380            OpusMode::Hybrid => {
1381                let duration_ms = frame_duration_ms_from_toc(toc);
1382                Some((sampling_rate as i64 * duration_ms as i64 / 1000) as usize)
1383            }
1384        }
1385    }
1386
1387    #[test]
1388    fn gen_toc_matches_celt_reference_values() {
1389        let sampling_rate = 48_000;
1390        let cases = [
1391            (120usize, 0xE0u8),
1392            (240usize, 0xE8u8),
1393            (480usize, 0xF0u8),
1394            (960usize, 0xF8u8),
1395        ];
1396
1397        for (frame_size, expected_toc) in cases {
1398            let frame_rate = frame_rate_from_params(sampling_rate, frame_size).unwrap();
1399            let toc = gen_toc(OpusMode::CeltOnly, frame_rate, Bandwidth::Fullband, 1);
1400            assert_eq!(
1401                toc, expected_toc,
1402                "frame_size {} expected TOC {:02X} got {:02X}",
1403                frame_size, expected_toc, toc
1404            );
1405            let decoded_size = frame_size_from_toc(toc, sampling_rate).unwrap();
1406            assert_eq!(decoded_size, frame_size);
1407        }
1408
1409        let stereo_toc = gen_toc(
1410            OpusMode::CeltOnly,
1411            frame_rate_from_params(sampling_rate, 960).unwrap(),
1412            Bandwidth::Fullband,
1413            2,
1414        );
1415        assert_eq!(channels_from_toc(stereo_toc), 2);
1416    }
1417
1418    #[test]
1419    fn test_celt_decoder_large_frame_sizes() {
1420        let sampling_rate = 48000;
1421        let channels = 1;
1422
1423        let mut decoder = OpusDecoder::new(sampling_rate, channels).unwrap();
1424
1425        let frame_sizes = [120, 240, 480, 960];
1426
1427        for frame_size in frame_sizes {
1428            let toc = gen_toc(
1429                OpusMode::CeltOnly,
1430                frame_rate_from_params(sampling_rate, frame_size).unwrap(),
1431                Bandwidth::Fullband,
1432                channels,
1433            );
1434            let packet = [toc, 0, 0, 0, 0];
1435
1436            let mut output = vec![0.0f32; frame_size * channels];
1437
1438            let _ = decoder.decode(&packet, frame_size, &mut output);
1439        }
1440
1441        let channels = 2;
1442        let mut decoder = OpusDecoder::new(sampling_rate, channels).unwrap();
1443
1444        for frame_size in frame_sizes {
1445            let toc = gen_toc(
1446                OpusMode::CeltOnly,
1447                frame_rate_from_params(sampling_rate, frame_size).unwrap(),
1448                Bandwidth::Fullband,
1449                channels,
1450            );
1451            let packet = [toc, 0, 0, 0, 0];
1452
1453            let mut output = vec![0.0f32; frame_size * channels];
1454            let _ = decoder.decode(&packet, frame_size, &mut output);
1455        }
1456    }
1457
1458    #[test]
1459    fn test_celt_decoder_edge_case_frame_sizes() {
1460        let sampling_rate = 48000;
1461        let channels = 1;
1462        let mut decoder = OpusDecoder::new(sampling_rate, channels).unwrap();
1463
1464        let edge_sizes = [2048, 2167, 2168, 2169, 2880, 3072];
1465
1466        for frame_size in edge_sizes {
1467            let mut output = vec![0.0f32; frame_size * channels];
1468
1469            let _ = decoder.decode(&[0x80, 0, 0, 0], frame_size, &mut output);
1470        }
1471    }
1472
1473    // Regression test for: "index out of bounds: the len is 48 but the index is 119"
1474    // Root cause: frame_size=48 at 48kHz gives frame_rate=1000, which is not a valid
1475    // Hybrid-mode frame rate but was not validated.  CELT's lm-search then silently
1476    // fell back to lm=0, computed n2=120, and wrote output[119] into a 48-element
1477    // slice.  Triggered via G.729-decoded PCM (8kHz) passed to a 48kHz Opus encoder
1478    // without proper resampling, so the encoder received 48 samples instead of 480.
1479    #[test]
1480    fn test_invalid_small_frame_size_returns_error_not_panic() {
1481        let mut enc = OpusEncoder::new(48000, 2, Application::Voip).unwrap();
1482        enc.bitrate_bps = 64000;
1483        enc.complexity = 5;
1484        enc.use_cbr = true;
1485
1486        // 48 samples at 48kHz = 1ms → frame_rate=1000, invalid for Hybrid mode.
1487        let input = vec![0.0f32; 48 * 2]; // stereo interleaved
1488        let mut output = vec![0u8; 256];
1489
1490        let result = enc.encode(&input, 48, &mut output);
1491        assert!(
1492            result.is_err(),
1493            "encode with invalid frame_size=48 should return Err, not panic"
1494        );
1495    }
1496
1497    // Also verify that the Audio application path (always Hybrid at 48 kHz) rejects
1498    // the same bad frame size.
1499    #[test]
1500    fn test_invalid_small_frame_size_audio_application_returns_error() {
1501        let mut enc = OpusEncoder::new(48000, 1, Application::Audio).unwrap();
1502        let input = vec![0.0f32; 48];
1503        let mut output = vec![0u8; 256];
1504
1505        let result = enc.encode(&input, 48, &mut output);
1506        assert!(
1507            result.is_err(),
1508            "Audio/48kHz encoder with frame_size=48 should return Err"
1509        );
1510    }
1511}