Skip to main content

opus_rs/
lib.rs

1pub mod bands;
2pub mod celt;
3pub mod celt_lpc;
4pub mod hp_cutoff;
5pub mod kiss_fft;
6pub mod mdct;
7pub mod modes;
8pub mod pitch;
9pub mod pvq;
10pub mod quant_bands;
11pub mod range_coder;
12pub mod rate;
13pub mod silk;
14
15pub use silk::{SilkResampler, SilkResamplerDown1_3, SilkResamplerDown1_6};
16
17use celt::{CeltDecoder, CeltEncoder};
18use hp_cutoff::hp_cutoff;
19use range_coder::RangeCoder;
20use silk::control_codec::silk_control_encoder;
21use silk::enc_api::silk_encode;
22use silk::init_encoder::silk_init_encoder;
23use silk::lin2log::silk_lin2log;
24use silk::log2lin::silk_log2lin;
25use silk::macros::*;
26use silk::resampler::{silk_resampler_down2, silk_resampler_down2_3};
27use silk::structs::SilkEncoderState;
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum Application {
31    Voip = 2048,
32    Audio = 2049,
33    RestrictedLowDelay = 2051,
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum Bandwidth {
38    Auto = -1000,
39    Narrowband = 1101,
40    Mediumband = 1102,
41    Wideband = 1103,
42    Superwideband = 1104,
43    Fullband = 1105,
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47enum OpusMode {
48    SilkOnly,
49    Hybrid,
50    CeltOnly,
51}
52
53pub struct OpusEncoder {
54    celt_enc: CeltEncoder,
55    silk_enc: Box<SilkEncoderState>,
56    application: Application,
57    sampling_rate: i32,
58    channels: usize,
59    bandwidth: Bandwidth,
60    pub bitrate_bps: i32,
61    pub complexity: i32,
62    pub use_cbr: bool,
63
64    pub use_inband_fec: bool,
65
66    pub packet_loss_perc: i32,
67    silk_initialized: bool,
68    mode: OpusMode,
69
70    variable_hp_smth2_q15: i32,
71    hp_mem: Vec<i32>,
72
73    buf_filtered: Vec<i16>,
74    buf_silk_input: Vec<i16>,
75    buf_stereo_mid: Vec<i16>,
76    buf_stereo_side: Vec<i16>,
77    down2_state_first: [i32; 2],
78    down2_state_second: [i32; 2],
79    down2_3_state: [i32; 6],
80    down_1_3_state: silk::resampler::SilkResamplerDown1_3,
81}
82
83/// Compute the per-channel SILK bitrate (bps) for hybrid mode, matching the reference.
84/// Reference: `compute_silk_rate_for_hybrid` in opus_encoder.c
85fn compute_silk_rate_for_hybrid(rate_bps: i32, frame20ms: bool) -> i32 {
86    // [total_rate_bps, silk_rate_10ms, silk_rate_20ms] — no FEC, single channel
87    const RATE_TABLE: &[(i32, i32, i32)] = &[
88        (0, 0, 0),
89        (12000, 10000, 10000),
90        (16000, 13500, 13500),
91        (20000, 16000, 16000),
92        (24000, 18000, 18000),
93        (32000, 22000, 22000),
94        (64000, 38000, 38000),
95    ];
96    let n = RATE_TABLE.len();
97    let mut i = 1;
98    while i < n && RATE_TABLE[i].0 <= rate_bps {
99        i += 1;
100    }
101    if i == n {
102        let (x_last, r10_last, r20_last) = RATE_TABLE[n - 1];
103        let base = if frame20ms { r20_last } else { r10_last };
104        base + (rate_bps - x_last) / 2
105    } else {
106        let (x0, lo10, lo20) = RATE_TABLE[i - 1];
107        let (x1, hi10, hi20) = RATE_TABLE[i];
108        let (lo, hi) = if frame20ms {
109            (lo20, hi20)
110        } else {
111            (lo10, hi10)
112        };
113        (lo * (x1 - rate_bps) + hi * (rate_bps - x0)) / (x1 - x0)
114    }
115}
116
117#[cfg(test)]
118mod silk_rate_tests {
119    use super::compute_silk_rate_for_hybrid;
120
121    #[test]
122    fn test_reference_table_exact_entries() {
123        // Values from the static table in compute_silk_rate_for_hybrid
124        assert_eq!(compute_silk_rate_for_hybrid(12000, true), 10000);
125        assert_eq!(compute_silk_rate_for_hybrid(16000, true), 13500);
126        assert_eq!(compute_silk_rate_for_hybrid(20000, true), 16000);
127        assert_eq!(compute_silk_rate_for_hybrid(24000, true), 18000);
128        assert_eq!(compute_silk_rate_for_hybrid(32000, true), 22000);
129        assert_eq!(compute_silk_rate_for_hybrid(64000, true), 38000);
130    }
131
132    #[test]
133    fn test_32kbps_gives_22kbps_silk() {
134        // At the reference working point (voip/audio 48k 32kbps 20ms),
135        // SILK should receive 22 000 bps — more than the old 40% split (12 600 bps).
136        assert_eq!(compute_silk_rate_for_hybrid(32000, true), 22000);
137    }
138
139    #[test]
140    fn test_interpolation_between_table_entries() {
141        // Between 16 kbps and 20 kbps: at 18 kbps, linear interp → (13500+16000)/2 = 14750
142        let r = compute_silk_rate_for_hybrid(18000, true);
143        assert_eq!(r, 14750);
144    }
145
146    #[test]
147    fn test_above_table_max_gives_half_extra() {
148        // Above 64 kbps: base 38 000 + (rate - 64 000) / 2
149        let r = compute_silk_rate_for_hybrid(72000, true);
150        assert_eq!(r, 38000 + (72000 - 64000) / 2); // 42 000
151    }
152}
153
154impl OpusEncoder {
155    pub fn new(
156        sampling_rate: i32,
157        channels: usize,
158        application: Application,
159    ) -> Result<Self, &'static str> {
160        if ![8000, 12000, 16000, 24000, 48000].contains(&sampling_rate) {
161            return Err("Invalid sampling rate");
162        }
163        if ![1, 2].contains(&channels) {
164            return Err("Invalid number of channels");
165        }
166
167        let mode = modes::default_mode();
168        let celt_enc = CeltEncoder::new(mode, channels);
169
170        let mut silk_enc = Box::new(SilkEncoderState::default());
171        if silk_init_encoder(&mut *silk_enc, 0) != 0 {
172            return Err("SILK encoder initialization failed");
173        }
174
175        let (opus_mode, bw) = match application {
176            Application::Voip => {
177                let bw = match sampling_rate {
178                    8000 => Bandwidth::Narrowband,
179                    12000 => Bandwidth::Mediumband,
180                    16000 => Bandwidth::Wideband,
181                    24000 => Bandwidth::Superwideband,
182                    48000 => Bandwidth::Fullband,
183                    _ => Bandwidth::Narrowband,
184                };
185                // For >16kHz, use Hybrid mode (SILK for low frequencies, CELT for high)
186                let mode = if sampling_rate > 16000 {
187                    OpusMode::Hybrid
188                } else {
189                    OpusMode::SilkOnly
190                };
191                (mode, bw)
192            }
193            _ => {
194                // Audio mode: use SILK for ≤16kHz, Hybrid for >16kHz.
195                // Hybrid uses SILK for the low-frequency speech content (0-8kHz) and
196                // CELT for the high-frequency extension, which gives much better quality
197                // at modest bitrates (≤50 kbps) than CELT-only across the full band.
198                if sampling_rate <= 16000 {
199                    let bw = match sampling_rate {
200                        8000 => Bandwidth::Narrowband,
201                        12000 => Bandwidth::Mediumband,
202                        _ => Bandwidth::Wideband,
203                    };
204                    (OpusMode::SilkOnly, bw)
205                } else {
206                    let bw = match sampling_rate {
207                        24000 => Bandwidth::Superwideband,
208                        _ => Bandwidth::Fullband,
209                    };
210                    (OpusMode::Hybrid, bw)
211                }
212            }
213        };
214
215        use silk::lin2log::silk_lin2log;
216        let variable_hp_smth2_q15 = silk_lin2log(60) << 8;
217
218        Ok(Self {
219            celt_enc,
220            silk_enc,
221            application,
222            sampling_rate,
223            channels,
224            bandwidth: bw,
225            bitrate_bps: 64000,
226            complexity: 0,
227            use_cbr: false,
228            use_inband_fec: false,
229            packet_loss_perc: 0,
230            silk_initialized: false,
231            mode: opus_mode,
232            variable_hp_smth2_q15,
233            hp_mem: vec![0; channels * 2],
234
235            buf_filtered: Vec::new(),
236            buf_silk_input: Vec::new(),
237            buf_stereo_mid: Vec::new(),
238            buf_stereo_side: Vec::new(),
239            down2_state_first: [0; 2],
240            down2_state_second: [0; 2],
241            down2_3_state: [0; 6],
242            down_1_3_state: silk::resampler::SilkResamplerDown1_3::default(),
243        })
244    }
245
246    pub fn enable_hybrid_mode(&mut self) -> Result<(), &'static str> {
247        if self.sampling_rate != 24000 && self.sampling_rate != 48000 {
248            return Err("Hybrid mode requires 24kHz or 48kHz sampling rate");
249        }
250        let bw = if self.sampling_rate == 48000 {
251            Bandwidth::Fullband
252        } else {
253            Bandwidth::Superwideband
254        };
255        self.mode = OpusMode::Hybrid;
256        self.bandwidth = bw;
257        self.silk_initialized = false;
258        Ok(())
259    }
260
261    pub fn encode(
262        &mut self,
263        input: &[f32],
264        frame_size: usize,
265        output: &mut [u8],
266    ) -> Result<usize, &'static str> {
267        if output.len() < 2 {
268            return Err("Output buffer too small");
269        }
270
271        let frame_rate = frame_rate_from_params(self.sampling_rate, frame_size)
272            .ok_or("Invalid frame size for sampling rate")?;
273
274        let mode = self.mode;
275        if mode == OpusMode::CeltOnly {
276            match frame_rate {
277                400 | 200 | 100 | 50 => {}
278                _ => return Err("Unsupported frame size for CELT-only mode"),
279            }
280        }
281
282        let toc = gen_toc(mode, frame_rate, self.bandwidth, self.channels);
283        output[0] = toc;
284
285        let target_bits =
286            (self.bitrate_bps as i64 * frame_size as i64 / self.sampling_rate as i64) as i32;
287        let cbr_bytes = ((target_bits + 4) / 8) as usize;
288        let max_data_bytes = output.len();
289
290        let n_bytes = if self.use_cbr {
291            cbr_bytes.min(max_data_bytes).max(1)
292        } else {
293            max_data_bytes
294        };
295
296        // Use n_bytes-1 as RC storage so the full buffer can be passed to the decoder.
297        // This ensures encoder total_bits == decoder total_bits (no mismatch from done() overhead).
298        let mut rc = RangeCoder::new_encoder((n_bytes - 1) as u32);
299
300        if mode == OpusMode::SilkOnly || mode == OpusMode::Hybrid {
301            let silk_fs_khz = if mode == OpusMode::Hybrid {
302                16
303            } else {
304                self.sampling_rate.min(16000) / 1000
305            };
306
307            let frame_ms = (frame_size as i32 * 1000) / self.sampling_rate;
308            if !self.silk_initialized || self.silk_enc.s_cmn.fs_khz != silk_fs_khz as i32 {
309                let silk_init_bitrate = (((n_bytes - 1) * 8) as i64 * self.sampling_rate as i64
310                    / frame_size as i64) as i32;
311                silk_control_encoder(
312                    &mut *self.silk_enc,
313                    silk_fs_khz as i32,
314                    frame_ms,
315                    silk_init_bitrate,
316                    self.complexity,
317                );
318                self.silk_enc.s_cmn.use_cbr = if self.use_cbr { 1 } else { 0 };
319
320                self.silk_enc.s_cmn.n_channels = self.channels as i32;
321                self.silk_initialized = true;
322                self.down2_state_first = [0; 2];
323                self.down2_state_second = [0; 2];
324                self.down2_3_state = [0; 6];
325                self.down_1_3_state = silk::resampler::SilkResamplerDown1_3::default();
326            }
327
328            self.silk_enc.s_cmn.use_in_band_fec = if self.use_inband_fec { 1 } else { 0 };
329            self.silk_enc.s_cmn.packet_loss_perc = self.packet_loss_perc.clamp(0, 100);
330
331            self.silk_enc.s_cmn.lbrr_enabled = if self.use_inband_fec { 1 } else { 0 };
332
333            if self.silk_enc.s_cmn.lbrr_gain_increases == 0 {
334                self.silk_enc.s_cmn.lbrr_gain_increases = 2;
335            }
336
337            let hp_freq_smth1 = if mode == OpusMode::CeltOnly {
338                silk_lin2log(60) << 8
339            } else {
340                self.silk_enc.s_cmn.variable_hp_smth1_q15
341            };
342
343            const VARIABLE_HP_SMTH_COEF2_Q16: i32 = 984;
344            self.variable_hp_smth2_q15 = silk_smlawb(
345                self.variable_hp_smth2_q15,
346                hp_freq_smth1 - self.variable_hp_smth2_q15,
347                VARIABLE_HP_SMTH_COEF2_Q16,
348            );
349
350            let cutoff_hz = silk_log2lin(silk_rshift(self.variable_hp_smth2_q15, 8));
351
352            let required_size = frame_size * self.channels;
353            self.buf_filtered.resize(required_size, 0);
354            if self.application == Application::Voip {
355                hp_cutoff(
356                    input,
357                    cutoff_hz,
358                    &mut self.buf_filtered,
359                    &mut self.hp_mem,
360                    frame_size,
361                    self.channels,
362                    self.sampling_rate,
363                );
364            } else {
365                for (i, &x) in input.iter().enumerate() {
366                    self.buf_filtered[i] = (x * 32768.0).clamp(-32768.0, 32767.0) as i16;
367                }
368            }
369
370            let input_i16 = &self.buf_filtered;
371
372            let silk_input: &[i16] = if mode == OpusMode::SilkOnly && self.sampling_rate > 16000 {
373                // SilkOnly mode with >16kHz input needs downsampling to 16kHz
374                if self.sampling_rate == 48000 {
375                    let stage1_size = frame_size / 2;
376                    let mut stage1_buf = [0i16; 480]; // max frame_size/2 for 48kHz/20ms
377                    silk_resampler_down2(
378                        &mut self.down2_state_first,
379                        &mut stage1_buf[..stage1_size],
380                        input_i16,
381                        frame_size as i32,
382                    );
383                    let silk_frame_size = stage1_size * 2 / 3;
384                    self.buf_silk_input.resize(silk_frame_size, 0);
385                    silk_resampler_down2_3(
386                        &mut self.down2_3_state,
387                        &mut self.buf_silk_input,
388                        &stage1_buf[..stage1_size],
389                        stage1_size as i32,
390                    );
391                    &self.buf_silk_input
392                } else if self.sampling_rate == 24000 {
393                    let silk_frame_size = frame_size * 2 / 3;
394                    self.buf_silk_input.resize(silk_frame_size, 0);
395                    silk_resampler_down2_3(
396                        &mut self.down2_3_state,
397                        &mut self.buf_silk_input,
398                        input_i16,
399                        frame_size as i32,
400                    );
401                    &self.buf_silk_input
402                } else {
403                    // Fallback: just use the input directly
404                    input_i16
405                }
406            } else if mode == OpusMode::SilkOnly && self.channels == 2 {
407                let frame_length = input_i16.len() / 2;
408                self.buf_stereo_mid.resize(frame_length, 0);
409                self.buf_stereo_side.resize(frame_length, 0);
410                for i in 0..frame_length {
411                    let l = input_i16[2 * i] as i32;
412                    let r = input_i16[2 * i + 1] as i32;
413                    self.buf_stereo_mid[i] = ((l + r) / 2) as i16;
414                    self.buf_stereo_side[i] = (l - r) as i16;
415                }
416
417                self.silk_enc.stereo.side.resize(frame_length, 0);
418                self.silk_enc
419                    .stereo
420                    .side
421                    .copy_from_slice(&self.buf_stereo_side[..frame_length]);
422                &self.buf_stereo_mid
423            } else if mode == OpusMode::Hybrid && self.sampling_rate > 16000 {
424                if self.sampling_rate == 48000 {
425                    // Reference: use AR2+FIR 1:3 downsampler (silk_resampler_private_down_FIR)
426                    // This has complementary group delay to the decoder's IirFir upsampler,
427                    // avoiding waveform distortion from IIR cascade down2+down2_3.
428                    let silk_frame_size = frame_size / 3; // 960/3 = 320 at 16kHz
429                    self.buf_silk_input.resize(silk_frame_size, 0);
430                    silk::resampler::silk_resampler_down_1_3(
431                        &mut self.down_1_3_state,
432                        &mut self.buf_silk_input,
433                        input_i16,
434                    );
435                } else {
436                    let silk_frame_size = frame_size * 2 / 3;
437                    self.buf_silk_input.resize(silk_frame_size, 0);
438                    silk_resampler_down2_3(
439                        &mut self.down2_3_state,
440                        &mut self.buf_silk_input,
441                        input_i16,
442                        frame_size as i32,
443                    );
444                }
445                &self.buf_silk_input
446            } else {
447                input_i16
448            };
449
450            let mut pn_bytes = 0;
451
452            let silk_rate_for_calc = if mode == OpusMode::Hybrid {
453                16000
454            } else {
455                self.sampling_rate
456            };
457            let silk_frame_len = silk_input.len();
458            // For Hybrid mode, use the reference rate table to allocate SILK bits.
459            // At 32kbps total, the table gives SILK 22kbps (440 bits/20ms frame),
460            // much more than the naive 40% split (252 bits) we had before.
461            let silk_bitrate = if mode == OpusMode::Hybrid {
462                let frame_duration_ms = frame_size as i32 * 1000 / self.sampling_rate;
463                let frame20ms = frame_duration_ms >= 20;
464                compute_silk_rate_for_hybrid(self.bitrate_bps, frame20ms)
465            } else {
466                (8i64 * (n_bytes - 1) as i64 * silk_rate_for_calc as i64 / silk_frame_len as i64)
467                    as i32
468            };
469            let silk_max_bits = if mode == OpusMode::Hybrid {
470                (silk_bitrate as i64 * silk_frame_len as i64 / silk_rate_for_calc as i64) as i32
471            } else {
472                ((n_bytes - 1) * 8) as i32
473            };
474            let ret = silk_encode(
475                &mut *self.silk_enc,
476                silk_input,
477                silk_input.len(),
478                &mut rc,
479                &mut pn_bytes,
480                silk_bitrate,
481                silk_max_bits,
482                if self.use_cbr { 1 } else { 0 },
483                1,
484            );
485            if ret != 0 {
486                return Err("SILK encoding failed");
487            }
488        }
489
490        if mode == OpusMode::CeltOnly || mode == OpusMode::Hybrid {
491            let start_band = if mode == OpusMode::Hybrid { 17 } else { 0 };
492            if mode == OpusMode::Hybrid {
493                // For Hybrid mode, pass total packet bits; clt_compute_allocation
494                // subtracts rc.tell() (= silk bits used) internally.
495                let total_packet_bits = ((n_bytes - 1) * 8) as i32;
496                self.celt_enc.encode_with_budget(
497                    input,
498                    frame_size,
499                    &mut rc,
500                    start_band,
501                    total_packet_bits,
502                );
503            } else {
504                // For CeltOnly mode, use the full packet budget
505                let total_packet_bits = ((n_bytes - 1) * 8) as i32;
506                self.celt_enc.encode_with_budget(
507                    input,
508                    frame_size,
509                    &mut rc,
510                    start_band,
511                    total_packet_bits,
512                );
513            }
514        }
515
516        rc.done();
517
518        let silk_payload: Vec<u8> = if mode == OpusMode::SilkOnly {
519            let mut combined = Vec::with_capacity(rc.storage as usize);
520            combined.extend_from_slice(&rc.buf[0..rc.offs as usize]);
521            combined.extend_from_slice(
522                &rc.buf[(rc.storage - rc.end_offs) as usize..rc.storage as usize],
523            );
524
525            while combined.len() > 2 && combined[combined.len() - 1] == 0 {
526                combined.pop();
527            }
528
529            combined
530        } else {
531            Vec::new()
532        };
533
534        let total_bytes = if mode == OpusMode::SilkOnly {
535            silk_payload.len()
536        } else {
537            n_bytes
538        };
539
540        let payload_bytes = total_bytes.min(output.len() - 1);
541        let ret_with_toc = payload_bytes + 1;
542
543        if mode == OpusMode::SilkOnly {
544            let target_total = if self.use_cbr {
545                n_bytes.min(output.len())
546            } else {
547                ret_with_toc
548            };
549
550            let silk_len = silk_payload.len();
551
552            if silk_len + 1 >= target_total {
553                output[0] = toc;
554                let copy_len = (target_total - 1).min(silk_len);
555                output[1..1 + copy_len].copy_from_slice(&silk_payload[..copy_len]);
556                return Ok(target_total.min(output.len()));
557            }
558
559            output[0] = toc | 0x03;
560
561            if silk_len + 2 >= target_total {
562                output[1] = 0x01;
563                let copy_len = (target_total - 2).min(silk_len);
564                output[2..2 + copy_len].copy_from_slice(&silk_payload[..copy_len]);
565                return Ok(target_total.min(output.len()));
566            }
567
568            let pad_amount = target_total - silk_len - 2;
569            output[1] = 0x41;
570
571            let nb_255s = (pad_amount - 1) / 255;
572            let mut ptr = 2;
573            for _ in 0..nb_255s {
574                output[ptr] = 255;
575                ptr += 1;
576            }
577            output[ptr] = (pad_amount - 255 * nb_255s - 1) as u8;
578            ptr += 1;
579
580            output[ptr..ptr + silk_len].copy_from_slice(&silk_payload);
581            ptr += silk_len;
582
583            let fill_end = target_total.min(output.len());
584            for byte in output[ptr..fill_end].iter_mut() {
585                *byte = 0;
586            }
587
588            return Ok(target_total.min(output.len()));
589        }
590
591        if mode == OpusMode::SilkOnly {
592            output[1..1 + payload_bytes].copy_from_slice(&silk_payload[..payload_bytes]);
593            Ok(ret_with_toc)
594        } else {
595            // For Hybrid/CeltOnly: use the full RC buffer (front+zeros+back layout).
596            // rc.storage = n_bytes - 1, so rc.buf[..n_bytes-1] has the correct layout
597            // where the decoder can read forward from buf[0] and backward from buf[n_bytes-2].
598            // This ensures encoder total_bits == decoder total_bits (both use (n_bytes-1)*8).
599            let payload_len = n_bytes - 1;
600            output[1..1 + payload_len].copy_from_slice(&rc.buf[..payload_len]);
601            Ok(n_bytes)
602        }
603    }
604}
605
606pub struct OpusDecoder {
607    celt_dec: CeltDecoder,
608    silk_dec: silk::dec_api::SilkDecoder,
609    sampling_rate: i32,
610    channels: usize,
611
612    prev_mode: Option<OpusMode>,
613    frame_size: usize,
614
615    bandwidth: Bandwidth,
616
617    stream_channels: usize,
618
619    silk_resampler: silk::resampler::SilkResampler,
620
621    prev_internal_rate: i32,
622
623    // Pre-allocated decode working buffers (avoid per-frame heap allocation)
624    w_pcm_i16: Vec<i16>,       // SILK internal PCM: max 16kHz×20ms×2ch = 640
625    w_silk_out: Vec<f32>,      // SILK→f32 output: max frame_size×channels
626    w_pcm_resampled: Vec<i16>, // resampler i16 output: max frame_size×channels
627    w_celt_out: Vec<f32>,      // CELT output: max frame_size×channels
628}
629
630impl OpusDecoder {
631    pub fn new(sampling_rate: i32, channels: usize) -> Result<Self, &'static str> {
632        if ![8000, 12000, 16000, 24000, 48000].contains(&sampling_rate) {
633            return Err("Invalid sampling rate");
634        }
635        if ![1, 2].contains(&channels) {
636            return Err("Invalid number of channels");
637        }
638
639        let mode = modes::default_mode();
640        let celt_dec = CeltDecoder::new(mode, channels);
641
642        let mut silk_dec = silk::dec_api::SilkDecoder::new();
643        silk_dec.init(sampling_rate.min(16000), channels as i32);
644        silk_dec.channel_state[0].fs_api_hz = sampling_rate;
645
646        Ok(Self {
647            celt_dec,
648            silk_dec,
649            sampling_rate,
650            channels,
651            prev_mode: None,
652            frame_size: 0,
653            bandwidth: Bandwidth::Auto,
654            stream_channels: channels,
655            silk_resampler: silk::resampler::SilkResampler::default(),
656            prev_internal_rate: 0,
657            // 640 = 16kHz × 20ms × 2ch (SILK max internal frame)
658            w_pcm_i16: vec![0i16; 640],
659            // 5760 × 2 = 48kHz × 120ms × 2ch (max output frame)
660            w_silk_out: vec![0.0f32; 5760 * channels],
661            w_pcm_resampled: vec![0i16; 5760 * channels],
662            w_celt_out: vec![0.0f32; 5760 * channels],
663        })
664    }
665
666    pub fn decode(
667        &mut self,
668        input: &[u8],
669        frame_size: usize,
670        output: &mut [f32],
671    ) -> Result<usize, &'static str> {
672        if input.is_empty() {
673            return Err("Input packet empty");
674        }
675
676        let toc = input[0];
677        let mode = mode_from_toc(toc);
678        let packet_channels = channels_from_toc(toc);
679        let bandwidth = bandwidth_from_toc(toc);
680        let frame_duration_ms = frame_duration_ms_from_toc(toc);
681
682        if packet_channels != self.channels {
683            return Err("Channel count mismatch between packet and decoder");
684        }
685
686        let code = toc & 0x03;
687        let payload_data;
688
689        match code {
690            0 => {
691                payload_data = &input[1..];
692            }
693            3 => {
694                if input.len() < 2 {
695                    return Err("Code 3 packet too short");
696                }
697                let count_byte = input[1];
698                let _frame_count = (count_byte & 0x3F) as usize;
699                let padding_flag = (count_byte & 0x40) != 0;
700
701                let mut ptr = 2usize;
702                if padding_flag {
703                    let mut pad_len = 0usize;
704                    loop {
705                        if ptr >= input.len() {
706                            return Err("Padding overflow");
707                        }
708                        let p = input[ptr] as usize;
709                        ptr += 1;
710                        if p == 255 {
711                            pad_len += 254;
712                        } else {
713                            pad_len += p;
714                            break;
715                        }
716                    }
717
718                    let end = input.len().saturating_sub(pad_len);
719                    if ptr > end {
720                        return Err("Padding exceeds packet");
721                    }
722                    payload_data = &input[ptr..end];
723                } else {
724                    payload_data = &input[ptr..];
725                }
726            }
727            _ => {
728                payload_data = &input[1..];
729            }
730        }
731
732        self.frame_size = frame_size;
733        self.bandwidth = bandwidth;
734        self.stream_channels = packet_channels;
735
736        match mode {
737            OpusMode::SilkOnly => {
738                let internal_sample_rate = match bandwidth {
739                    Bandwidth::Narrowband => 8000,
740                    Bandwidth::Mediumband => 12000,
741                    Bandwidth::Wideband => 16000,
742                    _ => 16000,
743                };
744
745                let mut rc = RangeCoder::new_decoder(payload_data);
746                let internal_frame_size =
747                    (frame_duration_ms * internal_sample_rate / 1000) as usize;
748                let pcm_i16_len = internal_frame_size * self.channels;
749                debug_assert!(pcm_i16_len <= self.w_pcm_i16.len());
750
751                let payload_size_ms = frame_duration_ms;
752
753                let ret = {
754                    let (silk_dec, pcm_i16) = (&mut self.silk_dec, &mut self.w_pcm_i16);
755                    silk_dec.decode(
756                        &mut rc,
757                        &mut pcm_i16[..pcm_i16_len],
758                        silk::decode_frame::FLAG_DECODE_NORMAL,
759                        true,
760                        payload_size_ms,
761                        internal_sample_rate,
762                    )
763                };
764
765                if ret < 0 {
766                    return Err("SILK decoding failed");
767                }
768
769                let decoded_samples = ret as usize;
770
771                if self.sampling_rate == internal_sample_rate {
772                    let n = decoded_samples.min(frame_size).min(output.len());
773                    for i in 0..n {
774                        output[i] = self.w_pcm_i16[i] as f32 / 32768.0;
775                    }
776                    self.prev_mode = Some(OpusMode::SilkOnly);
777                    Ok(n)
778                } else {
779                    if internal_sample_rate != self.prev_internal_rate {
780                        self.silk_resampler
781                            .init(internal_sample_rate, self.sampling_rate);
782                        self.prev_internal_rate = internal_sample_rate;
783                    }
784
785                    let ratio = self.sampling_rate as f64 / internal_sample_rate as f64;
786                    let out_len = ((decoded_samples as f64 * ratio) as usize).min(frame_size);
787                    debug_assert!(out_len <= self.w_pcm_resampled.len());
788                    let ret2 = {
789                        let (silk_res, pcm_i16, pcm_out) = (
790                            &mut self.silk_resampler,
791                            &self.w_pcm_i16,
792                            &mut self.w_pcm_resampled,
793                        );
794                        silk_res.process(
795                            &mut pcm_out[..out_len],
796                            &pcm_i16[..decoded_samples],
797                            decoded_samples as i32,
798                        )
799                    };
800                    let _ = ret2;
801
802                    let n = out_len.min(output.len());
803                    for i in 0..n {
804                        output[i] = self.w_pcm_resampled[i] as f32 / 32768.0;
805                    }
806                    self.prev_mode = Some(OpusMode::SilkOnly);
807                    Ok(n)
808                }
809            }
810
811            OpusMode::CeltOnly => {
812                self.celt_dec.decode(payload_data, frame_size, output);
813                self.prev_mode = Some(OpusMode::CeltOnly);
814                Ok(frame_size)
815            }
816
817            OpusMode::Hybrid => {
818                let internal_sample_rate = match bandwidth {
819                    Bandwidth::Superwideband => 16000,
820                    Bandwidth::Fullband => 16000,
821                    _ => 16000,
822                };
823
824                let mut rc = RangeCoder::new_decoder(payload_data);
825                let internal_frame_size =
826                    (frame_duration_ms * internal_sample_rate / 1000) as usize;
827                let pcm_silk_i16_len = internal_frame_size * self.channels;
828                debug_assert!(pcm_silk_i16_len <= self.w_pcm_i16.len());
829
830                let ret = {
831                    let (silk_dec, pcm_i16) = (&mut self.silk_dec, &mut self.w_pcm_i16);
832                    silk_dec.decode(
833                        &mut rc,
834                        &mut pcm_i16[..pcm_silk_i16_len],
835                        silk::decode_frame::FLAG_DECODE_NORMAL,
836                        true,
837                        frame_duration_ms,
838                        internal_sample_rate,
839                    )
840                };
841
842                let silk_out_len = frame_size * self.channels;
843                debug_assert!(silk_out_len <= self.w_silk_out.len());
844                self.w_silk_out[..silk_out_len].fill(0.0);
845                if ret > 0 {
846                    let decoded_samples = ret as usize;
847                    if self.sampling_rate == internal_sample_rate {
848                        let n = decoded_samples.min(frame_size);
849                        for i in 0..n {
850                            self.w_silk_out[i] = self.w_pcm_i16[i] as f32 / 32768.0;
851                        }
852                    } else {
853                        if internal_sample_rate != self.prev_internal_rate {
854                            self.silk_resampler
855                                .init(internal_sample_rate, self.sampling_rate);
856                            self.prev_internal_rate = internal_sample_rate;
857                        }
858                        let ratio = self.sampling_rate as f64 / internal_sample_rate as f64;
859                        let out_len = ((decoded_samples as f64 * ratio) as usize).min(frame_size);
860                        debug_assert!(out_len <= self.w_pcm_resampled.len());
861                        {
862                            let (silk_res, pcm_i16, pcm_resampled) = (
863                                &mut self.silk_resampler,
864                                &self.w_pcm_i16,
865                                &mut self.w_pcm_resampled,
866                            );
867                            silk_res.process(
868                                &mut pcm_resampled[..out_len],
869                                &pcm_i16[..decoded_samples],
870                                decoded_samples as i32,
871                            );
872                        }
873                        for i in 0..out_len.min(frame_size) {
874                            self.w_silk_out[i] = self.w_pcm_resampled[i] as f32 / 32768.0;
875                        }
876                    }
877                }
878
879                let celt_out_len = frame_size * self.channels;
880                debug_assert!(celt_out_len <= self.w_celt_out.len());
881                {
882                    let total_bits = (payload_data.len() * 8) as i32;
883                    let (celt_dec, celt_out) = (&mut self.celt_dec, &mut self.w_celt_out);
884                    celt_dec.decode_from_range_coder(
885                        &mut rc,
886                        total_bits,
887                        frame_size,
888                        &mut celt_out[..celt_out_len],
889                        17,
890                    );
891                }
892
893                let n = frame_size.min(output.len());
894                for i in 0..n {
895                    output[i] = (self.w_silk_out[i] + self.w_celt_out[i]).clamp(-1.0, 1.0);
896                }
897
898                self.prev_mode = Some(OpusMode::Hybrid);
899                Ok(n)
900            }
901        }
902    }
903}
904
905fn frame_rate_from_params(sampling_rate: i32, frame_size: usize) -> Option<i32> {
906    let frame_size = frame_size as i32;
907    if frame_size == 0 || sampling_rate % frame_size != 0 {
908        return None;
909    }
910    Some(sampling_rate / frame_size)
911}
912
913fn gen_toc(mode: OpusMode, frame_rate: i32, bandwidth: Bandwidth, channels: usize) -> u8 {
914    let mut rate = frame_rate;
915    let mut period = 0;
916    while rate < 400 {
917        rate <<= 1;
918        period += 1;
919    }
920
921    let mut toc = match mode {
922        OpusMode::SilkOnly => {
923            let bw = (bandwidth as i32 - Bandwidth::Narrowband as i32) << 5;
924            let per = (period - 2) << 3;
925            (bw | per) as u8
926        }
927        OpusMode::CeltOnly => {
928            let mut tmp = bandwidth as i32 - Bandwidth::Mediumband as i32;
929            if tmp < 0 {
930                tmp = 0;
931            }
932            let per = period << 3;
933            (0x80 | (tmp << 5) | per) as u8
934        }
935        OpusMode::Hybrid => {
936            // Hybrid config per RFC 6716:
937            // 12: 10ms SWB, 13: 20ms SWB, 14: 10ms FB, 15: 20ms FB
938            let base_config = if bandwidth == Bandwidth::Superwideband {
939                12 // SWB starts at 12
940            } else {
941                14 // FB starts at 14
942            };
943            let period_offset = if frame_rate >= 100 { 0 } else { 1 };
944            ((base_config + period_offset) << 3) as u8
945        }
946    };
947
948    if channels == 2 {
949        toc |= 0x04;
950    }
951    toc
952}
953
954fn mode_from_toc(toc: u8) -> OpusMode {
955    if toc & 0x80 != 0 {
956        OpusMode::CeltOnly
957    } else if toc & 0x60 == 0x60 {
958        OpusMode::Hybrid
959    } else {
960        OpusMode::SilkOnly
961    }
962}
963
964fn bandwidth_from_toc(toc: u8) -> Bandwidth {
965    let mode = mode_from_toc(toc);
966    match mode {
967        OpusMode::SilkOnly => {
968            let bw_bits = (toc >> 5) & 0x03;
969            match bw_bits {
970                0 => Bandwidth::Narrowband,
971                1 => Bandwidth::Mediumband,
972                2 => Bandwidth::Wideband,
973                _ => Bandwidth::Wideband,
974            }
975        }
976        OpusMode::Hybrid => {
977            let bw_bit = (toc >> 4) & 0x01;
978            if bw_bit == 0 {
979                Bandwidth::Superwideband
980            } else {
981                Bandwidth::Fullband
982            }
983        }
984        OpusMode::CeltOnly => {
985            let bw_bits = (toc >> 5) & 0x03;
986            match bw_bits {
987                0 => Bandwidth::Mediumband,
988                1 => Bandwidth::Wideband,
989                2 => Bandwidth::Superwideband,
990                3 => Bandwidth::Fullband,
991                _ => Bandwidth::Fullband,
992            }
993        }
994    }
995}
996
997fn frame_duration_ms_from_toc(toc: u8) -> i32 {
998    let mode = mode_from_toc(toc);
999    match mode {
1000        OpusMode::SilkOnly => {
1001            let config = (toc >> 3) & 0x03;
1002            match config {
1003                0 => 10,
1004                1 => 20,
1005                2 => 40,
1006                3 => 60,
1007                _ => 20,
1008            }
1009        }
1010        OpusMode::Hybrid => {
1011            let config = (toc >> 3) & 0x01;
1012            if config == 0 { 10 } else { 20 }
1013        }
1014        OpusMode::CeltOnly => {
1015            let config = (toc >> 3) & 0x03;
1016            match config {
1017                0 => 2,
1018                1 => 5,
1019                2 => 10,
1020                3 => 20,
1021                _ => 20,
1022            }
1023        }
1024    }
1025}
1026
1027fn channels_from_toc(toc: u8) -> usize {
1028    if toc & 0x04 != 0 { 2 } else { 1 }
1029}
1030
1031#[cfg(test)]
1032mod tests {
1033    use super::*;
1034
1035    fn frame_size_from_toc(toc: u8, sampling_rate: i32) -> Option<usize> {
1036        let mode = mode_from_toc(toc);
1037        match mode {
1038            OpusMode::CeltOnly => {
1039                let period = ((toc >> 3) & 0x03) as i32;
1040                let frame_rate = 400 >> period;
1041                if frame_rate == 0 || sampling_rate % frame_rate != 0 {
1042                    return None;
1043                }
1044                Some((sampling_rate / frame_rate) as usize)
1045            }
1046            OpusMode::SilkOnly => {
1047                let duration_ms = frame_duration_ms_from_toc(toc);
1048                Some((sampling_rate as i64 * duration_ms as i64 / 1000) as usize)
1049            }
1050            OpusMode::Hybrid => {
1051                let duration_ms = frame_duration_ms_from_toc(toc);
1052                Some((sampling_rate as i64 * duration_ms as i64 / 1000) as usize)
1053            }
1054        }
1055    }
1056
1057    #[test]
1058    fn gen_toc_matches_celt_reference_values() {
1059        let sampling_rate = 48_000;
1060        let cases = [
1061            (120usize, 0xE0u8),
1062            (240usize, 0xE8u8),
1063            (480usize, 0xF0u8),
1064            (960usize, 0xF8u8),
1065        ];
1066
1067        for (frame_size, expected_toc) in cases {
1068            let frame_rate = frame_rate_from_params(sampling_rate, frame_size).unwrap();
1069            let toc = gen_toc(OpusMode::CeltOnly, frame_rate, Bandwidth::Fullband, 1);
1070            assert_eq!(
1071                toc, expected_toc,
1072                "frame_size {} expected TOC {:02X} got {:02X}",
1073                frame_size, expected_toc, toc
1074            );
1075            let decoded_size = frame_size_from_toc(toc, sampling_rate).unwrap();
1076            assert_eq!(decoded_size, frame_size);
1077        }
1078
1079        let stereo_toc = gen_toc(
1080            OpusMode::CeltOnly,
1081            frame_rate_from_params(sampling_rate, 960).unwrap(),
1082            Bandwidth::Fullband,
1083            2,
1084        );
1085        assert_eq!(channels_from_toc(stereo_toc), 2);
1086    }
1087
1088    #[test]
1089    fn test_celt_decoder_large_frame_sizes() {
1090        let sampling_rate = 48000;
1091        let channels = 1;
1092
1093        let mut decoder = OpusDecoder::new(sampling_rate, channels).unwrap();
1094
1095        let frame_sizes = [120, 240, 480, 960];
1096
1097        for frame_size in frame_sizes {
1098            let toc = gen_toc(
1099                OpusMode::CeltOnly,
1100                frame_rate_from_params(sampling_rate, frame_size).unwrap(),
1101                Bandwidth::Fullband,
1102                channels,
1103            );
1104            let packet = [toc, 0, 0, 0, 0];
1105
1106            let mut output = vec![0.0f32; frame_size * channels];
1107
1108            let _ = decoder.decode(&packet, frame_size, &mut output);
1109        }
1110
1111        let channels = 2;
1112        let mut decoder = OpusDecoder::new(sampling_rate, channels).unwrap();
1113
1114        for frame_size in frame_sizes {
1115            let toc = gen_toc(
1116                OpusMode::CeltOnly,
1117                frame_rate_from_params(sampling_rate, frame_size).unwrap(),
1118                Bandwidth::Fullband,
1119                channels,
1120            );
1121            let packet = [toc, 0, 0, 0, 0];
1122
1123            let mut output = vec![0.0f32; frame_size * channels];
1124            let _ = decoder.decode(&packet, frame_size, &mut output);
1125        }
1126    }
1127
1128    #[test]
1129    fn test_celt_decoder_edge_case_frame_sizes() {
1130        let sampling_rate = 48000;
1131        let channels = 1;
1132        let mut decoder = OpusDecoder::new(sampling_rate, channels).unwrap();
1133
1134        let edge_sizes = [2048, 2167, 2168, 2169, 2880, 3072];
1135
1136        for frame_size in edge_sizes {
1137            let mut output = vec![0.0f32; frame_size * channels];
1138
1139            let _ = decoder.decode(&[0x80, 0, 0, 0], frame_size, &mut output);
1140        }
1141    }
1142}