Skip to main content

opus_rs/
lib.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2#![allow(clippy::too_many_arguments)]
3#![allow(clippy::needless_range_loop)]
4
5pub mod bands;
6pub mod celt;
7pub mod celt_lpc;
8pub mod hp_cutoff;
9pub mod kiss_fft;
10pub mod mdct;
11pub mod modes;
12pub mod pitch;
13pub mod pvq;
14pub mod quant_bands;
15pub mod range_coder;
16pub mod rate;
17pub mod silk;
18
19pub use silk::{SilkResampler, SilkResamplerDown1_3, SilkResamplerDown1_6};
20
21pub use celt::{CeltDecoder, CeltEncoder};
22use hp_cutoff::hp_cutoff;
23use range_coder::RangeCoder;
24use silk::control_codec::silk_control_encoder;
25use silk::enc_api::silk_encode;
26use silk::init_encoder::silk_init_encoder;
27use silk::lin2log::silk_lin2log;
28use silk::log2lin::silk_log2lin;
29use silk::macros::*;
30use silk::resampler::{silk_resampler_down2, silk_resampler_down2_3};
31use silk::structs::SilkEncoderState;
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum Application {
35    Voip = 2048,
36    Audio = 2049,
37    RestrictedLowDelay = 2051,
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum Bandwidth {
42    Auto = -1000,
43    Narrowband = 1101,
44    Mediumband = 1102,
45    Wideband = 1103,
46    Superwideband = 1104,
47    Fullband = 1105,
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51enum OpusMode {
52    SilkOnly,
53    Hybrid,
54    CeltOnly,
55}
56
57pub struct OpusEncoder {
58    celt_enc: CeltEncoder,
59    silk_enc: Box<SilkEncoderState>,
60    application: Application,
61    sampling_rate: i32,
62    channels: usize,
63    bandwidth: Bandwidth,
64    pub bitrate_bps: i32,
65    pub complexity: i32,
66    pub use_cbr: bool,
67
68    pub use_inband_fec: bool,
69
70    pub packet_loss_perc: i32,
71    silk_initialized: bool,
72    mode: OpusMode,
73    prev_enc_mode: Option<OpusMode>,
74
75    variable_hp_smth2_q15: i32,
76    hp_mem: Vec<i32>,
77
78    buf_filtered: Vec<i16>,
79    buf_silk_input: Vec<i16>,
80    buf_stereo_mid: Vec<i16>,
81    buf_stereo_side: Vec<i16>,
82    down2_state_first: [i32; 2],
83    down2_state_second: [i32; 2],
84    down2_3_state: [i32; 6],
85    down_1_3_state: silk::resampler::SilkResamplerDown1_3,
86
87    rc: RangeCoder,
88}
89
90fn compute_equiv_rate(
91    bitrate: i32,
92    channels: usize,
93    frame_rate: i32,
94    vbr: bool,
95    complexity: i32,
96    loss: i32,
97) -> i32 {
98    let mut equiv = bitrate;
99    if frame_rate > 50 {
100        equiv -= (40 * channels as i32 + 20) * (frame_rate - 50);
101    }
102    if !vbr {
103        equiv -= equiv / 12;
104    }
105    equiv = equiv * (90 + complexity) / 100;
106    if loss > 0 {
107        equiv -= equiv * loss / (12 * loss + 20);
108    }
109    equiv
110}
111
112fn compute_mode_threshold(application: Application, channels: usize, prev_was_celt: bool) -> i32 {
113    let mode_voice = if channels == 1 { 64000 } else { 44000 };
114    let mode_music = 10000;
115
116    let offset = (mode_voice - mode_music) >> 14;
117    let mut threshold = mode_music + offset;
118
119    if application == Application::Voip {
120        threshold += 8000;
121    }
122
123    if prev_was_celt {
124        threshold -= 4000;
125    } else {
126        threshold += 4000;
127    }
128
129    match application {
130        Application::Audio => threshold = threshold.max(25000),
131        Application::Voip => threshold = threshold.max(55000),
132        Application::RestrictedLowDelay => threshold = 0,
133    }
134
135    threshold
136}
137
138fn compute_silk_rate_for_hybrid(rate_bps: i32, frame20ms: bool) -> i32 {
139    const RATE_TABLE: &[(i32, i32, i32)] = &[
140        (0, 0, 0),
141        (12000, 10000, 10000),
142        (16000, 13500, 13500),
143        (20000, 16000, 16000),
144        (24000, 18000, 18000),
145        (32000, 22000, 22000),
146        (64000, 38000, 38000),
147    ];
148    let n = RATE_TABLE.len();
149    let mut i = 1;
150    while i < n && RATE_TABLE[i].0 <= rate_bps {
151        i += 1;
152    }
153    if i == n {
154        let (x_last, r10_last, r20_last) = RATE_TABLE[n - 1];
155        let base = if frame20ms { r20_last } else { r10_last };
156        base + (rate_bps - x_last) / 2
157    } else {
158        let (x0, lo10, lo20) = RATE_TABLE[i - 1];
159        let (x1, hi10, hi20) = RATE_TABLE[i];
160        let (lo, hi) = if frame20ms {
161            (lo20, hi20)
162        } else {
163            (lo10, hi10)
164        };
165        (lo * (x1 - rate_bps) + hi * (rate_bps - x0)) / (x1 - x0)
166    }
167}
168
169#[cfg(test)]
170mod silk_rate_tests {
171    use super::compute_silk_rate_for_hybrid;
172
173    #[test]
174    fn test_reference_table_exact_entries() {
175        assert_eq!(compute_silk_rate_for_hybrid(12000, true), 10000);
176        assert_eq!(compute_silk_rate_for_hybrid(16000, true), 13500);
177        assert_eq!(compute_silk_rate_for_hybrid(20000, true), 16000);
178        assert_eq!(compute_silk_rate_for_hybrid(24000, true), 18000);
179        assert_eq!(compute_silk_rate_for_hybrid(32000, true), 22000);
180        assert_eq!(compute_silk_rate_for_hybrid(64000, true), 38000);
181    }
182
183    #[test]
184    fn test_32kbps_gives_22kbps_silk() {
185        assert_eq!(compute_silk_rate_for_hybrid(32000, true), 22000);
186    }
187
188    #[test]
189    fn test_interpolation_between_table_entries() {
190        let r = compute_silk_rate_for_hybrid(18000, true);
191        assert_eq!(r, 14750);
192    }
193
194    #[test]
195    fn test_above_table_max_gives_half_extra() {
196        let r = compute_silk_rate_for_hybrid(72000, true);
197        assert_eq!(r, 38000 + (72000 - 64000) / 2);
198    }
199}
200
201impl OpusEncoder {
202    pub fn new(
203        sampling_rate: i32,
204        channels: usize,
205        application: Application,
206    ) -> Result<Self, &'static str> {
207        if ![8000, 12000, 16000, 24000, 48000].contains(&sampling_rate) {
208            return Err("Invalid sampling rate");
209        }
210        if ![1, 2].contains(&channels) {
211            return Err("Invalid number of channels");
212        }
213
214        let mode = modes::default_mode();
215        let celt_enc = CeltEncoder::new(mode, channels);
216
217        let mut silk_enc = Box::new(SilkEncoderState::default());
218        if silk_init_encoder(&mut silk_enc, 0) != 0 {
219            return Err("SILK encoder initialization failed");
220        }
221
222        let (opus_mode, bw) = match application {
223            Application::Voip => {
224                let bw = match sampling_rate {
225                    8000 => Bandwidth::Narrowband,
226                    12000 => Bandwidth::Mediumband,
227                    16000 => Bandwidth::Wideband,
228                    24000 => Bandwidth::Superwideband,
229                    48000 => Bandwidth::Fullband,
230                    _ => Bandwidth::Narrowband,
231                };
232
233                let mode = if sampling_rate > 16000 {
234                    OpusMode::Hybrid
235                } else {
236                    OpusMode::SilkOnly
237                };
238                (mode, bw)
239            }
240            Application::RestrictedLowDelay => {
241                let bw = match sampling_rate {
242                    8000 => Bandwidth::Narrowband,
243                    12000 => Bandwidth::Mediumband,
244                    16000 => Bandwidth::Wideband,
245                    24000 => Bandwidth::Superwideband,
246                    _ => Bandwidth::Fullband,
247                };
248                (OpusMode::CeltOnly, bw)
249            }
250            Application::Audio => {
251                if sampling_rate <= 16000 {
252                    let bw = match sampling_rate {
253                        8000 => Bandwidth::Narrowband,
254                        12000 => Bandwidth::Mediumband,
255                        _ => Bandwidth::Wideband,
256                    };
257                    (OpusMode::SilkOnly, bw)
258                } else {
259                    let bw = match sampling_rate {
260                        24000 => Bandwidth::Superwideband,
261                        _ => Bandwidth::Fullband,
262                    };
263                    (OpusMode::Hybrid, bw)
264                }
265            }
266        };
267
268        use silk::lin2log::silk_lin2log;
269        let variable_hp_smth2_q15 = silk_lin2log(60) << 8;
270
271        Ok(Self {
272            celt_enc,
273            silk_enc,
274            application,
275            sampling_rate,
276            channels,
277            bandwidth: bw,
278            bitrate_bps: 64000,
279            complexity: 9,
280            use_cbr: false,
281            use_inband_fec: false,
282            packet_loss_perc: 0,
283            silk_initialized: false,
284            prev_enc_mode: None,
285            mode: opus_mode,
286            variable_hp_smth2_q15,
287            hp_mem: vec![0; channels * 2],
288
289            buf_filtered: Vec::new(),
290            buf_silk_input: Vec::new(),
291            buf_stereo_mid: Vec::new(),
292            buf_stereo_side: Vec::new(),
293            down2_state_first: [0; 2],
294            down2_state_second: [0; 2],
295            down2_3_state: [0; 6],
296            down_1_3_state: silk::resampler::SilkResamplerDown1_3::default(),
297            rc: RangeCoder::new_encoder(1),
298        })
299    }
300
301    pub fn enable_hybrid_mode(&mut self) -> Result<(), &'static str> {
302        if self.sampling_rate != 24000 && self.sampling_rate != 48000 {
303            return Err("Hybrid mode requires 24kHz or 48kHz sampling rate");
304        }
305        let bw = if self.sampling_rate == 48000 {
306            Bandwidth::Fullband
307        } else {
308            Bandwidth::Superwideband
309        };
310        self.mode = OpusMode::Hybrid;
311        self.bandwidth = bw;
312        self.silk_initialized = false;
313        Ok(())
314    }
315
316    pub fn encode(
317        &mut self,
318        input: &[f32],
319        frame_size: usize,
320        output: &mut [u8],
321    ) -> Result<usize, &'static str> {
322        if output.len() < 2 {
323            return Err("Output buffer too small");
324        }
325
326        let frame_rate = frame_rate_from_params(self.sampling_rate, frame_size)
327            .ok_or("Invalid frame size for sampling rate")?;
328
329        // Mode selection: match C's opus_encode_native() behavior.
330        // C reference auto-selects between SILK_ONLY and CELT_ONLY; Hybrid is
331        // produced afterwards by bandwidth overrides (SILK-only + FB/SWB → Hybrid).
332        let mut mode = if self.application == Application::RestrictedLowDelay {
333            OpusMode::CeltOnly
334        } else {
335            let equiv = compute_equiv_rate(
336                self.bitrate_bps,
337                self.channels,
338                frame_rate,
339                !self.use_cbr,
340                self.complexity,
341                self.packet_loss_perc,
342            );
343            let prev_was_celt = self.prev_enc_mode == Some(OpusMode::CeltOnly);
344            let threshold = compute_mode_threshold(self.application, self.channels, prev_was_celt);
345            if equiv >= threshold {
346                OpusMode::CeltOnly
347            } else {
348                OpusMode::SilkOnly
349            }
350        };
351
352        let curr_bw = self.bandwidth;
353        if mode == OpusMode::SilkOnly
354            && (curr_bw == Bandwidth::Superwideband || curr_bw == Bandwidth::Fullband)
355        {
356            mode = OpusMode::Hybrid;
357        }
358        if mode == OpusMode::Hybrid
359            && (curr_bw == Bandwidth::Narrowband
360                || curr_bw == Bandwidth::Mediumband
361                || curr_bw == Bandwidth::Wideband)
362        {
363            mode = OpusMode::SilkOnly;
364        }
365
366        if mode == OpusMode::CeltOnly {
367            match frame_rate {
368                400 | 200 | 100 | 50 => {}
369                _ => return Err("Unsupported frame size for CELT-only mode"),
370            }
371        }
372
373        let toc = gen_toc(mode, frame_rate, self.bandwidth, self.channels);
374        output[0] = toc;
375
376        let target_bits =
377            (self.bitrate_bps as i64 * frame_size as i64 / self.sampling_rate as i64) as i32;
378        let cbr_bytes = ((target_bits + 4) / 8) as usize;
379        let max_data_bytes = output.len();
380
381        let n_bytes = cbr_bytes.min(max_data_bytes).max(1);
382
383        // Initialize range coder with CBR-capped buffer size, matching C's
384        // ec_enc_init(&enc, data, max_data_bytes-1) where max_data_bytes is CBR-capped.
385        let init_rc_size = n_bytes - 1;
386        self.rc.reset_for_encode(init_rc_size as u32);
387
388        if mode == OpusMode::SilkOnly || mode == OpusMode::Hybrid {
389            let silk_fs_khz = if mode == OpusMode::Hybrid {
390                16
391            } else {
392                self.sampling_rate.min(16000) / 1000
393            };
394
395            let frame_ms = (frame_size as i32 * 1000) / self.sampling_rate;
396            if !self.silk_initialized || self.silk_enc.s_cmn.fs_khz != silk_fs_khz {
397                let silk_init_bitrate = (((n_bytes - 1) * 8) as i64 * self.sampling_rate as i64
398                    / frame_size as i64) as i32;
399                silk_control_encoder(
400                    &mut self.silk_enc,
401                    silk_fs_khz,
402                    frame_ms,
403                    silk_init_bitrate,
404                    self.complexity,
405                );
406                self.silk_enc.s_cmn.use_cbr = if self.use_cbr { 1 } else { 0 };
407
408                self.silk_enc.s_cmn.n_channels = self.channels as i32;
409                self.silk_initialized = true;
410                self.down2_state_first = [0; 2];
411                self.down2_state_second = [0; 2];
412                self.down2_3_state = [0; 6];
413                self.down_1_3_state = silk::resampler::SilkResamplerDown1_3::default();
414            }
415
416            self.silk_enc.s_cmn.use_in_band_fec = if self.use_inband_fec { 1 } else { 0 };
417            self.silk_enc.s_cmn.packet_loss_perc = self.packet_loss_perc.clamp(0, 100);
418
419            self.silk_enc.s_cmn.lbrr_enabled = if self.use_inband_fec { 1 } else { 0 };
420
421            if self.silk_enc.s_cmn.lbrr_gain_increases == 0 {
422                self.silk_enc.s_cmn.lbrr_gain_increases = 2;
423            }
424
425            let hp_freq_smth1 = if mode == OpusMode::CeltOnly {
426                silk_lin2log(60) << 8
427            } else {
428                self.silk_enc.s_cmn.variable_hp_smth1_q15
429            };
430
431            const VARIABLE_HP_SMTH_COEF2_Q16: i32 = 984;
432            self.variable_hp_smth2_q15 = silk_smlawb(
433                self.variable_hp_smth2_q15,
434                hp_freq_smth1 - self.variable_hp_smth2_q15,
435                VARIABLE_HP_SMTH_COEF2_Q16,
436            );
437
438            let cutoff_hz = silk_log2lin(silk_rshift(self.variable_hp_smth2_q15, 8));
439
440            let required_size = frame_size * self.channels;
441            self.buf_filtered.resize(required_size, 0);
442            if self.application == Application::Voip {
443                hp_cutoff(
444                    input,
445                    cutoff_hz,
446                    &mut self.buf_filtered,
447                    &mut self.hp_mem,
448                    frame_size,
449                    self.channels,
450                    self.sampling_rate,
451                );
452            } else {
453                for (i, &x) in input.iter().enumerate() {
454                    self.buf_filtered[i] = (x * 32768.0).clamp(-32768.0, 32767.0) as i16;
455                }
456            }
457
458            let input_i16 = &self.buf_filtered;
459
460            let silk_input: &[i16] = if mode == OpusMode::SilkOnly && self.sampling_rate > 16000 {
461                if self.sampling_rate == 48000 {
462                    let stage1_size = frame_size / 2;
463                    let mut stage1_buf = [0i16; 480];
464                    silk_resampler_down2(
465                        &mut self.down2_state_first,
466                        &mut stage1_buf[..stage1_size],
467                        input_i16,
468                        frame_size as i32,
469                    );
470                    let silk_frame_size = stage1_size * 2 / 3;
471                    self.buf_silk_input.resize(silk_frame_size, 0);
472                    silk_resampler_down2_3(
473                        &mut self.down2_3_state,
474                        &mut self.buf_silk_input,
475                        &stage1_buf[..stage1_size],
476                        stage1_size as i32,
477                    );
478                    &self.buf_silk_input
479                } else if self.sampling_rate == 24000 {
480                    let silk_frame_size = frame_size * 2 / 3;
481                    self.buf_silk_input.resize(silk_frame_size, 0);
482                    silk_resampler_down2_3(
483                        &mut self.down2_3_state,
484                        &mut self.buf_silk_input,
485                        input_i16,
486                        frame_size as i32,
487                    );
488                    &self.buf_silk_input
489                } else {
490                    input_i16
491                }
492            } else if mode == OpusMode::SilkOnly && self.channels == 2 {
493                let frame_length = input_i16.len() / 2;
494                self.buf_stereo_mid.resize(frame_length, 0);
495                self.buf_stereo_side.resize(frame_length, 0);
496                for i in 0..frame_length {
497                    let l = input_i16[2 * i] as i32;
498                    let r = input_i16[2 * i + 1] as i32;
499                    self.buf_stereo_mid[i] = ((l + r) / 2) as i16;
500                    self.buf_stereo_side[i] = (l - r) as i16;
501                }
502
503                self.silk_enc.stereo.side.resize(frame_length, 0);
504                self.silk_enc
505                    .stereo
506                    .side
507                    .copy_from_slice(&self.buf_stereo_side[..frame_length]);
508                &self.buf_stereo_mid
509            } else if mode == OpusMode::Hybrid && self.sampling_rate > 16000 {
510                if self.sampling_rate == 48000 {
511                    let silk_frame_size = frame_size / 3;
512                    self.buf_silk_input.resize(silk_frame_size, 0);
513                    silk::resampler::silk_resampler_down_1_3(
514                        &mut self.down_1_3_state,
515                        &mut self.buf_silk_input,
516                        input_i16,
517                    );
518                } else {
519                    let silk_frame_size = frame_size * 2 / 3;
520                    self.buf_silk_input.resize(silk_frame_size, 0);
521                    silk_resampler_down2_3(
522                        &mut self.down2_3_state,
523                        &mut self.buf_silk_input,
524                        input_i16,
525                        frame_size as i32,
526                    );
527                }
528                &self.buf_silk_input
529            } else {
530                input_i16
531            };
532
533            let mut pn_bytes = 0;
534
535            let silk_rate_for_calc = if mode == OpusMode::Hybrid {
536                16000
537            } else {
538                self.sampling_rate
539            };
540            let silk_frame_len = silk_input.len();
541
542            let silk_bitrate = if mode == OpusMode::Hybrid {
543                let frame_duration_ms = frame_size as i32 * 1000 / self.sampling_rate;
544                let frame20ms = frame_duration_ms >= 20;
545                compute_silk_rate_for_hybrid(self.bitrate_bps, frame20ms)
546            } else {
547                (8i64 * (n_bytes - 1) as i64 * silk_rate_for_calc as i64 / silk_frame_len as i64)
548                    as i32
549            };
550            let silk_max_bits = if mode == OpusMode::Hybrid {
551                // Match C: for VBR Hybrid, maxBits starts at (max_data_bytes-1)*8,
552                // then constrained via compute_silk_rate_for_hybrid.
553                // For CBR Hybrid, allow SILK to steal up to 25% of remaining bits.
554                let total_max_bits = ((n_bytes - 1) * 8) as i32;
555                if self.use_cbr {
556                    let silk_bits = (silk_bitrate as i64 * silk_frame_len as i64
557                        / silk_rate_for_calc as i64) as i32;
558                    let other_bits = 0i32.max(total_max_bits - silk_bits);
559                    0i32.max(total_max_bits - other_bits * 3 / 4)
560                } else {
561                    // Constrained VBR: compute max SILK bits from total budget
562                    let frame_duration_ms = frame_size as i32 * 1000 / self.sampling_rate;
563                    let frame20ms = frame_duration_ms >= 20;
564                    let max_bit_rate = compute_silk_rate_for_hybrid(
565                        total_max_bits * self.sampling_rate / frame_size as i32,
566                        frame20ms,
567                    );
568                    max_bit_rate * frame_size as i32 / self.sampling_rate
569                }
570            } else {
571                ((n_bytes - 1) * 8) as i32
572            };
573            // For Hybrid CBR, force SILK to use VBR (matching C behavior)
574            let silk_use_cbr = if mode == OpusMode::Hybrid && self.use_cbr {
575                0
576            } else {
577                if self.use_cbr { 1 } else { 0 }
578            };
579            let ret = silk_encode(
580                &mut self.silk_enc,
581                silk_input,
582                silk_input.len(),
583                &mut self.rc,
584                &mut pn_bytes,
585                silk_bitrate,
586                silk_max_bits,
587                silk_use_cbr,
588                1,
589            );
590            if ret != 0 {
591                return Err("SILK encoding failed");
592            }
593        }
594
595        if mode == OpusMode::Hybrid {
596            self.rc.encode_bit_logp(false, 12); // redundancy = 0
597        }
598
599        if mode == OpusMode::Hybrid {
600            let nb_compr_bytes = (n_bytes - 1) as u32;
601            self.rc.shrink(nb_compr_bytes);
602        }
603
604        let silk_ret_bytes = if mode == OpusMode::SilkOnly {
605            ((self.rc.tell() + 7) >> 3) as usize
606        } else {
607            0
608        };
609
610        if mode == OpusMode::CeltOnly || mode == OpusMode::Hybrid {
611            self.celt_enc.complexity = self.complexity;
612            let start_band = if mode == OpusMode::Hybrid { 17 } else { 0 };
613            let total_packet_bits = ((n_bytes - 1) * 8) as i32;
614
615            if self.rc.tell() <= total_packet_bits {
616                self.celt_enc.encode_with_budget(
617                    input,
618                    frame_size,
619                    &mut self.rc,
620                    start_band,
621                    total_packet_bits,
622                );
623            }
624        }
625
626        self.rc.done();
627
628        if mode == OpusMode::SilkOnly {
629            let mut ret = silk_ret_bytes.min(self.rc.storage as usize);
630            while ret > 2 && self.rc.buf[ret - 1] == 0 {
631                ret -= 1;
632            }
633
634            let target_total = if self.use_cbr {
635                n_bytes.min(output.len())
636            } else {
637                (ret + 1).min(output.len())
638            };
639
640            let silk_len = ret;
641
642            if !self.use_cbr || silk_len + 1 >= target_total {
643                // VBR or payload fills the target: simple code 0 packet
644                output[0] = toc;
645                let copy_len = silk_len.min(target_total - 1);
646                output[1..1 + copy_len].copy_from_slice(&self.rc.buf[..copy_len]);
647                return Ok((copy_len + 1).min(output.len()));
648            }
649
650            output[0] = toc | 0x03;
651
652            if silk_len + 2 >= target_total {
653                output[1] = 0x01;
654                let copy_len = (target_total - 2).min(silk_len);
655                output[2..2 + copy_len].copy_from_slice(&self.rc.buf[..copy_len]);
656                self.prev_enc_mode = Some(mode);
657                return Ok(target_total.min(output.len()));
658            }
659
660            let pad_amount = target_total - silk_len - 2;
661            output[1] = 0x41;
662
663            let nb_255s = (pad_amount - 1) / 255;
664            let mut ptr = 2;
665            for _ in 0..nb_255s {
666                output[ptr] = 255;
667                ptr += 1;
668            }
669            output[ptr] = (pad_amount - 255 * nb_255s - 1) as u8;
670            ptr += 1;
671
672            output[ptr..ptr + silk_len].copy_from_slice(&self.rc.buf[..silk_len]);
673            ptr += silk_len;
674
675            let fill_end = target_total.min(output.len());
676            for byte in output[ptr..fill_end].iter_mut() {
677                *byte = 0;
678            }
679
680            self.prev_enc_mode = Some(mode);
681            return Ok(target_total.min(output.len()));
682        }
683
684        let payload_len = n_bytes - 1;
685        output[1..1 + payload_len].copy_from_slice(&self.rc.buf[..payload_len]);
686        self.prev_enc_mode = Some(mode);
687        Ok(n_bytes)
688    }
689}
690
691pub struct OpusDecoder {
692    celt_dec: CeltDecoder,
693    silk_dec: silk::dec_api::SilkDecoder,
694    sampling_rate: i32,
695    channels: usize,
696
697    prev_mode: Option<OpusMode>,
698    frame_size: usize,
699
700    bandwidth: Bandwidth,
701
702    stream_channels: usize,
703
704    silk_resampler: silk::resampler::SilkResampler,
705
706    prev_internal_rate: i32,
707
708    pub hybrid_skip_celt: bool,
709
710    w_pcm_i16: Vec<i16>,
711    w_silk_out: Vec<f32>,
712    w_pcm_resampled: Vec<i16>,
713    w_celt_out: Vec<f32>,
714}
715
716impl OpusDecoder {
717    pub fn new(sampling_rate: i32, channels: usize) -> Result<Self, &'static str> {
718        if ![8000, 12000, 16000, 24000, 48000].contains(&sampling_rate) {
719            return Err("Invalid sampling rate");
720        }
721        if ![1, 2].contains(&channels) {
722            return Err("Invalid number of channels");
723        }
724
725        let mode = modes::default_mode();
726        let celt_dec = CeltDecoder::new(mode, channels);
727
728        let mut silk_dec = silk::dec_api::SilkDecoder::new();
729        silk_dec.init(sampling_rate.min(16000), channels as i32);
730        silk_dec.channel_state[0].fs_api_hz = sampling_rate;
731
732        Ok(Self {
733            celt_dec,
734            silk_dec,
735            sampling_rate,
736            channels,
737            prev_mode: None,
738            frame_size: 0,
739            bandwidth: Bandwidth::Auto,
740            stream_channels: channels,
741            silk_resampler: silk::resampler::SilkResampler::default(),
742            prev_internal_rate: 0,
743            hybrid_skip_celt: false,
744
745            w_pcm_i16: vec![0i16; 640],
746
747            w_silk_out: vec![0.0f32; 5760 * channels],
748            w_pcm_resampled: vec![0i16; 5760 * channels],
749            w_celt_out: vec![0.0f32; 5760 * channels],
750        })
751    }
752
753    pub fn decode(
754        &mut self,
755        input: &[u8],
756        frame_size: usize,
757        output: &mut [f32],
758    ) -> Result<usize, &'static str> {
759        if input.is_empty() {
760            return Err("Input packet empty");
761        }
762
763        let toc = input[0];
764        let mode = mode_from_toc(toc);
765        let packet_channels = channels_from_toc(toc);
766        let bandwidth = bandwidth_from_toc(toc);
767        let frame_duration_ms = frame_duration_ms_from_toc(toc);
768
769        if packet_channels != self.channels {
770            return Err("Channel count mismatch between packet and decoder");
771        }
772
773        let code = toc & 0x03;
774        let payload_data;
775
776        match code {
777            0 => {
778                payload_data = &input[1..];
779            }
780            3 => {
781                if input.len() < 2 {
782                    return Err("Code 3 packet too short");
783                }
784                let count_byte = input[1];
785                let _frame_count = (count_byte & 0x3F) as usize;
786                let padding_flag = (count_byte & 0x40) != 0;
787
788                let mut ptr = 2usize;
789                if padding_flag {
790                    let mut pad_len = 0usize;
791                    loop {
792                        if ptr >= input.len() {
793                            return Err("Padding overflow");
794                        }
795                        let p = input[ptr] as usize;
796                        ptr += 1;
797                        if p == 255 {
798                            pad_len += 254;
799                        } else {
800                            pad_len += p;
801                            break;
802                        }
803                    }
804
805                    let end = input.len().saturating_sub(pad_len);
806                    if ptr > end {
807                        return Err("Padding exceeds packet");
808                    }
809                    payload_data = &input[ptr..end];
810                } else {
811                    payload_data = &input[ptr..];
812                }
813            }
814            _ => {
815                payload_data = &input[1..];
816            }
817        }
818
819        self.frame_size = frame_size;
820        self.bandwidth = bandwidth;
821        self.stream_channels = packet_channels;
822
823        match mode {
824            OpusMode::SilkOnly => {
825                let internal_sample_rate = match bandwidth {
826                    Bandwidth::Narrowband => 8000,
827                    Bandwidth::Mediumband => 12000,
828                    Bandwidth::Wideband => 16000,
829                    _ => 16000,
830                };
831
832                let mut rc = RangeCoder::new_decoder(payload_data);
833                let internal_frame_size =
834                    (frame_duration_ms * internal_sample_rate / 1000) as usize;
835                let pcm_i16_len = internal_frame_size * self.channels;
836                debug_assert!(pcm_i16_len <= self.w_pcm_i16.len());
837
838                let payload_size_ms = frame_duration_ms;
839
840                let ret = {
841                    let (silk_dec, pcm_i16) = (&mut self.silk_dec, &mut self.w_pcm_i16);
842                    silk_dec.decode(
843                        &mut rc,
844                        &mut pcm_i16[..pcm_i16_len],
845                        silk::decode_frame::FLAG_DECODE_NORMAL,
846                        true,
847                        payload_size_ms,
848                        internal_sample_rate,
849                    )
850                };
851
852                if ret < 0 {
853                    return Err("SILK decoding failed");
854                }
855
856                let decoded_samples = ret as usize;
857
858                if self.sampling_rate == internal_sample_rate {
859                    let n = decoded_samples.min(frame_size).min(output.len());
860                    for i in 0..n {
861                        output[i] = self.w_pcm_i16[i] as f32 / 32768.0;
862                    }
863                    self.prev_mode = Some(OpusMode::SilkOnly);
864                    Ok(n)
865                } else {
866                    if internal_sample_rate != self.prev_internal_rate {
867                        self.silk_resampler
868                            .init(internal_sample_rate, self.sampling_rate);
869                        self.prev_internal_rate = internal_sample_rate;
870                    }
871
872                    let ratio = self.sampling_rate as f64 / internal_sample_rate as f64;
873                    let out_len = ((decoded_samples as f64 * ratio) as usize).min(frame_size);
874                    debug_assert!(out_len <= self.w_pcm_resampled.len());
875                    let ret2 = {
876                        let (silk_res, pcm_i16, pcm_out) = (
877                            &mut self.silk_resampler,
878                            &self.w_pcm_i16,
879                            &mut self.w_pcm_resampled,
880                        );
881                        silk_res.process(
882                            &mut pcm_out[..out_len],
883                            &pcm_i16[..decoded_samples],
884                            decoded_samples as i32,
885                        )
886                    };
887                    let _ = ret2;
888
889                    let n = out_len.min(output.len());
890                    for i in 0..n {
891                        output[i] = self.w_pcm_resampled[i] as f32 / 32768.0;
892                    }
893                    self.prev_mode = Some(OpusMode::SilkOnly);
894                    Ok(n)
895                }
896            }
897
898            OpusMode::CeltOnly => {
899                let mut rc = RangeCoder::new_decoder(payload_data);
900                let total_bits = (payload_data.len() * 8) as i32;
901                let needed = frame_size * self.channels;
902                if output.len() < needed {
903                    return Err("Output buffer too small");
904                }
905                self.celt_dec.decode_from_range_coder(
906                    &mut rc,
907                    total_bits,
908                    frame_size,
909                    &mut output[..needed],
910                    0,
911                );
912                self.prev_mode = Some(OpusMode::CeltOnly);
913                Ok(frame_size)
914            }
915
916            OpusMode::Hybrid => {
917                let internal_sample_rate = match bandwidth {
918                    Bandwidth::Superwideband => 16000,
919                    Bandwidth::Fullband => 16000,
920                    _ => 16000,
921                };
922
923                let mut rc = RangeCoder::new_decoder(payload_data);
924                let internal_frame_size =
925                    (frame_duration_ms * internal_sample_rate / 1000) as usize;
926                let pcm_silk_i16_len = internal_frame_size * self.channels;
927                debug_assert!(pcm_silk_i16_len <= self.w_pcm_i16.len());
928
929                let ret = {
930                    let (silk_dec, pcm_i16) = (&mut self.silk_dec, &mut self.w_pcm_i16);
931                    silk_dec.decode(
932                        &mut rc,
933                        &mut pcm_i16[..pcm_silk_i16_len],
934                        silk::decode_frame::FLAG_DECODE_NORMAL,
935                        true,
936                        frame_duration_ms,
937                        internal_sample_rate,
938                    )
939                };
940
941                let silk_out_len = frame_size * self.channels;
942                debug_assert!(silk_out_len <= self.w_silk_out.len());
943                self.w_silk_out[..silk_out_len].fill(0.0);
944                if ret > 0 {
945                    let decoded_samples = ret as usize;
946                    if self.sampling_rate == internal_sample_rate {
947                        let n = decoded_samples.min(frame_size);
948                        for i in 0..n {
949                            self.w_silk_out[i] = self.w_pcm_i16[i] as f32 / 32768.0;
950                        }
951                    } else {
952                        if internal_sample_rate != self.prev_internal_rate {
953                            self.silk_resampler
954                                .init(internal_sample_rate, self.sampling_rate);
955                            self.prev_internal_rate = internal_sample_rate;
956                        }
957                        let ratio = self.sampling_rate as f64 / internal_sample_rate as f64;
958                        let out_len = ((decoded_samples as f64 * ratio) as usize).min(frame_size);
959                        debug_assert!(out_len <= self.w_pcm_resampled.len());
960                        {
961                            let (silk_res, pcm_i16, pcm_resampled) = (
962                                &mut self.silk_resampler,
963                                &self.w_pcm_i16,
964                                &mut self.w_pcm_resampled,
965                            );
966                            silk_res.process(
967                                &mut pcm_resampled[..out_len],
968                                &pcm_i16[..decoded_samples],
969                                decoded_samples as i32,
970                            );
971                        }
972                        for i in 0..out_len.min(frame_size) {
973                            self.w_silk_out[i] = self.w_pcm_resampled[i] as f32 / 32768.0;
974                        }
975                    }
976                }
977
978                // In Hybrid mode the encoder always writes a redundancy flag (logp=12)
979                // immediately after SILK.  Read it here to keep the range-coder in sync.
980                // We don't support redundancy frames, so if the flag is somehow 1 we
981                // still just consume the extra bit and fall through to normal CELT decode.
982                let total_bits = (payload_data.len() * 8) as i32;
983                let redundancy = rc.decode_bit_logp(12);
984                if redundancy {
985                    // Consume celt_to_silk flag; actual redundancy frame handling not implemented.
986                    let _ = rc.decode_bit_logp(1);
987                }
988
989                let celt_out_len = frame_size * self.channels;
990                debug_assert!(celt_out_len <= self.w_celt_out.len());
991                if self.hybrid_skip_celt {
992                    self.w_celt_out[..celt_out_len].fill(0.0);
993                } else {
994                    let (celt_dec, celt_out) = (&mut self.celt_dec, &mut self.w_celt_out);
995                    celt_dec.decode_from_range_coder(
996                        &mut rc,
997                        total_bits,
998                        frame_size,
999                        &mut celt_out[..celt_out_len],
1000                        17,
1001                    );
1002                }
1003
1004                let n = frame_size.min(output.len());
1005                for i in 0..n {
1006                    output[i] = (self.w_silk_out[i] + self.w_celt_out[i]).clamp(-1.0, 1.0);
1007                }
1008
1009                self.prev_mode = Some(OpusMode::Hybrid);
1010                Ok(n)
1011            }
1012        }
1013    }
1014}
1015
1016fn frame_rate_from_params(sampling_rate: i32, frame_size: usize) -> Option<i32> {
1017    let frame_size = frame_size as i32;
1018    if frame_size == 0 || sampling_rate % frame_size != 0 {
1019        return None;
1020    }
1021    Some(sampling_rate / frame_size)
1022}
1023
1024fn gen_toc(mode: OpusMode, frame_rate: i32, bandwidth: Bandwidth, channels: usize) -> u8 {
1025    let mut rate = frame_rate;
1026    let mut period = 0;
1027    while rate < 400 {
1028        rate <<= 1;
1029        period += 1;
1030    }
1031
1032    let mut toc = match mode {
1033        OpusMode::SilkOnly => {
1034            let bw = (bandwidth as i32 - Bandwidth::Narrowband as i32) << 5;
1035            let per = (period - 2) << 3;
1036            (bw | per) as u8
1037        }
1038        OpusMode::CeltOnly => {
1039            let mut tmp = bandwidth as i32 - Bandwidth::Mediumband as i32;
1040            if tmp < 0 {
1041                tmp = 0;
1042            }
1043            let per = period << 3;
1044            (0x80 | (tmp << 5) | per) as u8
1045        }
1046        OpusMode::Hybrid => {
1047            let base_config = if bandwidth == Bandwidth::Superwideband {
1048                12
1049            } else {
1050                14
1051            };
1052            let period_offset = if frame_rate >= 100 { 0 } else { 1 };
1053            ((base_config + period_offset) << 3) as u8
1054        }
1055    };
1056
1057    if channels == 2 {
1058        toc |= 0x04;
1059    }
1060    toc
1061}
1062
1063fn mode_from_toc(toc: u8) -> OpusMode {
1064    if toc & 0x80 != 0 {
1065        OpusMode::CeltOnly
1066    } else if toc & 0x60 == 0x60 {
1067        OpusMode::Hybrid
1068    } else {
1069        OpusMode::SilkOnly
1070    }
1071}
1072
1073fn bandwidth_from_toc(toc: u8) -> Bandwidth {
1074    let mode = mode_from_toc(toc);
1075    match mode {
1076        OpusMode::SilkOnly => {
1077            let bw_bits = (toc >> 5) & 0x03;
1078            match bw_bits {
1079                0 => Bandwidth::Narrowband,
1080                1 => Bandwidth::Mediumband,
1081                2 => Bandwidth::Wideband,
1082                _ => Bandwidth::Wideband,
1083            }
1084        }
1085        OpusMode::Hybrid => {
1086            let bw_bit = (toc >> 4) & 0x01;
1087            if bw_bit == 0 {
1088                Bandwidth::Superwideband
1089            } else {
1090                Bandwidth::Fullband
1091            }
1092        }
1093        OpusMode::CeltOnly => {
1094            let bw_bits = (toc >> 5) & 0x03;
1095            match bw_bits {
1096                0 => Bandwidth::Mediumband,
1097                1 => Bandwidth::Wideband,
1098                2 => Bandwidth::Superwideband,
1099                3 => Bandwidth::Fullband,
1100                _ => Bandwidth::Fullband,
1101            }
1102        }
1103    }
1104}
1105
1106fn frame_duration_ms_from_toc(toc: u8) -> i32 {
1107    let mode = mode_from_toc(toc);
1108    match mode {
1109        OpusMode::SilkOnly => {
1110            let config = (toc >> 3) & 0x03;
1111            match config {
1112                0 => 10,
1113                1 => 20,
1114                2 => 40,
1115                3 => 60,
1116                _ => 20,
1117            }
1118        }
1119        OpusMode::Hybrid => {
1120            let config = (toc >> 3) & 0x01;
1121            if config == 0 { 10 } else { 20 }
1122        }
1123        OpusMode::CeltOnly => {
1124            let config = (toc >> 3) & 0x03;
1125            match config {
1126                0 => 2,
1127                1 => 5,
1128                2 => 10,
1129                3 => 20,
1130                _ => 20,
1131            }
1132        }
1133    }
1134}
1135
1136fn channels_from_toc(toc: u8) -> usize {
1137    if toc & 0x04 != 0 { 2 } else { 1 }
1138}
1139
1140#[cfg(test)]
1141mod tests {
1142    use super::*;
1143
1144    fn frame_size_from_toc(toc: u8, sampling_rate: i32) -> Option<usize> {
1145        let mode = mode_from_toc(toc);
1146        match mode {
1147            OpusMode::CeltOnly => {
1148                let period = ((toc >> 3) & 0x03) as i32;
1149                let frame_rate = 400 >> period;
1150                if frame_rate == 0 || sampling_rate % frame_rate != 0 {
1151                    return None;
1152                }
1153                Some((sampling_rate / frame_rate) as usize)
1154            }
1155            OpusMode::SilkOnly => {
1156                let duration_ms = frame_duration_ms_from_toc(toc);
1157                Some((sampling_rate as i64 * duration_ms as i64 / 1000) as usize)
1158            }
1159            OpusMode::Hybrid => {
1160                let duration_ms = frame_duration_ms_from_toc(toc);
1161                Some((sampling_rate as i64 * duration_ms as i64 / 1000) as usize)
1162            }
1163        }
1164    }
1165
1166    #[test]
1167    fn gen_toc_matches_celt_reference_values() {
1168        let sampling_rate = 48_000;
1169        let cases = [
1170            (120usize, 0xE0u8),
1171            (240usize, 0xE8u8),
1172            (480usize, 0xF0u8),
1173            (960usize, 0xF8u8),
1174        ];
1175
1176        for (frame_size, expected_toc) in cases {
1177            let frame_rate = frame_rate_from_params(sampling_rate, frame_size).unwrap();
1178            let toc = gen_toc(OpusMode::CeltOnly, frame_rate, Bandwidth::Fullband, 1);
1179            assert_eq!(
1180                toc, expected_toc,
1181                "frame_size {} expected TOC {:02X} got {:02X}",
1182                frame_size, expected_toc, toc
1183            );
1184            let decoded_size = frame_size_from_toc(toc, sampling_rate).unwrap();
1185            assert_eq!(decoded_size, frame_size);
1186        }
1187
1188        let stereo_toc = gen_toc(
1189            OpusMode::CeltOnly,
1190            frame_rate_from_params(sampling_rate, 960).unwrap(),
1191            Bandwidth::Fullband,
1192            2,
1193        );
1194        assert_eq!(channels_from_toc(stereo_toc), 2);
1195    }
1196
1197    #[test]
1198    fn test_celt_decoder_large_frame_sizes() {
1199        let sampling_rate = 48000;
1200        let channels = 1;
1201
1202        let mut decoder = OpusDecoder::new(sampling_rate, channels).unwrap();
1203
1204        let frame_sizes = [120, 240, 480, 960];
1205
1206        for frame_size in frame_sizes {
1207            let toc = gen_toc(
1208                OpusMode::CeltOnly,
1209                frame_rate_from_params(sampling_rate, frame_size).unwrap(),
1210                Bandwidth::Fullband,
1211                channels,
1212            );
1213            let packet = [toc, 0, 0, 0, 0];
1214
1215            let mut output = vec![0.0f32; frame_size * channels];
1216
1217            let _ = decoder.decode(&packet, frame_size, &mut output);
1218        }
1219
1220        let channels = 2;
1221        let mut decoder = OpusDecoder::new(sampling_rate, channels).unwrap();
1222
1223        for frame_size in frame_sizes {
1224            let toc = gen_toc(
1225                OpusMode::CeltOnly,
1226                frame_rate_from_params(sampling_rate, frame_size).unwrap(),
1227                Bandwidth::Fullband,
1228                channels,
1229            );
1230            let packet = [toc, 0, 0, 0, 0];
1231
1232            let mut output = vec![0.0f32; frame_size * channels];
1233            let _ = decoder.decode(&packet, frame_size, &mut output);
1234        }
1235    }
1236
1237    #[test]
1238    fn test_celt_decoder_edge_case_frame_sizes() {
1239        let sampling_rate = 48000;
1240        let channels = 1;
1241        let mut decoder = OpusDecoder::new(sampling_rate, channels).unwrap();
1242
1243        let edge_sizes = [2048, 2167, 2168, 2169, 2880, 3072];
1244
1245        for frame_size in edge_sizes {
1246            let mut output = vec![0.0f32; frame_size * channels];
1247
1248            let _ = decoder.decode(&[0x80, 0, 0, 0], frame_size, &mut output);
1249        }
1250    }
1251}