Skip to main content

opus_rs/
lib.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2#![allow(clippy::too_many_arguments)]
3#![allow(clippy::needless_range_loop)]
4
5pub mod bands;
6pub mod celt;
7pub mod celt_lpc;
8pub mod hp_cutoff;
9pub mod kiss_fft;
10pub mod mdct;
11pub mod modes;
12pub mod pitch;
13pub mod pvq;
14pub mod quant_bands;
15pub mod range_coder;
16pub mod rate;
17pub mod silk;
18
19pub use silk::{SilkResampler, SilkResamplerDown1_3, SilkResamplerDown1_6};
20
21pub use celt::{CeltDecoder, CeltEncoder};
22use hp_cutoff::hp_cutoff;
23use range_coder::RangeCoder;
24use silk::control_codec::silk_control_encoder;
25use silk::enc_api::silk_encode;
26use silk::init_encoder::silk_init_encoder;
27use silk::lin2log::silk_lin2log;
28use silk::log2lin::silk_log2lin;
29use silk::macros::*;
30use silk::resampler::{silk_resampler_down2, silk_resampler_down2_3};
31use silk::structs::SilkEncoderState;
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum Application {
35    Voip = 2048,
36    Audio = 2049,
37    RestrictedLowDelay = 2051,
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum Bandwidth {
42    Auto = -1000,
43    Narrowband = 1101,
44    Mediumband = 1102,
45    Wideband = 1103,
46    Superwideband = 1104,
47    Fullband = 1105,
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51enum OpusMode {
52    SilkOnly,
53    Hybrid,
54    CeltOnly,
55}
56
57pub struct OpusEncoder {
58    celt_enc: CeltEncoder,
59    silk_enc: Box<SilkEncoderState>,
60    application: Application,
61    sampling_rate: i32,
62    channels: usize,
63    bandwidth: Bandwidth,
64    pub bitrate_bps: i32,
65    pub complexity: i32,
66    pub use_cbr: bool,
67
68    pub use_inband_fec: bool,
69
70    pub packet_loss_perc: i32,
71    silk_initialized: bool,
72    mode: OpusMode,
73    prev_enc_mode: Option<OpusMode>,
74
75    variable_hp_smth2_q15: i32,
76    hp_mem: Vec<i32>,
77
78    buf_filtered: Vec<i16>,
79    buf_silk_input: Vec<i16>,
80    buf_stereo_mid: Vec<i16>,
81    buf_stereo_side: Vec<i16>,
82    buf_celt_input: Vec<f32>,
83    down2_state_first: [i32; 2],
84    down2_state_second: [i32; 2],
85    down2_3_state: [i32; 6],
86    down_1_3_state: silk::resampler::SilkResamplerDown1_3,
87
88    rc: RangeCoder,
89}
90
91fn compute_equiv_rate(
92    bitrate: i32,
93    channels: usize,
94    frame_rate: i32,
95    vbr: bool,
96    complexity: i32,
97    loss: i32,
98) -> i32 {
99    let mut equiv = bitrate;
100    if frame_rate > 50 {
101        equiv -= (40 * channels as i32 + 20) * (frame_rate - 50);
102    }
103    if !vbr {
104        equiv -= equiv / 12;
105    }
106    equiv = equiv * (90 + complexity) / 100;
107    if loss > 0 {
108        equiv -= equiv * loss / (12 * loss + 20);
109    }
110    equiv
111}
112
113fn compute_mode_threshold(application: Application, channels: usize, prev_was_celt: bool) -> i32 {
114    let mode_voice = if channels == 1 { 64000 } else { 44000 };
115    let mode_music = 10000;
116
117    let offset = (mode_voice - mode_music) >> 14;
118    let mut threshold = mode_music + offset;
119
120    if application == Application::Voip {
121        threshold += 8000;
122    }
123
124    if prev_was_celt {
125        threshold -= 4000;
126    } else {
127        threshold += 4000;
128    }
129
130    match application {
131        Application::Audio => threshold = threshold.max(55000),
132        Application::Voip => threshold = threshold.max(55000),
133        Application::RestrictedLowDelay => threshold = 0,
134    }
135
136    threshold
137}
138
139fn compute_silk_rate_for_hybrid(rate_bps: i32, frame20ms: bool) -> i32 {
140    const RATE_TABLE: &[(i32, i32, i32)] = &[
141        (0, 0, 0),
142        (12000, 10000, 10000),
143        (16000, 13500, 13500),
144        (20000, 16000, 16000),
145        (24000, 18000, 18000),
146        (32000, 22000, 22000),
147        (64000, 38000, 38000),
148    ];
149    let n = RATE_TABLE.len();
150    let mut i = 1;
151    while i < n && RATE_TABLE[i].0 <= rate_bps {
152        i += 1;
153    }
154    if i == n {
155        let (x_last, r10_last, r20_last) = RATE_TABLE[n - 1];
156        let base = if frame20ms { r20_last } else { r10_last };
157        base + (rate_bps - x_last) / 2
158    } else {
159        let (x0, lo10, lo20) = RATE_TABLE[i - 1];
160        let (x1, hi10, hi20) = RATE_TABLE[i];
161        let (lo, hi) = if frame20ms {
162            (lo20, hi20)
163        } else {
164            (lo10, hi10)
165        };
166        (lo * (x1 - rate_bps) + hi * (rate_bps - x0)) / (x1 - x0)
167    }
168}
169
170#[cfg(test)]
171mod silk_rate_tests {
172    use super::compute_silk_rate_for_hybrid;
173
174    #[test]
175    fn test_reference_table_exact_entries() {
176        assert_eq!(compute_silk_rate_for_hybrid(12000, true), 10000);
177        assert_eq!(compute_silk_rate_for_hybrid(16000, true), 13500);
178        assert_eq!(compute_silk_rate_for_hybrid(20000, true), 16000);
179        assert_eq!(compute_silk_rate_for_hybrid(24000, true), 18000);
180        assert_eq!(compute_silk_rate_for_hybrid(32000, true), 22000);
181        assert_eq!(compute_silk_rate_for_hybrid(64000, true), 38000);
182    }
183
184    #[test]
185    fn test_32kbps_gives_22kbps_silk() {
186        assert_eq!(compute_silk_rate_for_hybrid(32000, true), 22000);
187    }
188
189    #[test]
190    fn test_interpolation_between_table_entries() {
191        let r = compute_silk_rate_for_hybrid(18000, true);
192        assert_eq!(r, 14750);
193    }
194
195    #[test]
196    fn test_above_table_max_gives_half_extra() {
197        let r = compute_silk_rate_for_hybrid(72000, true);
198        assert_eq!(r, 38000 + (72000 - 64000) / 2);
199    }
200}
201
202impl OpusEncoder {
203    pub fn new(
204        sampling_rate: i32,
205        channels: usize,
206        application: Application,
207    ) -> Result<Self, &'static str> {
208        if ![8000, 12000, 16000, 24000, 48000].contains(&sampling_rate) {
209            return Err("Invalid sampling rate");
210        }
211        if ![1, 2].contains(&channels) {
212            return Err("Invalid number of channels");
213        }
214
215        let mode = modes::default_mode();
216        let celt_enc = CeltEncoder::new(mode, channels);
217
218        let mut silk_enc = Box::new(SilkEncoderState::default());
219        if silk_init_encoder(&mut silk_enc, 0) != 0 {
220            return Err("SILK encoder initialization failed");
221        }
222
223        let (opus_mode, bw) = match application {
224            Application::Voip => {
225                let bw = match sampling_rate {
226                    8000 => Bandwidth::Narrowband,
227                    12000 => Bandwidth::Mediumband,
228                    16000 => Bandwidth::Wideband,
229                    24000 => Bandwidth::Superwideband,
230                    48000 => Bandwidth::Fullband,
231                    _ => Bandwidth::Narrowband,
232                };
233
234                let mode = if sampling_rate > 16000 {
235                    OpusMode::Hybrid
236                } else {
237                    OpusMode::SilkOnly
238                };
239                (mode, bw)
240            }
241            Application::RestrictedLowDelay => {
242                let bw = match sampling_rate {
243                    8000 => Bandwidth::Narrowband,
244                    12000 => Bandwidth::Mediumband,
245                    16000 => Bandwidth::Wideband,
246                    24000 => Bandwidth::Superwideband,
247                    _ => Bandwidth::Fullband,
248                };
249                (OpusMode::CeltOnly, bw)
250            }
251            Application::Audio => {
252                if sampling_rate <= 16000 {
253                    let bw = match sampling_rate {
254                        8000 => Bandwidth::Narrowband,
255                        12000 => Bandwidth::Mediumband,
256                        _ => Bandwidth::Wideband,
257                    };
258                    (OpusMode::SilkOnly, bw)
259                } else {
260                    let bw = match sampling_rate {
261                        24000 => Bandwidth::Superwideband,
262                        _ => Bandwidth::Fullband,
263                    };
264                    (OpusMode::Hybrid, bw)
265                }
266            }
267        };
268
269        use silk::lin2log::silk_lin2log;
270        let variable_hp_smth2_q15 = silk_lin2log(60) << 8;
271
272        Ok(Self {
273            celt_enc,
274            silk_enc,
275            application,
276            sampling_rate,
277            channels,
278            bandwidth: bw,
279            bitrate_bps: 64000,
280            complexity: 9,
281            use_cbr: false,
282            use_inband_fec: false,
283            packet_loss_perc: 0,
284            silk_initialized: false,
285            prev_enc_mode: None,
286            mode: opus_mode,
287            variable_hp_smth2_q15,
288            hp_mem: vec![0; channels * 2],
289
290            buf_filtered: Vec::new(),
291            buf_silk_input: Vec::new(),
292            buf_stereo_mid: Vec::new(),
293            buf_stereo_side: Vec::new(),
294            buf_celt_input: Vec::new(),
295            down2_state_first: [0; 2],
296            down2_state_second: [0; 2],
297            down2_3_state: [0; 6],
298            down_1_3_state: silk::resampler::SilkResamplerDown1_3::default(),
299            rc: RangeCoder::new_encoder(1),
300        })
301    }
302
303    pub fn enable_hybrid_mode(&mut self) -> Result<(), &'static str> {
304        if self.sampling_rate != 24000 && self.sampling_rate != 48000 {
305            return Err("Hybrid mode requires 24kHz or 48kHz sampling rate");
306        }
307        let bw = if self.sampling_rate == 48000 {
308            Bandwidth::Fullband
309        } else {
310            Bandwidth::Superwideband
311        };
312        self.mode = OpusMode::Hybrid;
313        self.bandwidth = bw;
314        self.silk_initialized = false;
315        Ok(())
316    }
317
318    pub fn encode(
319        &mut self,
320        input: &[f32],
321        frame_size: usize,
322        output: &mut [u8],
323    ) -> Result<usize, &'static str> {
324        if output.len() < 2 {
325            return Err("Output buffer too small");
326        }
327
328        let frame_rate = frame_rate_from_params(self.sampling_rate, frame_size)
329            .ok_or("Invalid frame size for sampling rate")?;
330
331        // Mode selection: match C's opus_encode_native() behavior.
332        // C reference auto-selects between SILK_ONLY and CELT_ONLY; Hybrid is
333        // produced afterwards by bandwidth overrides (SILK-only + FB/SWB → Hybrid).
334        let mut mode = if self.application == Application::RestrictedLowDelay {
335            OpusMode::CeltOnly
336        } else {
337            let equiv = compute_equiv_rate(
338                self.bitrate_bps,
339                self.channels,
340                frame_rate,
341                !self.use_cbr,
342                self.complexity,
343                self.packet_loss_perc,
344            );
345            let prev_was_celt = self.prev_enc_mode == Some(OpusMode::CeltOnly);
346            let threshold = compute_mode_threshold(self.application, self.channels, prev_was_celt);
347            if equiv >= threshold {
348                OpusMode::CeltOnly
349            } else {
350                OpusMode::SilkOnly
351            }
352        };
353
354        let curr_bw = self.bandwidth;
355        if mode == OpusMode::SilkOnly
356            && (curr_bw == Bandwidth::Superwideband || curr_bw == Bandwidth::Fullband)
357        {
358            mode = OpusMode::Hybrid;
359        }
360        if mode == OpusMode::Hybrid
361            && (curr_bw == Bandwidth::Narrowband
362                || curr_bw == Bandwidth::Mediumband
363                || curr_bw == Bandwidth::Wideband)
364        {
365            mode = OpusMode::SilkOnly;
366        }
367
368        if mode == OpusMode::CeltOnly {
369            match frame_rate {
370                400 | 200 | 100 | 50 => {}
371                _ => return Err("Unsupported frame size for CELT-only mode"),
372            }
373        }
374
375        let toc = gen_toc(mode, frame_rate, self.bandwidth, self.channels);
376        output[0] = toc;
377
378        let target_bits =
379            (self.bitrate_bps as i64 * frame_size as i64 / self.sampling_rate as i64) as i32;
380        let cbr_bytes = ((target_bits + 4) / 8) as usize;
381        let max_data_bytes = output.len();
382
383        let n_bytes = cbr_bytes.min(max_data_bytes).max(1);
384
385        // Initialize range coder with CBR-capped buffer size, matching C's
386        // ec_enc_init(&enc, data, max_data_bytes-1) where max_data_bytes is CBR-capped.
387        let init_rc_size = n_bytes - 1;
388        self.rc.reset_for_encode(init_rc_size as u32);
389
390        if mode == OpusMode::SilkOnly || mode == OpusMode::Hybrid {
391            let silk_fs_khz = if mode == OpusMode::Hybrid {
392                16
393            } else {
394                self.sampling_rate.min(16000) / 1000
395            };
396
397            let frame_ms = (frame_size as i32 * 1000) / self.sampling_rate;
398            if !self.silk_initialized || self.silk_enc.s_cmn.fs_khz != silk_fs_khz {
399                let silk_init_bitrate = (((n_bytes - 1) * 8) as i64 * self.sampling_rate as i64
400                    / frame_size as i64) as i32;
401                silk_control_encoder(
402                    &mut self.silk_enc,
403                    silk_fs_khz,
404                    frame_ms,
405                    silk_init_bitrate,
406                    self.complexity,
407                );
408                self.silk_enc.s_cmn.use_cbr = if self.use_cbr { 1 } else { 0 };
409
410                self.silk_enc.s_cmn.n_channels = self.channels as i32;
411                self.silk_initialized = true;
412                self.down2_state_first = [0; 2];
413                self.down2_state_second = [0; 2];
414                self.down2_3_state = [0; 6];
415                self.down_1_3_state = silk::resampler::SilkResamplerDown1_3::default();
416            }
417
418            self.silk_enc.s_cmn.use_in_band_fec = if self.use_inband_fec { 1 } else { 0 };
419            self.silk_enc.s_cmn.packet_loss_perc = self.packet_loss_perc.clamp(0, 100);
420
421            self.silk_enc.s_cmn.lbrr_enabled = if self.use_inband_fec { 1 } else { 0 };
422
423            if self.silk_enc.s_cmn.lbrr_gain_increases == 0 {
424                self.silk_enc.s_cmn.lbrr_gain_increases = 2;
425            }
426
427            let hp_freq_smth1 = if mode == OpusMode::CeltOnly {
428                silk_lin2log(60) << 8
429            } else {
430                self.silk_enc.s_cmn.variable_hp_smth1_q15
431            };
432
433            const VARIABLE_HP_SMTH_COEF2_Q16: i32 = 984;
434            self.variable_hp_smth2_q15 = silk_smlawb(
435                self.variable_hp_smth2_q15,
436                hp_freq_smth1 - self.variable_hp_smth2_q15,
437                VARIABLE_HP_SMTH_COEF2_Q16,
438            );
439
440            let cutoff_hz = silk_log2lin(silk_rshift(self.variable_hp_smth2_q15, 8));
441
442            let required_size = frame_size * self.channels;
443            self.buf_filtered.resize(required_size, 0);
444            if self.application == Application::Voip {
445                hp_cutoff(
446                    input,
447                    cutoff_hz,
448                    &mut self.buf_filtered,
449                    &mut self.hp_mem,
450                    frame_size,
451                    self.channels,
452                    self.sampling_rate,
453                );
454            } else {
455                for (i, &x) in input.iter().enumerate() {
456                    self.buf_filtered[i] = (x * 32768.0).clamp(-32768.0, 32767.0) as i16;
457                }
458            }
459
460            let input_i16 = &self.buf_filtered;
461
462            let silk_input: &[i16] = if mode == OpusMode::SilkOnly && self.sampling_rate > 16000 {
463                if self.sampling_rate == 48000 {
464                    let stage1_size = frame_size / 2;
465                    let mut stage1_buf = [0i16; 480];
466                    silk_resampler_down2(
467                        &mut self.down2_state_first,
468                        &mut stage1_buf[..stage1_size],
469                        input_i16,
470                        frame_size as i32,
471                    );
472                    let silk_frame_size = stage1_size * 2 / 3;
473                    self.buf_silk_input.resize(silk_frame_size, 0);
474                    silk_resampler_down2_3(
475                        &mut self.down2_3_state,
476                        &mut self.buf_silk_input,
477                        &stage1_buf[..stage1_size],
478                        stage1_size as i32,
479                    );
480                    &self.buf_silk_input
481                } else if self.sampling_rate == 24000 {
482                    let silk_frame_size = frame_size * 2 / 3;
483                    self.buf_silk_input.resize(silk_frame_size, 0);
484                    silk_resampler_down2_3(
485                        &mut self.down2_3_state,
486                        &mut self.buf_silk_input,
487                        input_i16,
488                        frame_size as i32,
489                    );
490                    &self.buf_silk_input
491                } else {
492                    input_i16
493                }
494            } else if mode == OpusMode::SilkOnly && self.channels == 2 {
495                let frame_length = input_i16.len() / 2;
496                self.buf_stereo_mid.resize(frame_length, 0);
497                self.buf_stereo_side.resize(frame_length, 0);
498                for i in 0..frame_length {
499                    let l = input_i16[2 * i] as i32;
500                    let r = input_i16[2 * i + 1] as i32;
501                    self.buf_stereo_mid[i] = ((l + r) / 2) as i16;
502                    self.buf_stereo_side[i] = (l - r) as i16;
503                }
504
505                self.silk_enc.stereo.side.resize(frame_length, 0);
506                self.silk_enc
507                    .stereo
508                    .side
509                    .copy_from_slice(&self.buf_stereo_side[..frame_length]);
510                &self.buf_stereo_mid
511            } else if mode == OpusMode::Hybrid && self.sampling_rate > 16000 {
512                if self.sampling_rate == 48000 {
513                    let silk_frame_size = frame_size / 3;
514                    self.buf_silk_input.resize(silk_frame_size, 0);
515                    silk::resampler::silk_resampler_down_1_3(
516                        &mut self.down_1_3_state,
517                        &mut self.buf_silk_input,
518                        input_i16,
519                    );
520                } else {
521                    let silk_frame_size = frame_size * 2 / 3;
522                    self.buf_silk_input.resize(silk_frame_size, 0);
523                    silk_resampler_down2_3(
524                        &mut self.down2_3_state,
525                        &mut self.buf_silk_input,
526                        input_i16,
527                        frame_size as i32,
528                    );
529                }
530                &self.buf_silk_input
531            } else {
532                input_i16
533            };
534
535            let mut pn_bytes = 0;
536
537            let silk_rate_for_calc = if mode == OpusMode::Hybrid {
538                16000
539            } else {
540                self.sampling_rate
541            };
542            let silk_frame_len = silk_input.len();
543
544            let silk_bitrate = if mode == OpusMode::Hybrid {
545                let frame_duration_ms = frame_size as i32 * 1000 / self.sampling_rate;
546                let frame20ms = frame_duration_ms >= 20;
547                compute_silk_rate_for_hybrid(self.bitrate_bps, frame20ms)
548            } else {
549                (8i64 * (n_bytes - 1) as i64 * silk_rate_for_calc as i64 / silk_frame_len as i64)
550                    as i32
551            };
552            let silk_max_bits = if mode == OpusMode::Hybrid {
553                // Match C: for VBR Hybrid, maxBits starts at (max_data_bytes-1)*8,
554                // then constrained via compute_silk_rate_for_hybrid.
555                // For CBR Hybrid, allow SILK to steal up to 25% of remaining bits.
556                let total_max_bits = ((n_bytes - 1) * 8) as i32;
557                if self.use_cbr {
558                    let silk_bits = (silk_bitrate as i64 * silk_frame_len as i64
559                        / silk_rate_for_calc as i64) as i32;
560                    let other_bits = 0i32.max(total_max_bits - silk_bits);
561                    0i32.max(total_max_bits - other_bits * 3 / 4)
562                } else {
563                    // Constrained VBR: compute max SILK bits from total budget
564                    let frame_duration_ms = frame_size as i32 * 1000 / self.sampling_rate;
565                    let frame20ms = frame_duration_ms >= 20;
566                    let max_bit_rate = compute_silk_rate_for_hybrid(
567                        total_max_bits * self.sampling_rate / frame_size as i32,
568                        frame20ms,
569                    );
570                    max_bit_rate * frame_size as i32 / self.sampling_rate
571                }
572            } else {
573                ((n_bytes - 1) * 8) as i32
574            };
575            // For Hybrid CBR, force SILK to use VBR (matching C behavior)
576            let silk_use_cbr = if mode == OpusMode::Hybrid && self.use_cbr {
577                0
578            } else if self.use_cbr {
579                1
580            } else {
581                0
582            };
583            let ret = silk_encode(
584                &mut self.silk_enc,
585                silk_input,
586                silk_input.len(),
587                &mut self.rc,
588                &mut pn_bytes,
589                silk_bitrate,
590                silk_max_bits,
591                silk_use_cbr,
592                1,
593            );
594            if ret != 0 {
595                return Err("SILK encoding failed");
596            }
597        }
598
599        if mode == OpusMode::Hybrid {
600            self.rc.encode_bit_logp(false, 12); // redundancy = 0
601        }
602
603        if mode == OpusMode::Hybrid {
604            let nb_compr_bytes = (n_bytes - 1) as u32;
605            self.rc.shrink(nb_compr_bytes);
606        }
607
608        let silk_ret_bytes = if mode == OpusMode::SilkOnly {
609            ((self.rc.tell() + 7) >> 3) as usize
610        } else {
611            0
612        };
613
614        if mode == OpusMode::CeltOnly || mode == OpusMode::Hybrid {
615            self.celt_enc.complexity = self.complexity;
616            let start_band = if mode == OpusMode::Hybrid { 17 } else { 0 };
617            let total_packet_bits = ((n_bytes - 1) * 8) as i32;
618
619            let celt_input: &[f32] = if self.channels == 1 {
620                input
621            } else {
622                let n = frame_size * self.channels;
623                self.buf_celt_input.resize(n, 0.0);
624                for i in 0..frame_size {
625                    for ch in 0..self.channels {
626                        self.buf_celt_input[ch * frame_size + i] = input[i * self.channels + ch];
627                    }
628                }
629                &self.buf_celt_input
630            };
631
632            if self.rc.tell() <= total_packet_bits {
633                self.celt_enc.encode_with_budget(
634                    celt_input,
635                    frame_size,
636                    &mut self.rc,
637                    start_band,
638                    total_packet_bits,
639                );
640            }
641        }
642
643        self.rc.done();
644
645        if mode == OpusMode::SilkOnly {
646            let mut ret = silk_ret_bytes.min(self.rc.storage as usize);
647            while ret > 2 && self.rc.buf[ret - 1] == 0 {
648                ret -= 1;
649            }
650
651            let target_total = if self.use_cbr {
652                n_bytes.min(output.len())
653            } else {
654                (ret + 1).min(output.len())
655            };
656
657            let silk_len = ret;
658
659            if !self.use_cbr || silk_len + 1 >= target_total {
660                // VBR or payload fills the target: simple code 0 packet
661                output[0] = toc;
662                let copy_len = silk_len.min(target_total - 1);
663                output[1..1 + copy_len].copy_from_slice(&self.rc.buf[..copy_len]);
664                return Ok((copy_len + 1).min(output.len()));
665            }
666
667            output[0] = toc | 0x03;
668
669            if silk_len + 2 >= target_total {
670                output[1] = 0x01;
671                let copy_len = (target_total - 2).min(silk_len);
672                output[2..2 + copy_len].copy_from_slice(&self.rc.buf[..copy_len]);
673                self.prev_enc_mode = Some(mode);
674                return Ok(target_total.min(output.len()));
675            }
676
677            let pad_amount = target_total - silk_len - 2;
678            output[1] = 0x41;
679
680            let nb_255s = (pad_amount - 1) / 255;
681            let mut ptr = 2;
682            for _ in 0..nb_255s {
683                output[ptr] = 255;
684                ptr += 1;
685            }
686            output[ptr] = (pad_amount - 255 * nb_255s - 1) as u8;
687            ptr += 1;
688
689            output[ptr..ptr + silk_len].copy_from_slice(&self.rc.buf[..silk_len]);
690            ptr += silk_len;
691
692            let fill_end = target_total.min(output.len());
693            for byte in output[ptr..fill_end].iter_mut() {
694                *byte = 0;
695            }
696
697            self.prev_enc_mode = Some(mode);
698            return Ok(target_total.min(output.len()));
699        }
700
701        let payload_len = n_bytes - 1;
702        output[1..1 + payload_len].copy_from_slice(&self.rc.buf[..payload_len]);
703        self.prev_enc_mode = Some(mode);
704        Ok(n_bytes)
705    }
706}
707
708pub struct OpusDecoder {
709    celt_dec: CeltDecoder,
710    silk_dec: silk::dec_api::SilkDecoder,
711    sampling_rate: i32,
712    channels: usize,
713
714    prev_mode: Option<OpusMode>,
715    frame_size: usize,
716
717    bandwidth: Bandwidth,
718
719    stream_channels: usize,
720
721    silk_resampler: silk::resampler::SilkResampler,
722
723    prev_internal_rate: i32,
724
725    pub hybrid_skip_celt: bool,
726
727    w_pcm_i16: Vec<i16>,
728    w_silk_out: Vec<f32>,
729    w_pcm_resampled: Vec<i16>,
730    w_celt_planar: Vec<f32>,
731    w_celt_out: Vec<f32>,
732}
733
734impl OpusDecoder {
735    pub fn new(sampling_rate: i32, channels: usize) -> Result<Self, &'static str> {
736        if ![8000, 12000, 16000, 24000, 48000].contains(&sampling_rate) {
737            return Err("Invalid sampling rate");
738        }
739        if ![1, 2].contains(&channels) {
740            return Err("Invalid number of channels");
741        }
742
743        let mode = modes::default_mode();
744        let celt_dec = CeltDecoder::new(mode, channels);
745
746        let mut silk_dec = silk::dec_api::SilkDecoder::new();
747        silk_dec.init(sampling_rate.min(16000), channels as i32);
748        silk_dec.channel_state[0].fs_api_hz = sampling_rate;
749
750        Ok(Self {
751            celt_dec,
752            silk_dec,
753            sampling_rate,
754            channels,
755            prev_mode: None,
756            frame_size: 0,
757            bandwidth: Bandwidth::Auto,
758            stream_channels: channels,
759            silk_resampler: silk::resampler::SilkResampler::default(),
760            prev_internal_rate: 0,
761            hybrid_skip_celt: false,
762
763            w_pcm_i16: vec![0i16; 640],
764
765            w_silk_out: vec![0.0f32; 5760 * channels],
766            w_pcm_resampled: vec![0i16; 5760 * channels],
767            w_celt_planar: vec![0.0f32; 5760 * channels],
768            w_celt_out: vec![0.0f32; 5760 * channels],
769        })
770    }
771
772    pub fn decode(
773        &mut self,
774        input: &[u8],
775        frame_size: usize,
776        output: &mut [f32],
777    ) -> Result<usize, &'static str> {
778        if input.is_empty() {
779            return Err("Input packet empty");
780        }
781
782        let toc = input[0];
783        let mode = mode_from_toc(toc);
784        let packet_channels = channels_from_toc(toc);
785        let bandwidth = bandwidth_from_toc(toc);
786        let frame_duration_ms = frame_duration_ms_from_toc(toc);
787
788        if packet_channels != self.channels {
789            return Err("Channel count mismatch between packet and decoder");
790        }
791
792        let code = toc & 0x03;
793        let frame_count: usize;
794        let frame_payloads: Vec<&[u8]>;
795
796        match code {
797            0 => {
798                frame_count = 1;
799                frame_payloads = vec![&input[1..]];
800            }
801            1 => {
802                frame_count = 2;
803                let half = (input.len() - 1) / 2;
804                if half == 0 {
805                    return Err("Code 1: empty frame");
806                }
807                frame_payloads = vec![&input[1..1 + half], &input[1 + half..]];
808            }
809            2 => {
810                frame_count = 2;
811                let data = &input[1..];
812                if data.is_empty() {
813                    return Err("Code 2 packet has no data");
814                }
815                let (first_len, header_size) = if data[0] & 0x80 != 0 {
816                    if data.len() < 2 {
817                        return Err("Code 2 packet too short for 2-byte length");
818                    }
819                    (((data[0] & 0x7F) as usize) << 8 | data[1] as usize, 2)
820                } else {
821                    (data[0] as usize, 1)
822                };
823                if header_size + first_len > data.len() {
824                    return Err("Code 2: first frame size exceeds packet");
825                }
826                frame_payloads = vec![
827                    &data[header_size..header_size + first_len],
828                    &data[header_size + first_len..],
829                ];
830            }
831            3 => {
832                if input.len() < 2 {
833                    return Err("Code 3 packet too short");
834                }
835                let count_byte = input[1];
836                let n_frames = (count_byte & 0x3F) as usize;
837                if n_frames < 1 || n_frames > 48 {
838                    return Err("Code 3: invalid frame count");
839                }
840                frame_count = n_frames;
841                let padding_flag = (count_byte & 0x40) != 0;
842
843                if padding_flag {
844                    let mut ptr = 2usize;
845                    let mut pad_len = 0usize;
846                    loop {
847                        if ptr >= input.len() {
848                            return Err("Padding overflow");
849                        }
850                        let p = input[ptr] as usize;
851                        ptr += 1;
852                        if p == 255 {
853                            pad_len += 254;
854                        } else {
855                            pad_len += p;
856                            break;
857                        }
858                    }
859
860                    let end = input.len().saturating_sub(pad_len);
861                    if ptr > end {
862                        return Err("Padding exceeds packet");
863                    }
864                    let compressed = &input[ptr..end];
865                    let frame_len = compressed.len() / frame_count;
866                    if frame_len == 0 {
867                        return Err("Code 3 with padding: empty frame");
868                    }
869                    frame_payloads = compressed.chunks(frame_len).collect();
870                    if frame_payloads.len() != frame_count {
871                        return Err("Code 3: frame count mismatch");
872                    }
873                } else {
874                    // Self-delimiting format:
875                    // - Single frame: no length prefix, remaining data is the frame
876                    // - Multi-frame: lengths for all frames except the last
877                    let mut payload_ptr = 2usize;
878                    if frame_count == 1 {
879                        frame_payloads = vec![&input[payload_ptr..]];
880                    } else {
881                        let mut payloads = Vec::with_capacity(frame_count);
882                        for _f in 0..frame_count - 1 {
883                            if payload_ptr >= input.len() {
884                                return Err("Code 3: unexpected end in self-delimiting header");
885                            }
886                            let (frame_len, header_bytes) = if input[payload_ptr] & 0x80 != 0 {
887                                if payload_ptr + 2 > input.len() {
888                                    return Err("Code 3: short frame length");
889                                }
890                                (
891                                    ((input[payload_ptr] & 0x7F) as usize) << 8
892                                        | input[payload_ptr + 1] as usize,
893                                    2,
894                                )
895                            } else {
896                                (input[payload_ptr] as usize, 1)
897                            };
898                            payload_ptr += header_bytes;
899                            if payload_ptr + frame_len > input.len() {
900                                return Err("Code 3: frame length exceeds packet");
901                            }
902                            payloads.push(&input[payload_ptr..payload_ptr + frame_len]);
903                            payload_ptr += frame_len;
904                        }
905                        // Last frame: no length prefix, remaining data
906                        if payload_ptr > input.len() {
907                            return Err("Code 3: no data for last frame");
908                        }
909                        payloads.push(&input[payload_ptr..]);
910                        frame_payloads = payloads;
911                    }
912                }
913            }
914            _ => unreachable!(),
915        }
916
917        self.frame_size = frame_size;
918        self.bandwidth = bandwidth;
919        self.stream_channels = packet_channels;
920
921        let sub_frame_size = frame_size / frame_count;
922        let sub_output_len = sub_frame_size * self.channels;
923
924        match mode {
925            OpusMode::SilkOnly => {
926                let internal_sample_rate = match bandwidth {
927                    Bandwidth::Narrowband => 8000,
928                    Bandwidth::Mediumband => 12000,
929                    Bandwidth::Wideband => 16000,
930                    _ => 16000,
931                };
932                let internal_frame_size =
933                    (frame_duration_ms * internal_sample_rate / 1000) as usize;
934
935                if self.sampling_rate != internal_sample_rate
936                    && internal_sample_rate != self.prev_internal_rate
937                {
938                    self.silk_resampler
939                        .init(internal_sample_rate, self.sampling_rate);
940                    self.prev_internal_rate = internal_sample_rate;
941                }
942
943                for (fi, payload) in frame_payloads.iter().enumerate() {
944                    let mut rc = RangeCoder::new_decoder(payload);
945                    let pcm_i16_len = internal_frame_size * self.channels;
946                    debug_assert!(pcm_i16_len <= self.w_pcm_i16.len());
947
948                    let ret = {
949                        let (silk_dec, pcm_i16) = (&mut self.silk_dec, &mut self.w_pcm_i16);
950                        silk_dec.decode(
951                            &mut rc,
952                            &mut pcm_i16[..pcm_i16_len],
953                            silk::decode_frame::FLAG_DECODE_NORMAL,
954                            true,
955                            frame_duration_ms,
956                            internal_sample_rate,
957                        )
958                    };
959
960                    if ret < 0 {
961                        return Err("SILK decoding failed");
962                    }
963
964                    let decoded_samples = ret as usize;
965                    let out_start = fi * sub_output_len;
966
967                    // SILK only decodes channel 0 (mono). For multi-channel output,
968                    // replicate the mono samples to every channel.
969                    if self.sampling_rate == internal_sample_rate {
970                        let frames = decoded_samples.min(sub_frame_size);
971                        for i in 0..frames {
972                            let v = self.w_pcm_i16[i] as f32 / 32768.0;
973                            for ch in 0..self.channels {
974                                let idx = out_start + i * self.channels + ch;
975                                if idx < output.len() {
976                                    output[idx] = v;
977                                }
978                            }
979                        }
980                    } else {
981                        let ratio = self.sampling_rate as f64 / internal_sample_rate as f64;
982                        let out_len =
983                            ((decoded_samples as f64 * ratio) as usize).min(sub_frame_size);
984                        debug_assert!(out_len <= self.w_pcm_resampled.len());
985                        {
986                            let (silk_res, pcm_i16, pcm_out) = (
987                                &mut self.silk_resampler,
988                                &self.w_pcm_i16,
989                                &mut self.w_pcm_resampled,
990                            );
991                            silk_res.process(
992                                &mut pcm_out[..out_len],
993                                &pcm_i16[..decoded_samples],
994                                decoded_samples as i32,
995                            );
996                        }
997                        let frames = out_len.min(sub_frame_size);
998                        for i in 0..frames {
999                            let v = self.w_pcm_resampled[i] as f32 / 32768.0;
1000                            for ch in 0..self.channels {
1001                                let idx = out_start + i * self.channels + ch;
1002                                if idx < output.len() {
1003                                    output[idx] = v;
1004                                }
1005                            }
1006                        }
1007                    }
1008                }
1009                self.prev_mode = Some(OpusMode::SilkOnly);
1010                Ok(frame_size)
1011            }
1012
1013            OpusMode::CeltOnly => {
1014                let celt_end_band = self.celt_end_band_from_toc(toc);
1015
1016                for (fi, payload) in frame_payloads.iter().enumerate() {
1017                    let mut rc = RangeCoder::new_decoder(payload);
1018                    let total_bits = (payload.len() * 8) as i32;
1019                    let needed = sub_frame_size * self.channels;
1020                    let out_start = fi * needed;
1021                    let out_end = (out_start + needed).min(output.len());
1022
1023                    if output.len() < out_end {
1024                        return Err("Output buffer too small");
1025                    }
1026
1027                    if self.channels == 1 {
1028                        self.celt_dec.decode_from_range_coder_with_band_range(
1029                            &mut rc,
1030                            total_bits,
1031                            sub_frame_size,
1032                            &mut output[out_start..out_end],
1033                            0,
1034                            celt_end_band,
1035                        );
1036                        for sample in &mut output[out_start..out_end] {
1037                            *sample = sample.clamp(-1.0, 1.0);
1038                        }
1039                    } else {
1040                        self.celt_dec.decode_from_range_coder_with_band_range(
1041                            &mut rc,
1042                            total_bits,
1043                            sub_frame_size,
1044                            &mut self.w_celt_planar[..needed],
1045                            0,
1046                            celt_end_band,
1047                        );
1048                        for i in 0..sub_frame_size {
1049                            for ch in 0..self.channels {
1050                                let idx = out_start + i * self.channels + ch;
1051                                output[idx] =
1052                                    self.w_celt_planar[ch * sub_frame_size + i].clamp(-1.0, 1.0);
1053                            }
1054                        }
1055                    }
1056                }
1057                self.prev_mode = Some(OpusMode::CeltOnly);
1058                Ok(frame_size)
1059            }
1060
1061            OpusMode::Hybrid => {
1062                let internal_sample_rate = 16000;
1063                let internal_frame_size =
1064                    (frame_duration_ms * internal_sample_rate / 1000) as usize;
1065                let celt_end_band = self.celt_end_band_from_toc(toc);
1066
1067                if self.sampling_rate != internal_sample_rate
1068                    && internal_sample_rate != self.prev_internal_rate
1069                {
1070                    self.silk_resampler
1071                        .init(internal_sample_rate, self.sampling_rate);
1072                    self.prev_internal_rate = internal_sample_rate;
1073                }
1074
1075                for (fi, payload) in frame_payloads.iter().enumerate() {
1076                    let mut rc = RangeCoder::new_decoder(payload);
1077                    let pcm_silk_i16_len = internal_frame_size * self.channels;
1078                    debug_assert!(pcm_silk_i16_len <= self.w_pcm_i16.len());
1079
1080                    let ret = {
1081                        let (silk_dec, pcm_i16) = (&mut self.silk_dec, &mut self.w_pcm_i16);
1082                        silk_dec.decode(
1083                            &mut rc,
1084                            &mut pcm_i16[..pcm_silk_i16_len],
1085                            silk::decode_frame::FLAG_DECODE_NORMAL,
1086                            true,
1087                            frame_duration_ms,
1088                            internal_sample_rate,
1089                        )
1090                    };
1091
1092                    if ret < 0 {
1093                        return Err("SILK decoding failed");
1094                    }
1095
1096                    let silk_out_len = sub_frame_size * self.channels;
1097                    self.w_silk_out[..silk_out_len].fill(0.0);
1098                    if ret > 0 {
1099                        let decoded_samples = ret as usize;
1100                        // SILK only decodes channel 0 (mono). For multi-channel output,
1101                        // replicate the mono samples to every channel in w_silk_out.
1102                        if self.sampling_rate == internal_sample_rate {
1103                            let frames = decoded_samples.min(sub_frame_size);
1104                            for i in 0..frames {
1105                                let v = self.w_pcm_i16[i] as f32 / 32768.0;
1106                                for ch in 0..self.channels {
1107                                    let idx = i * self.channels + ch;
1108                                    if idx < silk_out_len {
1109                                        self.w_silk_out[idx] = v;
1110                                    }
1111                                }
1112                            }
1113                        } else {
1114                            let ratio = self.sampling_rate as f64 / internal_sample_rate as f64;
1115                            let out_len =
1116                                ((decoded_samples as f64 * ratio) as usize).min(sub_frame_size);
1117                            debug_assert!(out_len <= self.w_pcm_resampled.len());
1118                            {
1119                                let (silk_res, pcm_i16, pcm_resampled) = (
1120                                    &mut self.silk_resampler,
1121                                    &self.w_pcm_i16,
1122                                    &mut self.w_pcm_resampled,
1123                                );
1124                                silk_res.process(
1125                                    &mut pcm_resampled[..out_len],
1126                                    &pcm_i16[..decoded_samples],
1127                                    decoded_samples as i32,
1128                                );
1129                            }
1130                            let frames = out_len.min(sub_frame_size);
1131                            for i in 0..frames {
1132                                let v = self.w_pcm_resampled[i] as f32 / 32768.0;
1133                                for ch in 0..self.channels {
1134                                    let idx = i * self.channels + ch;
1135                                    if idx < silk_out_len {
1136                                        self.w_silk_out[idx] = v;
1137                                    }
1138                                }
1139                            }
1140                        }
1141                    }
1142
1143                    let total_bits = (payload.len() * 8) as i32;
1144                    let redundancy = rc.decode_bit_logp(12);
1145                    let skip_celt = if redundancy {
1146                        let _ = rc.decode_bit_logp(1);
1147                        true
1148                    } else {
1149                        false
1150                    };
1151
1152                    if skip_celt {
1153                        self.w_celt_out[..silk_out_len].fill(0.0);
1154                    } else {
1155                        let (celt_dec, celt_planar) = (&mut self.celt_dec, &mut self.w_celt_planar);
1156                        celt_dec.decode_from_range_coder_with_band_range(
1157                            &mut rc,
1158                            total_bits,
1159                            sub_frame_size,
1160                            &mut celt_planar[..silk_out_len],
1161                            17,
1162                            celt_end_band,
1163                        );
1164
1165                        if self.channels == 1 {
1166                            self.w_celt_out[..silk_out_len]
1167                                .copy_from_slice(&self.w_celt_planar[..silk_out_len]);
1168                        } else {
1169                            for i in 0..sub_frame_size {
1170                                for ch in 0..self.channels {
1171                                    self.w_celt_out[i * self.channels + ch] =
1172                                        self.w_celt_planar[ch * sub_frame_size + i];
1173                                }
1174                            }
1175                        }
1176                    }
1177
1178                    let out_start = fi * silk_out_len;
1179                    let total = silk_out_len.min(output.len() - out_start);
1180                    for j in 0..total {
1181                        output[out_start + j] =
1182                            (self.w_silk_out[j] + self.w_celt_out[j]).clamp(-1.0, 1.0);
1183                    }
1184                }
1185                self.prev_mode = Some(OpusMode::Hybrid);
1186                Ok(frame_size)
1187            }
1188        }
1189    }
1190}
1191
1192impl OpusDecoder {
1193    #[inline(always)]
1194    fn celt_end_band_from_toc(&self, toc: u8) -> usize {
1195        let mode = modes::default_mode();
1196        let top = mode.eff_ebands;
1197        if mode_from_toc(toc) == OpusMode::CeltOnly && toc >= 0x80 {
1198            const FROM_OPUS_TABLE: [u8; 16] = [
1199                0x80, 0x88, 0x90, 0x98, 0x40, 0x48, 0x50, 0x58, 0x20, 0x28, 0x30, 0x38, 0x00, 0x08,
1200                0x10, 0x18,
1201            ];
1202            let idx = ((toc >> 3) - 16) as usize;
1203            let data0 = FROM_OPUS_TABLE[idx] | (toc & 0x7);
1204            let trim = (data0 >> 5) as usize;
1205            return top.saturating_sub(2 * trim).max(1);
1206        }
1207        top
1208    }
1209}
1210
1211fn frame_rate_from_params(sampling_rate: i32, frame_size: usize) -> Option<i32> {
1212    let frame_size = frame_size as i32;
1213    if frame_size == 0 || sampling_rate % frame_size != 0 {
1214        return None;
1215    }
1216    Some(sampling_rate / frame_size)
1217}
1218
1219fn gen_toc(mode: OpusMode, frame_rate: i32, bandwidth: Bandwidth, channels: usize) -> u8 {
1220    let mut rate = frame_rate;
1221    let mut period = 0;
1222    while rate < 400 {
1223        rate <<= 1;
1224        period += 1;
1225    }
1226
1227    let mut toc = match mode {
1228        OpusMode::SilkOnly => {
1229            let bw = (bandwidth as i32 - Bandwidth::Narrowband as i32) << 5;
1230            let per = (period - 2) << 3;
1231            (bw | per) as u8
1232        }
1233        OpusMode::CeltOnly => {
1234            let mut tmp = bandwidth as i32 - Bandwidth::Mediumband as i32;
1235            if tmp < 0 {
1236                tmp = 0;
1237            }
1238            let per = period << 3;
1239            (0x80 | (tmp << 5) | per) as u8
1240        }
1241        OpusMode::Hybrid => {
1242            let base_config = if bandwidth == Bandwidth::Superwideband {
1243                12
1244            } else {
1245                14
1246            };
1247            let period_offset = if frame_rate >= 100 { 0 } else { 1 };
1248            ((base_config + period_offset) << 3) as u8
1249        }
1250    };
1251
1252    if channels == 2 {
1253        toc |= 0x04;
1254    }
1255    toc
1256}
1257
1258fn mode_from_toc(toc: u8) -> OpusMode {
1259    if toc & 0x80 != 0 {
1260        OpusMode::CeltOnly
1261    } else if toc & 0x60 == 0x60 {
1262        OpusMode::Hybrid
1263    } else {
1264        OpusMode::SilkOnly
1265    }
1266}
1267
1268fn bandwidth_from_toc(toc: u8) -> Bandwidth {
1269    let mode = mode_from_toc(toc);
1270    match mode {
1271        OpusMode::SilkOnly => {
1272            let bw_bits = (toc >> 5) & 0x03;
1273            match bw_bits {
1274                0 => Bandwidth::Narrowband,
1275                1 => Bandwidth::Mediumband,
1276                2 => Bandwidth::Wideband,
1277                _ => Bandwidth::Wideband,
1278            }
1279        }
1280        OpusMode::Hybrid => {
1281            let bw_bit = (toc >> 4) & 0x01;
1282            if bw_bit == 0 {
1283                Bandwidth::Superwideband
1284            } else {
1285                Bandwidth::Fullband
1286            }
1287        }
1288        OpusMode::CeltOnly => {
1289            let bw_bits = (toc >> 5) & 0x03;
1290            match bw_bits {
1291                0 => Bandwidth::Mediumband,
1292                1 => Bandwidth::Wideband,
1293                2 => Bandwidth::Superwideband,
1294                3 => Bandwidth::Fullband,
1295                _ => Bandwidth::Fullband,
1296            }
1297        }
1298    }
1299}
1300
1301fn frame_duration_ms_from_toc(toc: u8) -> i32 {
1302    let mode = mode_from_toc(toc);
1303    match mode {
1304        OpusMode::SilkOnly => {
1305            let config = (toc >> 3) & 0x03;
1306            match config {
1307                0 => 10,
1308                1 => 20,
1309                2 => 40,
1310                3 => 60,
1311                _ => 20,
1312            }
1313        }
1314        OpusMode::Hybrid => {
1315            let config = (toc >> 3) & 0x01;
1316            if config == 0 { 10 } else { 20 }
1317        }
1318        OpusMode::CeltOnly => {
1319            let config = (toc >> 3) & 0x03;
1320            match config {
1321                0 => 2,
1322                1 => 5,
1323                2 => 10,
1324                3 => 20,
1325                _ => 20,
1326            }
1327        }
1328    }
1329}
1330
1331fn channels_from_toc(toc: u8) -> usize {
1332    if toc & 0x04 != 0 { 2 } else { 1 }
1333}
1334
1335#[cfg(test)]
1336mod tests {
1337    use super::*;
1338
1339    fn frame_size_from_toc(toc: u8, sampling_rate: i32) -> Option<usize> {
1340        let mode = mode_from_toc(toc);
1341        match mode {
1342            OpusMode::CeltOnly => {
1343                let period = ((toc >> 3) & 0x03) as i32;
1344                let frame_rate = 400 >> period;
1345                if frame_rate == 0 || sampling_rate % frame_rate != 0 {
1346                    return None;
1347                }
1348                Some((sampling_rate / frame_rate) as usize)
1349            }
1350            OpusMode::SilkOnly => {
1351                let duration_ms = frame_duration_ms_from_toc(toc);
1352                Some((sampling_rate as i64 * duration_ms as i64 / 1000) as usize)
1353            }
1354            OpusMode::Hybrid => {
1355                let duration_ms = frame_duration_ms_from_toc(toc);
1356                Some((sampling_rate as i64 * duration_ms as i64 / 1000) as usize)
1357            }
1358        }
1359    }
1360
1361    #[test]
1362    fn gen_toc_matches_celt_reference_values() {
1363        let sampling_rate = 48_000;
1364        let cases = [
1365            (120usize, 0xE0u8),
1366            (240usize, 0xE8u8),
1367            (480usize, 0xF0u8),
1368            (960usize, 0xF8u8),
1369        ];
1370
1371        for (frame_size, expected_toc) in cases {
1372            let frame_rate = frame_rate_from_params(sampling_rate, frame_size).unwrap();
1373            let toc = gen_toc(OpusMode::CeltOnly, frame_rate, Bandwidth::Fullband, 1);
1374            assert_eq!(
1375                toc, expected_toc,
1376                "frame_size {} expected TOC {:02X} got {:02X}",
1377                frame_size, expected_toc, toc
1378            );
1379            let decoded_size = frame_size_from_toc(toc, sampling_rate).unwrap();
1380            assert_eq!(decoded_size, frame_size);
1381        }
1382
1383        let stereo_toc = gen_toc(
1384            OpusMode::CeltOnly,
1385            frame_rate_from_params(sampling_rate, 960).unwrap(),
1386            Bandwidth::Fullband,
1387            2,
1388        );
1389        assert_eq!(channels_from_toc(stereo_toc), 2);
1390    }
1391
1392    #[test]
1393    fn test_celt_decoder_large_frame_sizes() {
1394        let sampling_rate = 48000;
1395        let channels = 1;
1396
1397        let mut decoder = OpusDecoder::new(sampling_rate, channels).unwrap();
1398
1399        let frame_sizes = [120, 240, 480, 960];
1400
1401        for frame_size in frame_sizes {
1402            let toc = gen_toc(
1403                OpusMode::CeltOnly,
1404                frame_rate_from_params(sampling_rate, frame_size).unwrap(),
1405                Bandwidth::Fullband,
1406                channels,
1407            );
1408            let packet = [toc, 0, 0, 0, 0];
1409
1410            let mut output = vec![0.0f32; frame_size * channels];
1411
1412            let _ = decoder.decode(&packet, frame_size, &mut output);
1413        }
1414
1415        let channels = 2;
1416        let mut decoder = OpusDecoder::new(sampling_rate, channels).unwrap();
1417
1418        for frame_size in frame_sizes {
1419            let toc = gen_toc(
1420                OpusMode::CeltOnly,
1421                frame_rate_from_params(sampling_rate, frame_size).unwrap(),
1422                Bandwidth::Fullband,
1423                channels,
1424            );
1425            let packet = [toc, 0, 0, 0, 0];
1426
1427            let mut output = vec![0.0f32; frame_size * channels];
1428            let _ = decoder.decode(&packet, frame_size, &mut output);
1429        }
1430    }
1431
1432    #[test]
1433    fn test_celt_decoder_edge_case_frame_sizes() {
1434        let sampling_rate = 48000;
1435        let channels = 1;
1436        let mut decoder = OpusDecoder::new(sampling_rate, channels).unwrap();
1437
1438        let edge_sizes = [2048, 2167, 2168, 2169, 2880, 3072];
1439
1440        for frame_size in edge_sizes {
1441            let mut output = vec![0.0f32; frame_size * channels];
1442
1443            let _ = decoder.decode(&[0x80, 0, 0, 0], frame_size, &mut output);
1444        }
1445    }
1446}