Skip to main content

opus_rs/
lib.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2
3pub mod bands;
4pub mod celt;
5pub mod celt_lpc;
6pub mod hp_cutoff;
7pub mod kiss_fft;
8pub mod mdct;
9pub mod modes;
10pub mod pitch;
11pub mod pvq;
12pub mod quant_bands;
13pub mod range_coder;
14pub mod rate;
15pub mod silk;
16
17pub use silk::{SilkResampler, SilkResamplerDown1_3, SilkResamplerDown1_6};
18
19pub use celt::{CeltDecoder, CeltEncoder};
20use hp_cutoff::hp_cutoff;
21use range_coder::RangeCoder;
22use silk::control_codec::silk_control_encoder;
23use silk::enc_api::silk_encode;
24use silk::init_encoder::silk_init_encoder;
25use silk::lin2log::silk_lin2log;
26use silk::log2lin::silk_log2lin;
27use silk::macros::*;
28use silk::resampler::{silk_resampler_down2, silk_resampler_down2_3};
29use silk::structs::SilkEncoderState;
30
31/// Helper for timing encoder stages. Prints breakdown on drop when STAGE_TIMING env var is set.
32pub struct StageTimerGuard<'a> {
33    times: &'a [std::sync::atomic::AtomicU64; 12],
34    count: &'a std::sync::atomic::AtomicU64,
35    #[allow(dead_code)]
36    start: std::time::Instant,
37}
38impl<'a> StageTimerGuard<'a> {
39    pub fn new(
40        times: &'a [std::sync::atomic::AtomicU64; 12],
41        count: &'a std::sync::atomic::AtomicU64,
42    ) -> Self {
43        Self { times, count, start: std::time::Instant::now() }
44    }
45}
46impl<'a> Drop for StageTimerGuard<'a> {
47    fn drop(&mut self) {
48        use std::sync::atomic::Ordering;
49        let n = self.count.fetch_add(1, Ordering::Relaxed) + 1;
50        // Print every 5000 frames
51        if n % 5000 == 0 && std::env::var("STAGE_TIMING").is_ok() {
52            let names = [
53                "preemph", "transient", "prefilter", "mdct_fwd",
54                "band_energy", "coarse_energy", "tf_spreading", "bit_alloc",
55                "fine_energy", "quant_all_bands", "energy_finalise", "synthesis_mdct_back",
56            ];
57            let total: u64 = self.times.iter().map(|t| t.load(Ordering::Relaxed)).sum();
58            eprintln!("\n=== Stage timing (avg over {} frames, total {:.1}µs) ===", n, total as f64 / n as f64 / 1000.0);
59            for (i, name) in names.iter().enumerate() {
60                let ns = self.times[i].load(Ordering::Relaxed) as f64 / n as f64;
61                let pct = if total > 0 { 100.0 * self.times[i].load(Ordering::Relaxed) as f64 / total as f64 } else { 0.0 };
62                eprintln!("  {:25}: {:6.1} µs ({:5.1}%)", name, ns / 1000.0, pct);
63            }
64        }
65    }
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq)]
69pub enum Application {
70    Voip = 2048,
71    Audio = 2049,
72    RestrictedLowDelay = 2051,
73}
74
75#[derive(Debug, Clone, Copy, PartialEq, Eq)]
76pub enum Bandwidth {
77    Auto = -1000,
78    Narrowband = 1101,
79    Mediumband = 1102,
80    Wideband = 1103,
81    Superwideband = 1104,
82    Fullband = 1105,
83}
84
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86enum OpusMode {
87    SilkOnly,
88    Hybrid,
89    CeltOnly,
90}
91
92pub struct OpusEncoder {
93    celt_enc: CeltEncoder,
94    silk_enc: Box<SilkEncoderState>,
95    application: Application,
96    sampling_rate: i32,
97    channels: usize,
98    bandwidth: Bandwidth,
99    pub bitrate_bps: i32,
100    pub complexity: i32,
101    pub use_cbr: bool,
102
103    pub use_inband_fec: bool,
104
105    pub packet_loss_perc: i32,
106    silk_initialized: bool,
107    mode: OpusMode,
108
109    variable_hp_smth2_q15: i32,
110    hp_mem: Vec<i32>,
111
112    buf_filtered: Vec<i16>,
113    buf_silk_input: Vec<i16>,
114    buf_stereo_mid: Vec<i16>,
115    buf_stereo_side: Vec<i16>,
116    down2_state_first: [i32; 2],
117    down2_state_second: [i32; 2],
118    down2_3_state: [i32; 6],
119    down_1_3_state: silk::resampler::SilkResamplerDown1_3,
120    // Pre-allocated RangeCoder to avoid per-frame heap allocation
121    rc: RangeCoder,
122}
123
124/// Compute the per-channel SILK bitrate (bps) for hybrid mode, matching the reference.
125/// Reference: `compute_silk_rate_for_hybrid` in opus_encoder.c
126fn compute_silk_rate_for_hybrid(rate_bps: i32, frame20ms: bool) -> i32 {
127    // [total_rate_bps, silk_rate_10ms, silk_rate_20ms] — no FEC, single channel
128    const RATE_TABLE: &[(i32, i32, i32)] = &[
129        (0, 0, 0),
130        (12000, 10000, 10000),
131        (16000, 13500, 13500),
132        (20000, 16000, 16000),
133        (24000, 18000, 18000),
134        (32000, 22000, 22000),
135        (64000, 38000, 38000),
136    ];
137    let n = RATE_TABLE.len();
138    let mut i = 1;
139    while i < n && RATE_TABLE[i].0 <= rate_bps {
140        i += 1;
141    }
142    if i == n {
143        let (x_last, r10_last, r20_last) = RATE_TABLE[n - 1];
144        let base = if frame20ms { r20_last } else { r10_last };
145        base + (rate_bps - x_last) / 2
146    } else {
147        let (x0, lo10, lo20) = RATE_TABLE[i - 1];
148        let (x1, hi10, hi20) = RATE_TABLE[i];
149        let (lo, hi) = if frame20ms {
150            (lo20, hi20)
151        } else {
152            (lo10, hi10)
153        };
154        (lo * (x1 - rate_bps) + hi * (rate_bps - x0)) / (x1 - x0)
155    }
156}
157
158#[cfg(test)]
159mod silk_rate_tests {
160    use super::compute_silk_rate_for_hybrid;
161
162    #[test]
163    fn test_reference_table_exact_entries() {
164        // Values from the static table in compute_silk_rate_for_hybrid
165        assert_eq!(compute_silk_rate_for_hybrid(12000, true), 10000);
166        assert_eq!(compute_silk_rate_for_hybrid(16000, true), 13500);
167        assert_eq!(compute_silk_rate_for_hybrid(20000, true), 16000);
168        assert_eq!(compute_silk_rate_for_hybrid(24000, true), 18000);
169        assert_eq!(compute_silk_rate_for_hybrid(32000, true), 22000);
170        assert_eq!(compute_silk_rate_for_hybrid(64000, true), 38000);
171    }
172
173    #[test]
174    fn test_32kbps_gives_22kbps_silk() {
175        // At the reference working point (voip/audio 48k 32kbps 20ms),
176        // SILK should receive 22 000 bps — more than the old 40% split (12 600 bps).
177        assert_eq!(compute_silk_rate_for_hybrid(32000, true), 22000);
178    }
179
180    #[test]
181    fn test_interpolation_between_table_entries() {
182        // Between 16 kbps and 20 kbps: at 18 kbps, linear interp → (13500+16000)/2 = 14750
183        let r = compute_silk_rate_for_hybrid(18000, true);
184        assert_eq!(r, 14750);
185    }
186
187    #[test]
188    fn test_above_table_max_gives_half_extra() {
189        // Above 64 kbps: base 38 000 + (rate - 64 000) / 2
190        let r = compute_silk_rate_for_hybrid(72000, true);
191        assert_eq!(r, 38000 + (72000 - 64000) / 2); // 42 000
192    }
193}
194
195impl OpusEncoder {
196    pub fn new(
197        sampling_rate: i32,
198        channels: usize,
199        application: Application,
200    ) -> Result<Self, &'static str> {
201        if ![8000, 12000, 16000, 24000, 48000].contains(&sampling_rate) {
202            return Err("Invalid sampling rate");
203        }
204        if ![1, 2].contains(&channels) {
205            return Err("Invalid number of channels");
206        }
207
208        let mode = modes::default_mode();
209        let celt_enc = CeltEncoder::new(mode, channels);
210
211        let mut silk_enc = Box::new(SilkEncoderState::default());
212        if silk_init_encoder(&mut *silk_enc, 0) != 0 {
213            return Err("SILK encoder initialization failed");
214        }
215
216        let (opus_mode, bw) = match application {
217            Application::Voip => {
218                let bw = match sampling_rate {
219                    8000 => Bandwidth::Narrowband,
220                    12000 => Bandwidth::Mediumband,
221                    16000 => Bandwidth::Wideband,
222                    24000 => Bandwidth::Superwideband,
223                    48000 => Bandwidth::Fullband,
224                    _ => Bandwidth::Narrowband,
225                };
226                // For >16kHz, use Hybrid mode (SILK for low frequencies, CELT for high)
227                let mode = if sampling_rate > 16000 {
228                    OpusMode::Hybrid
229                } else {
230                    OpusMode::SilkOnly
231                };
232                (mode, bw)
233            }
234            _ => {
235                // Audio mode: use SILK for ≤16kHz, Hybrid for >16kHz.
236                // Hybrid uses SILK for the low-frequency speech content (0-8kHz) and
237                // CELT for the high-frequency extension, which gives much better quality
238                // at modest bitrates (≤50 kbps) than CELT-only across the full band.
239                if sampling_rate <= 16000 {
240                    let bw = match sampling_rate {
241                        8000 => Bandwidth::Narrowband,
242                        12000 => Bandwidth::Mediumband,
243                        _ => Bandwidth::Wideband,
244                    };
245                    (OpusMode::SilkOnly, bw)
246                } else {
247                    let bw = match sampling_rate {
248                        24000 => Bandwidth::Superwideband,
249                        _ => Bandwidth::Fullband,
250                    };
251                    (OpusMode::Hybrid, bw)
252                }
253            }
254        };
255
256        use silk::lin2log::silk_lin2log;
257        let variable_hp_smth2_q15 = silk_lin2log(60) << 8;
258
259        Ok(Self {
260            celt_enc,
261            silk_enc,
262            application,
263            sampling_rate,
264            channels,
265            bandwidth: bw,
266            bitrate_bps: 64000,
267            complexity: 0,
268            use_cbr: false,
269            use_inband_fec: false,
270            packet_loss_perc: 0,
271            silk_initialized: false,
272            mode: opus_mode,
273            variable_hp_smth2_q15,
274            hp_mem: vec![0; channels * 2],
275
276            buf_filtered: Vec::new(),
277            buf_silk_input: Vec::new(),
278            buf_stereo_mid: Vec::new(),
279            buf_stereo_side: Vec::new(),
280            down2_state_first: [0; 2],
281            down2_state_second: [0; 2],
282            down2_3_state: [0; 6],
283            down_1_3_state: silk::resampler::SilkResamplerDown1_3::default(),
284            rc: RangeCoder::new_encoder(1),
285        })
286    }
287
288    pub fn enable_hybrid_mode(&mut self) -> Result<(), &'static str> {
289        if self.sampling_rate != 24000 && self.sampling_rate != 48000 {
290            return Err("Hybrid mode requires 24kHz or 48kHz sampling rate");
291        }
292        let bw = if self.sampling_rate == 48000 {
293            Bandwidth::Fullband
294        } else {
295            Bandwidth::Superwideband
296        };
297        self.mode = OpusMode::Hybrid;
298        self.bandwidth = bw;
299        self.silk_initialized = false;
300        Ok(())
301    }
302
303    pub fn encode(
304        &mut self,
305        input: &[f32],
306        frame_size: usize,
307        output: &mut [u8],
308    ) -> Result<usize, &'static str> {
309        if output.len() < 2 {
310            return Err("Output buffer too small");
311        }
312
313        let frame_rate = frame_rate_from_params(self.sampling_rate, frame_size)
314            .ok_or("Invalid frame size for sampling rate")?;
315
316        // Dynamic mode selection matching C opus behavior:
317        // For Audio application, switch between SILK/CELT based on bitrate threshold.
318        // At 64kbps mono, C opus uses CELT-only (threshold ~17.6kbps for mono Audio).
319        let mode = if self.application == Application::Audio && self.mode == OpusMode::Hybrid {
320            // Mode thresholds from C opus: [mono_voice=64000, mono_music=10000]
321            // For mono Audio with default voice_est=48:
322            // threshold = 10000 + 48^2 * (64000-10000)/16384 ≈ 17600
323            // At 64kbps, equiv_rate (64000) > threshold → CELT_ONLY
324            let equiv_rate = self.bitrate_bps;
325            let threshold: i32 = 17_600; // Approximate threshold for mono Audio
326            if equiv_rate >= threshold {
327                OpusMode::CeltOnly
328            } else {
329                OpusMode::Hybrid
330            }
331        } else {
332            self.mode
333        };
334
335        if mode == OpusMode::CeltOnly {
336            match frame_rate {
337                400 | 200 | 100 | 50 => {}
338                _ => return Err("Unsupported frame size for CELT-only mode"),
339            }
340        }
341
342        let toc = gen_toc(mode, frame_rate, self.bandwidth, self.channels);
343        output[0] = toc;
344
345        let target_bits =
346            (self.bitrate_bps as i64 * frame_size as i64 / self.sampling_rate as i64) as i32;
347        let cbr_bytes = ((target_bits + 4) / 8) as usize;
348        let max_data_bytes = output.len();
349
350        // In VBR mode, cap n_bytes to the bitrate-appropriate size (matching C's
351        // IMIN(8*max_data_bytes, bitrate_to_bits(...))). This ensures the encoder
352        // and decoder agree on total_bits. VBR quality could be improved later
353        // with a reservoir mechanism + ec_enc_shrink for bursting above average.
354        let n_bytes = cbr_bytes.min(max_data_bytes).max(1);
355
356        // Use n_bytes-1 as RC storage so the full buffer can be passed to the decoder.
357        // This ensures encoder total_bits == decoder total_bits (no mismatch from done() overhead).
358        // Reuse pre-allocated RangeCoder to avoid per-frame heap allocation.
359        self.rc.reset_for_encode((n_bytes - 1) as u32);
360
361        if mode == OpusMode::SilkOnly || mode == OpusMode::Hybrid {
362            let silk_fs_khz = if mode == OpusMode::Hybrid {
363                16
364            } else {
365                self.sampling_rate.min(16000) / 1000
366            };
367
368            let frame_ms = (frame_size as i32 * 1000) / self.sampling_rate;
369            if !self.silk_initialized || self.silk_enc.s_cmn.fs_khz != silk_fs_khz as i32 {
370                let silk_init_bitrate = (((n_bytes - 1) * 8) as i64 * self.sampling_rate as i64
371                    / frame_size as i64) as i32;
372                silk_control_encoder(
373                    &mut *self.silk_enc,
374                    silk_fs_khz as i32,
375                    frame_ms,
376                    silk_init_bitrate,
377                    self.complexity,
378                );
379                self.silk_enc.s_cmn.use_cbr = if self.use_cbr { 1 } else { 0 };
380
381                self.silk_enc.s_cmn.n_channels = self.channels as i32;
382                self.silk_initialized = true;
383                self.down2_state_first = [0; 2];
384                self.down2_state_second = [0; 2];
385                self.down2_3_state = [0; 6];
386                self.down_1_3_state = silk::resampler::SilkResamplerDown1_3::default();
387            }
388
389            self.silk_enc.s_cmn.use_in_band_fec = if self.use_inband_fec { 1 } else { 0 };
390            self.silk_enc.s_cmn.packet_loss_perc = self.packet_loss_perc.clamp(0, 100);
391
392            self.silk_enc.s_cmn.lbrr_enabled = if self.use_inband_fec { 1 } else { 0 };
393
394            if self.silk_enc.s_cmn.lbrr_gain_increases == 0 {
395                self.silk_enc.s_cmn.lbrr_gain_increases = 2;
396            }
397
398            let hp_freq_smth1 = if mode == OpusMode::CeltOnly {
399                silk_lin2log(60) << 8
400            } else {
401                self.silk_enc.s_cmn.variable_hp_smth1_q15
402            };
403
404            const VARIABLE_HP_SMTH_COEF2_Q16: i32 = 984;
405            self.variable_hp_smth2_q15 = silk_smlawb(
406                self.variable_hp_smth2_q15,
407                hp_freq_smth1 - self.variable_hp_smth2_q15,
408                VARIABLE_HP_SMTH_COEF2_Q16,
409            );
410
411            let cutoff_hz = silk_log2lin(silk_rshift(self.variable_hp_smth2_q15, 8));
412
413            let required_size = frame_size * self.channels;
414            self.buf_filtered.resize(required_size, 0);
415            if self.application == Application::Voip {
416                hp_cutoff(
417                    input,
418                    cutoff_hz,
419                    &mut self.buf_filtered,
420                    &mut self.hp_mem,
421                    frame_size,
422                    self.channels,
423                    self.sampling_rate,
424                );
425            } else {
426                for (i, &x) in input.iter().enumerate() {
427                    self.buf_filtered[i] = (x * 32768.0).clamp(-32768.0, 32767.0) as i16;
428                }
429            }
430
431            let input_i16 = &self.buf_filtered;
432
433            let silk_input: &[i16] = if mode == OpusMode::SilkOnly && self.sampling_rate > 16000 {
434                // SilkOnly mode with >16kHz input needs downsampling to 16kHz
435                if self.sampling_rate == 48000 {
436                    let stage1_size = frame_size / 2;
437                    let mut stage1_buf = [0i16; 480]; // max frame_size/2 for 48kHz/20ms
438                    silk_resampler_down2(
439                        &mut self.down2_state_first,
440                        &mut stage1_buf[..stage1_size],
441                        input_i16,
442                        frame_size as i32,
443                    );
444                    let silk_frame_size = stage1_size * 2 / 3;
445                    self.buf_silk_input.resize(silk_frame_size, 0);
446                    silk_resampler_down2_3(
447                        &mut self.down2_3_state,
448                        &mut self.buf_silk_input,
449                        &stage1_buf[..stage1_size],
450                        stage1_size as i32,
451                    );
452                    &self.buf_silk_input
453                } else if self.sampling_rate == 24000 {
454                    let silk_frame_size = frame_size * 2 / 3;
455                    self.buf_silk_input.resize(silk_frame_size, 0);
456                    silk_resampler_down2_3(
457                        &mut self.down2_3_state,
458                        &mut self.buf_silk_input,
459                        input_i16,
460                        frame_size as i32,
461                    );
462                    &self.buf_silk_input
463                } else {
464                    // Fallback: just use the input directly
465                    input_i16
466                }
467            } else if mode == OpusMode::SilkOnly && self.channels == 2 {
468                let frame_length = input_i16.len() / 2;
469                self.buf_stereo_mid.resize(frame_length, 0);
470                self.buf_stereo_side.resize(frame_length, 0);
471                for i in 0..frame_length {
472                    let l = input_i16[2 * i] as i32;
473                    let r = input_i16[2 * i + 1] as i32;
474                    self.buf_stereo_mid[i] = ((l + r) / 2) as i16;
475                    self.buf_stereo_side[i] = (l - r) as i16;
476                }
477
478                self.silk_enc.stereo.side.resize(frame_length, 0);
479                self.silk_enc
480                    .stereo
481                    .side
482                    .copy_from_slice(&self.buf_stereo_side[..frame_length]);
483                &self.buf_stereo_mid
484            } else if mode == OpusMode::Hybrid && self.sampling_rate > 16000 {
485                if self.sampling_rate == 48000 {
486                    // Reference: use AR2+FIR 1:3 downsampler (silk_resampler_private_down_FIR)
487                    // This has complementary group delay to the decoder's IirFir upsampler,
488                    // avoiding waveform distortion from IIR cascade down2+down2_3.
489                    let silk_frame_size = frame_size / 3; // 960/3 = 320 at 16kHz
490                    self.buf_silk_input.resize(silk_frame_size, 0);
491                    silk::resampler::silk_resampler_down_1_3(
492                        &mut self.down_1_3_state,
493                        &mut self.buf_silk_input,
494                        input_i16,
495                    );
496                } else {
497                    let silk_frame_size = frame_size * 2 / 3;
498                    self.buf_silk_input.resize(silk_frame_size, 0);
499                    silk_resampler_down2_3(
500                        &mut self.down2_3_state,
501                        &mut self.buf_silk_input,
502                        input_i16,
503                        frame_size as i32,
504                    );
505                }
506                &self.buf_silk_input
507            } else {
508                input_i16
509            };
510
511            let mut pn_bytes = 0;
512
513            let silk_rate_for_calc = if mode == OpusMode::Hybrid {
514                16000
515            } else {
516                self.sampling_rate
517            };
518            let silk_frame_len = silk_input.len();
519            // For Hybrid mode, use the reference rate table to allocate SILK bits.
520            // At 32kbps total, the table gives SILK 22kbps (440 bits/20ms frame),
521            // much more than the naive 40% split (252 bits) we had before.
522            let silk_bitrate = if mode == OpusMode::Hybrid {
523                let frame_duration_ms = frame_size as i32 * 1000 / self.sampling_rate;
524                let frame20ms = frame_duration_ms >= 20;
525                compute_silk_rate_for_hybrid(self.bitrate_bps, frame20ms)
526            } else {
527                (8i64 * (n_bytes - 1) as i64 * silk_rate_for_calc as i64 / silk_frame_len as i64)
528                    as i32
529            };
530            let silk_max_bits = if mode == OpusMode::Hybrid {
531                (silk_bitrate as i64 * silk_frame_len as i64 / silk_rate_for_calc as i64) as i32
532            } else {
533                ((n_bytes - 1) * 8) as i32
534            };
535            let ret = silk_encode(
536                &mut *self.silk_enc,
537                silk_input,
538                silk_input.len(),
539                &mut self.rc,
540                &mut pn_bytes,
541                silk_bitrate,
542                silk_max_bits,
543                if self.use_cbr { 1 } else { 0 },
544                1,
545            );
546            if ret != 0 {
547                return Err("SILK encoding failed");
548            }
549        }
550
551        if mode == OpusMode::CeltOnly || mode == OpusMode::Hybrid {
552            // Propagate complexity to CeltEncoder (controls prefilter skip at low complexity)
553            self.celt_enc.complexity = self.complexity;
554            let start_band = if mode == OpusMode::Hybrid { 17 } else { 0 };
555            let total_packet_bits = ((n_bytes - 1) * 8) as i32;
556            self.celt_enc.encode_with_budget(
557                input,
558                frame_size,
559                &mut self.rc,
560                start_band,
561                total_packet_bits,
562            );
563        }
564
565        self.rc.done();
566
567        let silk_payload: Vec<u8> = if mode == OpusMode::SilkOnly {
568            let mut combined = Vec::with_capacity(self.rc.storage as usize);
569            combined.extend_from_slice(&self.rc.buf[0..self.rc.offs as usize]);
570            combined.extend_from_slice(
571                &self.rc.buf[(self.rc.storage - self.rc.end_offs) as usize..self.rc.storage as usize],
572            );
573
574            while combined.len() > 2 && combined[combined.len() - 1] == 0 {
575                combined.pop();
576            }
577
578            combined
579        } else {
580            Vec::new()
581        };
582
583        let total_bytes = if mode == OpusMode::SilkOnly {
584            silk_payload.len()
585        } else {
586            n_bytes
587        };
588
589        let payload_bytes = total_bytes.min(output.len() - 1);
590        let ret_with_toc = payload_bytes + 1;
591
592        if mode == OpusMode::SilkOnly {
593            let target_total = if self.use_cbr {
594                n_bytes.min(output.len())
595            } else {
596                ret_with_toc
597            };
598
599            let silk_len = silk_payload.len();
600
601            if silk_len + 1 >= target_total {
602                output[0] = toc;
603                let copy_len = (target_total - 1).min(silk_len);
604                output[1..1 + copy_len].copy_from_slice(&silk_payload[..copy_len]);
605                return Ok(target_total.min(output.len()));
606            }
607
608            output[0] = toc | 0x03;
609
610            if silk_len + 2 >= target_total {
611                output[1] = 0x01;
612                let copy_len = (target_total - 2).min(silk_len);
613                output[2..2 + copy_len].copy_from_slice(&silk_payload[..copy_len]);
614                return Ok(target_total.min(output.len()));
615            }
616
617            let pad_amount = target_total - silk_len - 2;
618            output[1] = 0x41;
619
620            let nb_255s = (pad_amount - 1) / 255;
621            let mut ptr = 2;
622            for _ in 0..nb_255s {
623                output[ptr] = 255;
624                ptr += 1;
625            }
626            output[ptr] = (pad_amount - 255 * nb_255s - 1) as u8;
627            ptr += 1;
628
629            output[ptr..ptr + silk_len].copy_from_slice(&silk_payload);
630            ptr += silk_len;
631
632            let fill_end = target_total.min(output.len());
633            for byte in output[ptr..fill_end].iter_mut() {
634                *byte = 0;
635            }
636
637            return Ok(target_total.min(output.len()));
638        }
639
640        if mode == OpusMode::SilkOnly {
641            output[1..1 + payload_bytes].copy_from_slice(&silk_payload[..payload_bytes]);
642            Ok(ret_with_toc)
643        } else {
644            // For Hybrid/CeltOnly: use the full RC buffer (front+zeros+back layout).
645            // rc.storage = n_bytes - 1, so rc.buf[..n_bytes-1] has the correct layout
646            // where the decoder can read forward from buf[0] and backward from buf[n_bytes-2].
647            // This ensures encoder total_bits == decoder total_bits (both use (n_bytes-1)*8).
648            let payload_len = n_bytes - 1;
649            output[1..1 + payload_len].copy_from_slice(&self.rc.buf[..payload_len]);
650            Ok(n_bytes)
651        }
652    }
653}
654
655pub struct OpusDecoder {
656    celt_dec: CeltDecoder,
657    silk_dec: silk::dec_api::SilkDecoder,
658    sampling_rate: i32,
659    channels: usize,
660
661    prev_mode: Option<OpusMode>,
662    frame_size: usize,
663
664    bandwidth: Bandwidth,
665
666    stream_channels: usize,
667
668    silk_resampler: silk::resampler::SilkResampler,
669
670    prev_internal_rate: i32,
671
672    /// Diagnostic: when true, skip CELT contribution in hybrid mode (output SILK only).
673    pub hybrid_skip_celt: bool,
674
675    // Pre-allocated decode working buffers (avoid per-frame heap allocation)
676    w_pcm_i16: Vec<i16>,       // SILK internal PCM: max 16kHz×20ms×2ch = 640
677    w_silk_out: Vec<f32>,      // SILK→f32 output: max frame_size×channels
678    w_pcm_resampled: Vec<i16>, // resampler i16 output: max frame_size×channels
679    w_celt_out: Vec<f32>,      // CELT output: max frame_size×channels
680}
681
682impl OpusDecoder {
683    pub fn new(sampling_rate: i32, channels: usize) -> Result<Self, &'static str> {
684        if ![8000, 12000, 16000, 24000, 48000].contains(&sampling_rate) {
685            return Err("Invalid sampling rate");
686        }
687        if ![1, 2].contains(&channels) {
688            return Err("Invalid number of channels");
689        }
690
691        let mode = modes::default_mode();
692        let celt_dec = CeltDecoder::new(mode, channels);
693
694        let mut silk_dec = silk::dec_api::SilkDecoder::new();
695        silk_dec.init(sampling_rate.min(16000), channels as i32);
696        silk_dec.channel_state[0].fs_api_hz = sampling_rate;
697
698        Ok(Self {
699            celt_dec,
700            silk_dec,
701            sampling_rate,
702            channels,
703            prev_mode: None,
704            frame_size: 0,
705            bandwidth: Bandwidth::Auto,
706            stream_channels: channels,
707            silk_resampler: silk::resampler::SilkResampler::default(),
708            prev_internal_rate: 0,
709            hybrid_skip_celt: false,
710            // 640 = 16kHz × 20ms × 2ch (SILK max internal frame)
711            w_pcm_i16: vec![0i16; 640],
712            // 5760 × 2 = 48kHz × 120ms × 2ch (max output frame)
713            w_silk_out: vec![0.0f32; 5760 * channels],
714            w_pcm_resampled: vec![0i16; 5760 * channels],
715            w_celt_out: vec![0.0f32; 5760 * channels],
716        })
717    }
718
719    pub fn decode(
720        &mut self,
721        input: &[u8],
722        frame_size: usize,
723        output: &mut [f32],
724    ) -> Result<usize, &'static str> {
725        if input.is_empty() {
726            return Err("Input packet empty");
727        }
728
729        let toc = input[0];
730        let mode = mode_from_toc(toc);
731        let packet_channels = channels_from_toc(toc);
732        let bandwidth = bandwidth_from_toc(toc);
733        let frame_duration_ms = frame_duration_ms_from_toc(toc);
734
735        if packet_channels != self.channels {
736            return Err("Channel count mismatch between packet and decoder");
737        }
738
739        let code = toc & 0x03;
740        let payload_data;
741
742        match code {
743            0 => {
744                payload_data = &input[1..];
745            }
746            3 => {
747                if input.len() < 2 {
748                    return Err("Code 3 packet too short");
749                }
750                let count_byte = input[1];
751                let _frame_count = (count_byte & 0x3F) as usize;
752                let padding_flag = (count_byte & 0x40) != 0;
753
754                let mut ptr = 2usize;
755                if padding_flag {
756                    let mut pad_len = 0usize;
757                    loop {
758                        if ptr >= input.len() {
759                            return Err("Padding overflow");
760                        }
761                        let p = input[ptr] as usize;
762                        ptr += 1;
763                        if p == 255 {
764                            pad_len += 254;
765                        } else {
766                            pad_len += p;
767                            break;
768                        }
769                    }
770
771                    let end = input.len().saturating_sub(pad_len);
772                    if ptr > end {
773                        return Err("Padding exceeds packet");
774                    }
775                    payload_data = &input[ptr..end];
776                } else {
777                    payload_data = &input[ptr..];
778                }
779            }
780            _ => {
781                payload_data = &input[1..];
782            }
783        }
784
785        self.frame_size = frame_size;
786        self.bandwidth = bandwidth;
787        self.stream_channels = packet_channels;
788
789        match mode {
790            OpusMode::SilkOnly => {
791                let internal_sample_rate = match bandwidth {
792                    Bandwidth::Narrowband => 8000,
793                    Bandwidth::Mediumband => 12000,
794                    Bandwidth::Wideband => 16000,
795                    _ => 16000,
796                };
797
798                let mut rc = RangeCoder::new_decoder(payload_data);
799                let internal_frame_size =
800                    (frame_duration_ms * internal_sample_rate / 1000) as usize;
801                let pcm_i16_len = internal_frame_size * self.channels;
802                debug_assert!(pcm_i16_len <= self.w_pcm_i16.len());
803
804                let payload_size_ms = frame_duration_ms;
805
806                let ret = {
807                    let (silk_dec, pcm_i16) = (&mut self.silk_dec, &mut self.w_pcm_i16);
808                    silk_dec.decode(
809                        &mut rc,
810                        &mut pcm_i16[..pcm_i16_len],
811                        silk::decode_frame::FLAG_DECODE_NORMAL,
812                        true,
813                        payload_size_ms,
814                        internal_sample_rate,
815                    )
816                };
817
818                if ret < 0 {
819                    return Err("SILK decoding failed");
820                }
821
822                let decoded_samples = ret as usize;
823
824                if self.sampling_rate == internal_sample_rate {
825                    let n = decoded_samples.min(frame_size).min(output.len());
826                    for i in 0..n {
827                        output[i] = self.w_pcm_i16[i] as f32 / 32768.0;
828                    }
829                    self.prev_mode = Some(OpusMode::SilkOnly);
830                    Ok(n)
831                } else {
832                    if internal_sample_rate != self.prev_internal_rate {
833                        self.silk_resampler
834                            .init(internal_sample_rate, self.sampling_rate);
835                        self.prev_internal_rate = internal_sample_rate;
836                    }
837
838                    let ratio = self.sampling_rate as f64 / internal_sample_rate as f64;
839                    let out_len = ((decoded_samples as f64 * ratio) as usize).min(frame_size);
840                    debug_assert!(out_len <= self.w_pcm_resampled.len());
841                    let ret2 = {
842                        let (silk_res, pcm_i16, pcm_out) = (
843                            &mut self.silk_resampler,
844                            &self.w_pcm_i16,
845                            &mut self.w_pcm_resampled,
846                        );
847                        silk_res.process(
848                            &mut pcm_out[..out_len],
849                            &pcm_i16[..decoded_samples],
850                            decoded_samples as i32,
851                        )
852                    };
853                    let _ = ret2;
854
855                    let n = out_len.min(output.len());
856                    for i in 0..n {
857                        output[i] = self.w_pcm_resampled[i] as f32 / 32768.0;
858                    }
859                    self.prev_mode = Some(OpusMode::SilkOnly);
860                    Ok(n)
861                }
862            }
863
864            OpusMode::CeltOnly => {
865                self.celt_dec.decode(payload_data, frame_size, output);
866                self.prev_mode = Some(OpusMode::CeltOnly);
867                Ok(frame_size)
868            }
869
870            OpusMode::Hybrid => {
871                let internal_sample_rate = match bandwidth {
872                    Bandwidth::Superwideband => 16000,
873                    Bandwidth::Fullband => 16000,
874                    _ => 16000,
875                };
876
877                let mut rc = RangeCoder::new_decoder(payload_data);
878                let internal_frame_size =
879                    (frame_duration_ms * internal_sample_rate / 1000) as usize;
880                let pcm_silk_i16_len = internal_frame_size * self.channels;
881                debug_assert!(pcm_silk_i16_len <= self.w_pcm_i16.len());
882
883                let ret = {
884                    let (silk_dec, pcm_i16) = (&mut self.silk_dec, &mut self.w_pcm_i16);
885                    silk_dec.decode(
886                        &mut rc,
887                        &mut pcm_i16[..pcm_silk_i16_len],
888                        silk::decode_frame::FLAG_DECODE_NORMAL,
889                        true,
890                        frame_duration_ms,
891                        internal_sample_rate,
892                    )
893                };
894
895                let silk_out_len = frame_size * self.channels;
896                debug_assert!(silk_out_len <= self.w_silk_out.len());
897                self.w_silk_out[..silk_out_len].fill(0.0);
898                if ret > 0 {
899                    let decoded_samples = ret as usize;
900                    if self.sampling_rate == internal_sample_rate {
901                        let n = decoded_samples.min(frame_size);
902                        for i in 0..n {
903                            self.w_silk_out[i] = self.w_pcm_i16[i] as f32 / 32768.0;
904                        }
905                    } else {
906                        if internal_sample_rate != self.prev_internal_rate {
907                            self.silk_resampler
908                                .init(internal_sample_rate, self.sampling_rate);
909                            self.prev_internal_rate = internal_sample_rate;
910                        }
911                        let ratio = self.sampling_rate as f64 / internal_sample_rate as f64;
912                        let out_len = ((decoded_samples as f64 * ratio) as usize).min(frame_size);
913                        debug_assert!(out_len <= self.w_pcm_resampled.len());
914                        {
915                            let (silk_res, pcm_i16, pcm_resampled) = (
916                                &mut self.silk_resampler,
917                                &self.w_pcm_i16,
918                                &mut self.w_pcm_resampled,
919                            );
920                            silk_res.process(
921                                &mut pcm_resampled[..out_len],
922                                &pcm_i16[..decoded_samples],
923                                decoded_samples as i32,
924                            );
925                        }
926                        for i in 0..out_len.min(frame_size) {
927                            self.w_silk_out[i] = self.w_pcm_resampled[i] as f32 / 32768.0;
928                        }
929                    }
930                }
931
932                let celt_out_len = frame_size * self.channels;
933                debug_assert!(celt_out_len <= self.w_celt_out.len());
934                if self.hybrid_skip_celt {
935                    // Diagnostic: skip CELT decode entirely, output SILK only
936                    self.w_celt_out[..celt_out_len].fill(0.0);
937                } else {
938                    let total_bits = (payload_data.len() * 8) as i32;
939                    let (celt_dec, celt_out) = (&mut self.celt_dec, &mut self.w_celt_out);
940                    celt_dec.decode_from_range_coder(
941                        &mut rc,
942                        total_bits,
943                        frame_size,
944                        &mut celt_out[..celt_out_len],
945                        17,
946                    );
947                }
948
949                let n = frame_size.min(output.len());
950                for i in 0..n {
951                    output[i] = (self.w_silk_out[i] + self.w_celt_out[i]).clamp(-1.0, 1.0);
952                }
953
954                self.prev_mode = Some(OpusMode::Hybrid);
955                Ok(n)
956            }
957        }
958    }
959}
960
961fn frame_rate_from_params(sampling_rate: i32, frame_size: usize) -> Option<i32> {
962    let frame_size = frame_size as i32;
963    if frame_size == 0 || sampling_rate % frame_size != 0 {
964        return None;
965    }
966    Some(sampling_rate / frame_size)
967}
968
969fn gen_toc(mode: OpusMode, frame_rate: i32, bandwidth: Bandwidth, channels: usize) -> u8 {
970    let mut rate = frame_rate;
971    let mut period = 0;
972    while rate < 400 {
973        rate <<= 1;
974        period += 1;
975    }
976
977    let mut toc = match mode {
978        OpusMode::SilkOnly => {
979            let bw = (bandwidth as i32 - Bandwidth::Narrowband as i32) << 5;
980            let per = (period - 2) << 3;
981            (bw | per) as u8
982        }
983        OpusMode::CeltOnly => {
984            let mut tmp = bandwidth as i32 - Bandwidth::Mediumband as i32;
985            if tmp < 0 {
986                tmp = 0;
987            }
988            let per = period << 3;
989            (0x80 | (tmp << 5) | per) as u8
990        }
991        OpusMode::Hybrid => {
992            // Hybrid config per RFC 6716:
993            // 12: 10ms SWB, 13: 20ms SWB, 14: 10ms FB, 15: 20ms FB
994            let base_config = if bandwidth == Bandwidth::Superwideband {
995                12 // SWB starts at 12
996            } else {
997                14 // FB starts at 14
998            };
999            let period_offset = if frame_rate >= 100 { 0 } else { 1 };
1000            ((base_config + period_offset) << 3) as u8
1001        }
1002    };
1003
1004    if channels == 2 {
1005        toc |= 0x04;
1006    }
1007    toc
1008}
1009
1010fn mode_from_toc(toc: u8) -> OpusMode {
1011    if toc & 0x80 != 0 {
1012        OpusMode::CeltOnly
1013    } else if toc & 0x60 == 0x60 {
1014        OpusMode::Hybrid
1015    } else {
1016        OpusMode::SilkOnly
1017    }
1018}
1019
1020fn bandwidth_from_toc(toc: u8) -> Bandwidth {
1021    let mode = mode_from_toc(toc);
1022    match mode {
1023        OpusMode::SilkOnly => {
1024            let bw_bits = (toc >> 5) & 0x03;
1025            match bw_bits {
1026                0 => Bandwidth::Narrowband,
1027                1 => Bandwidth::Mediumband,
1028                2 => Bandwidth::Wideband,
1029                _ => Bandwidth::Wideband,
1030            }
1031        }
1032        OpusMode::Hybrid => {
1033            let bw_bit = (toc >> 4) & 0x01;
1034            if bw_bit == 0 {
1035                Bandwidth::Superwideband
1036            } else {
1037                Bandwidth::Fullband
1038            }
1039        }
1040        OpusMode::CeltOnly => {
1041            let bw_bits = (toc >> 5) & 0x03;
1042            match bw_bits {
1043                0 => Bandwidth::Mediumband,
1044                1 => Bandwidth::Wideband,
1045                2 => Bandwidth::Superwideband,
1046                3 => Bandwidth::Fullband,
1047                _ => Bandwidth::Fullband,
1048            }
1049        }
1050    }
1051}
1052
1053fn frame_duration_ms_from_toc(toc: u8) -> i32 {
1054    let mode = mode_from_toc(toc);
1055    match mode {
1056        OpusMode::SilkOnly => {
1057            let config = (toc >> 3) & 0x03;
1058            match config {
1059                0 => 10,
1060                1 => 20,
1061                2 => 40,
1062                3 => 60,
1063                _ => 20,
1064            }
1065        }
1066        OpusMode::Hybrid => {
1067            let config = (toc >> 3) & 0x01;
1068            if config == 0 { 10 } else { 20 }
1069        }
1070        OpusMode::CeltOnly => {
1071            let config = (toc >> 3) & 0x03;
1072            match config {
1073                0 => 2,
1074                1 => 5,
1075                2 => 10,
1076                3 => 20,
1077                _ => 20,
1078            }
1079        }
1080    }
1081}
1082
1083fn channels_from_toc(toc: u8) -> usize {
1084    if toc & 0x04 != 0 { 2 } else { 1 }
1085}
1086
1087#[cfg(test)]
1088mod tests {
1089    use super::*;
1090
1091    fn frame_size_from_toc(toc: u8, sampling_rate: i32) -> Option<usize> {
1092        let mode = mode_from_toc(toc);
1093        match mode {
1094            OpusMode::CeltOnly => {
1095                let period = ((toc >> 3) & 0x03) as i32;
1096                let frame_rate = 400 >> period;
1097                if frame_rate == 0 || sampling_rate % frame_rate != 0 {
1098                    return None;
1099                }
1100                Some((sampling_rate / frame_rate) as usize)
1101            }
1102            OpusMode::SilkOnly => {
1103                let duration_ms = frame_duration_ms_from_toc(toc);
1104                Some((sampling_rate as i64 * duration_ms as i64 / 1000) as usize)
1105            }
1106            OpusMode::Hybrid => {
1107                let duration_ms = frame_duration_ms_from_toc(toc);
1108                Some((sampling_rate as i64 * duration_ms as i64 / 1000) as usize)
1109            }
1110        }
1111    }
1112
1113    #[test]
1114    fn gen_toc_matches_celt_reference_values() {
1115        let sampling_rate = 48_000;
1116        let cases = [
1117            (120usize, 0xE0u8),
1118            (240usize, 0xE8u8),
1119            (480usize, 0xF0u8),
1120            (960usize, 0xF8u8),
1121        ];
1122
1123        for (frame_size, expected_toc) in cases {
1124            let frame_rate = frame_rate_from_params(sampling_rate, frame_size).unwrap();
1125            let toc = gen_toc(OpusMode::CeltOnly, frame_rate, Bandwidth::Fullband, 1);
1126            assert_eq!(
1127                toc, expected_toc,
1128                "frame_size {} expected TOC {:02X} got {:02X}",
1129                frame_size, expected_toc, toc
1130            );
1131            let decoded_size = frame_size_from_toc(toc, sampling_rate).unwrap();
1132            assert_eq!(decoded_size, frame_size);
1133        }
1134
1135        let stereo_toc = gen_toc(
1136            OpusMode::CeltOnly,
1137            frame_rate_from_params(sampling_rate, 960).unwrap(),
1138            Bandwidth::Fullband,
1139            2,
1140        );
1141        assert_eq!(channels_from_toc(stereo_toc), 2);
1142    }
1143
1144    #[test]
1145    fn test_celt_decoder_large_frame_sizes() {
1146        let sampling_rate = 48000;
1147        let channels = 1;
1148
1149        let mut decoder = OpusDecoder::new(sampling_rate, channels).unwrap();
1150
1151        let frame_sizes = [120, 240, 480, 960];
1152
1153        for frame_size in frame_sizes {
1154            let toc = gen_toc(
1155                OpusMode::CeltOnly,
1156                frame_rate_from_params(sampling_rate, frame_size).unwrap(),
1157                Bandwidth::Fullband,
1158                channels,
1159            );
1160            let packet = [toc, 0, 0, 0, 0];
1161
1162            let mut output = vec![0.0f32; frame_size * channels];
1163
1164            let _ = decoder.decode(&packet, frame_size, &mut output);
1165        }
1166
1167        let channels = 2;
1168        let mut decoder = OpusDecoder::new(sampling_rate, channels).unwrap();
1169
1170        for frame_size in frame_sizes {
1171            let toc = gen_toc(
1172                OpusMode::CeltOnly,
1173                frame_rate_from_params(sampling_rate, frame_size).unwrap(),
1174                Bandwidth::Fullband,
1175                channels,
1176            );
1177            let packet = [toc, 0, 0, 0, 0];
1178
1179            let mut output = vec![0.0f32; frame_size * channels];
1180            let _ = decoder.decode(&packet, frame_size, &mut output);
1181        }
1182    }
1183
1184    #[test]
1185    fn test_celt_decoder_edge_case_frame_sizes() {
1186        let sampling_rate = 48000;
1187        let channels = 1;
1188        let mut decoder = OpusDecoder::new(sampling_rate, channels).unwrap();
1189
1190        let edge_sizes = [2048, 2167, 2168, 2169, 2880, 3072];
1191
1192        for frame_size in edge_sizes {
1193            let mut output = vec![0.0f32; frame_size * channels];
1194
1195            let _ = decoder.decode(&[0x80, 0, 0, 0], frame_size, &mut output);
1196        }
1197    }
1198}