opus_codec/
multistream.rs

1//! Safe wrappers for the Opus Multistream API (surround and channel-mapped streams)
2
3use crate::bindings::{
4    OPUS_AUTO, OPUS_BANDWIDTH_FULLBAND, OPUS_BANDWIDTH_MEDIUMBAND, OPUS_BANDWIDTH_NARROWBAND,
5    OPUS_BANDWIDTH_SUPERWIDEBAND, OPUS_BANDWIDTH_WIDEBAND, OPUS_BITRATE_MAX,
6    OPUS_GET_BANDWIDTH_REQUEST, OPUS_GET_BITRATE_REQUEST, OPUS_GET_COMPLEXITY_REQUEST,
7    OPUS_GET_DTX_REQUEST, OPUS_GET_FINAL_RANGE_REQUEST, OPUS_GET_FORCE_CHANNELS_REQUEST,
8    OPUS_GET_GAIN_REQUEST, OPUS_GET_IN_DTX_REQUEST, OPUS_GET_INBAND_FEC_REQUEST,
9    OPUS_GET_LAST_PACKET_DURATION_REQUEST, OPUS_GET_LOOKAHEAD_REQUEST,
10    OPUS_GET_MAX_BANDWIDTH_REQUEST, OPUS_GET_PACKET_LOSS_PERC_REQUEST,
11    OPUS_GET_PHASE_INVERSION_DISABLED_REQUEST, OPUS_GET_PITCH_REQUEST,
12    OPUS_GET_SAMPLE_RATE_REQUEST, OPUS_GET_SIGNAL_REQUEST, OPUS_GET_VBR_CONSTRAINT_REQUEST,
13    OPUS_GET_VBR_REQUEST, OPUS_MULTISTREAM_GET_DECODER_STATE_REQUEST,
14    OPUS_MULTISTREAM_GET_ENCODER_STATE_REQUEST, OPUS_RESET_STATE, OPUS_SET_BANDWIDTH_REQUEST,
15    OPUS_SET_BITRATE_REQUEST, OPUS_SET_COMPLEXITY_REQUEST, OPUS_SET_DTX_REQUEST,
16    OPUS_SET_FORCE_CHANNELS_REQUEST, OPUS_SET_GAIN_REQUEST, OPUS_SET_INBAND_FEC_REQUEST,
17    OPUS_SET_MAX_BANDWIDTH_REQUEST, OPUS_SET_PACKET_LOSS_PERC_REQUEST,
18    OPUS_SET_PHASE_INVERSION_DISABLED_REQUEST, OPUS_SET_SIGNAL_REQUEST,
19    OPUS_SET_VBR_CONSTRAINT_REQUEST, OPUS_SET_VBR_REQUEST, OPUS_SIGNAL_MUSIC, OPUS_SIGNAL_VOICE,
20    OpusDecoder, OpusEncoder, OpusMSDecoder, OpusMSEncoder, opus_multistream_decode,
21    opus_multistream_decode_float, opus_multistream_decoder_create, opus_multistream_decoder_ctl,
22    opus_multistream_decoder_destroy, opus_multistream_encode, opus_multistream_encode_float,
23    opus_multistream_encoder_create, opus_multistream_encoder_ctl,
24    opus_multistream_encoder_destroy, opus_multistream_surround_encoder_create,
25};
26use crate::error::{Error, Result};
27use crate::types::{Application, Bandwidth, Bitrate, Channels, Complexity, SampleRate, Signal};
28
29/// Describes the multistream mapping configuration.
30#[derive(Debug, Clone, Copy)]
31pub struct Mapping<'a> {
32    /// Total input/output channels.
33    pub channels: u8,
34    /// Number of uncoupled mono streams.
35    pub streams: u8,
36    /// Number of coupled stereo streams (each counts as 2 channels).
37    pub coupled_streams: u8,
38    /// Channel-to-stream mapping table (length == channels).
39    pub mapping: &'a [u8],
40}
41
42impl Mapping<'_> {
43    /// Validate that mapping length matches channels.
44    fn validate(&self) -> Result<()> {
45        let channel_count = usize::from(self.channels);
46        if channel_count == 0 {
47            return Err(Error::BadArg);
48        }
49        if self.mapping.len() != channel_count {
50            return Err(Error::BadArg);
51        }
52
53        let streams = usize::from(self.streams);
54        let coupled = usize::from(self.coupled_streams);
55        if streams + coupled == 0 {
56            return Err(Error::BadArg);
57        }
58        if streams > channel_count {
59            return Err(Error::BadArg);
60        }
61        if coupled > channel_count / 2 {
62            return Err(Error::BadArg);
63        }
64        let total_streams = streams + coupled;
65        let mut assignments = vec![0usize; total_streams];
66        for &entry in self.mapping {
67            if entry == u8::MAX {
68                continue;
69            }
70            let idx = usize::from(entry);
71            if idx >= total_streams {
72                return Err(Error::BadArg);
73            }
74            assignments[idx] += 1;
75            if idx < streams {
76                if assignments[idx] > 1 {
77                    return Err(Error::BadArg);
78                }
79            } else if assignments[idx] > 2 {
80                return Err(Error::BadArg);
81            }
82        }
83        Ok(())
84    }
85}
86
87/// Safe wrapper around `OpusMSEncoder`.
88pub struct MSEncoder {
89    raw: *mut OpusMSEncoder,
90    sample_rate: SampleRate,
91    channels: u8,
92    streams: u8,
93    coupled_streams: u8,
94}
95
96unsafe impl Send for MSEncoder {}
97unsafe impl Sync for MSEncoder {}
98
99impl MSEncoder {
100    /// Create a new multistream encoder.
101    ///
102    /// The `mapping.mapping` array describes how input channels are assigned to streams.
103    /// See libopus docs for standard surround layouts.
104    ///
105    /// # Errors
106    /// Returns [`Error::BadArg`] when the mapping dimensions are inconsistent, or
107    /// propagates allocation/configuration failures from libopus.
108    pub fn new(sr: SampleRate, app: Application, mapping: Mapping<'_>) -> Result<Self> {
109        mapping.validate()?;
110        let mut err = 0i32;
111        let enc = unsafe {
112            opus_multistream_encoder_create(
113                sr as i32,
114                i32::from(mapping.channels),
115                i32::from(mapping.streams),
116                i32::from(mapping.coupled_streams),
117                mapping.mapping.as_ptr(),
118                app as i32,
119                std::ptr::addr_of_mut!(err),
120            )
121        };
122        if err != 0 {
123            return Err(Error::from_code(err));
124        }
125        if enc.is_null() {
126            return Err(Error::AllocFail);
127        }
128        Ok(Self {
129            raw: enc,
130            sample_rate: sr,
131            channels: mapping.channels,
132            streams: mapping.streams,
133            coupled_streams: mapping.coupled_streams,
134        })
135    }
136
137    /// Encode interleaved i16 PCM into a multistream Opus packet.
138    ///
139    /// # Errors
140    /// Returns [`Error::InvalidState`] if the encoder handle is invalid, [`Error::BadArg`]
141    /// for buffer mismatches, or the mapped libopus error code.
142    #[allow(clippy::missing_panics_doc)]
143    pub fn encode(
144        &mut self,
145        pcm: &[i16],
146        frame_size_per_ch: usize,
147        out: &mut [u8],
148    ) -> Result<usize> {
149        if self.raw.is_null() {
150            return Err(Error::InvalidState);
151        }
152        if pcm.len() != frame_size_per_ch * self.channels as usize {
153            return Err(Error::BadArg);
154        }
155        if out.is_empty() || out.len() > i32::MAX as usize {
156            return Err(Error::BadArg);
157        }
158        let n = unsafe {
159            opus_multistream_encode(
160                self.raw,
161                pcm.as_ptr(),
162                i32::try_from(frame_size_per_ch).map_err(|_| Error::BadArg)?,
163                out.as_mut_ptr(),
164                i32::try_from(out.len()).map_err(|_| Error::BadArg)?,
165            )
166        };
167        if n < 0 {
168            return Err(Error::from_code(n));
169        }
170        usize::try_from(n).map_err(|_| Error::InternalError)
171    }
172
173    /// Encode interleaved f32 PCM into a multistream Opus packet.
174    ///
175    /// # Errors
176    /// Returns [`Error::InvalidState`] if the encoder handle is invalid, [`Error::BadArg`]
177    /// for buffer mismatches, or the mapped libopus error code.
178    pub fn encode_float(
179        &mut self,
180        pcm: &[f32],
181        frame_size_per_ch: usize,
182        out: &mut [u8],
183    ) -> Result<usize> {
184        if self.raw.is_null() {
185            return Err(Error::InvalidState);
186        }
187        if pcm.len() != frame_size_per_ch * self.channels as usize {
188            return Err(Error::BadArg);
189        }
190        if out.is_empty() || out.len() > i32::MAX as usize {
191            return Err(Error::BadArg);
192        }
193        let n = unsafe {
194            opus_multistream_encode_float(
195                self.raw,
196                pcm.as_ptr(),
197                i32::try_from(frame_size_per_ch).map_err(|_| Error::BadArg)?,
198                out.as_mut_ptr(),
199                i32::try_from(out.len()).map_err(|_| Error::BadArg)?,
200            )
201        };
202        if n < 0 {
203            return Err(Error::from_code(n));
204        }
205        usize::try_from(n).map_err(|_| Error::InternalError)
206    }
207
208    /// Final RNG state from the last encode.
209    ///
210    /// # Errors
211    /// Returns [`Error::InvalidState`] when the encoder handle is null or
212    /// propagates the libopus error.
213    pub fn final_range(&mut self) -> Result<u32> {
214        if self.raw.is_null() {
215            return Err(Error::InvalidState);
216        }
217        let mut v: u32 = 0;
218        let r = unsafe {
219            opus_multistream_encoder_ctl(self.raw, OPUS_GET_FINAL_RANGE_REQUEST as i32, &mut v)
220        };
221        if r != 0 {
222            return Err(Error::from_code(r));
223        }
224        Ok(v)
225    }
226
227    /// Set target bitrate for the encoder.
228    ///
229    /// # Errors
230    /// Returns [`Error::InvalidState`] if the encoder handle is null or propagates any error
231    /// reported by libopus.
232    pub fn set_bitrate(&mut self, bitrate: Bitrate) -> Result<()> {
233        self.simple_ctl(OPUS_SET_BITRATE_REQUEST as i32, bitrate.value())
234    }
235
236    /// Query the current bitrate target.
237    ///
238    /// # Errors
239    /// Returns [`Error::InvalidState`] if the encoder handle is null, [`Error::InternalError`]
240    /// if the returned value cannot be represented, or propagates any error reported by
241    /// libopus.
242    pub fn bitrate(&mut self) -> Result<Bitrate> {
243        let v = self.get_int_ctl(OPUS_GET_BITRATE_REQUEST as i32)?;
244        Ok(match v {
245            x if x == OPUS_AUTO => Bitrate::Auto,
246            x if x == OPUS_BITRATE_MAX => Bitrate::Max,
247            other => Bitrate::Custom(other),
248        })
249    }
250
251    /// Set encoder complexity in the range 0..=10.
252    ///
253    /// # Errors
254    /// Returns [`Error::InvalidState`] if the encoder handle is null or propagates any error
255    /// reported by libopus.
256    pub fn set_complexity(&mut self, complexity: Complexity) -> Result<()> {
257        self.simple_ctl(
258            OPUS_SET_COMPLEXITY_REQUEST as i32,
259            complexity.value() as i32,
260        )
261    }
262
263    /// Query encoder complexity.
264    ///
265    /// # Errors
266    /// Returns [`Error::InvalidState`] if the encoder handle is null, [`Error::InternalError`]
267    /// if the response is outside the valid range, or propagates any error reported by libopus.
268    pub fn complexity(&mut self) -> Result<Complexity> {
269        let v = self.get_int_ctl(OPUS_GET_COMPLEXITY_REQUEST as i32)?;
270        Ok(Complexity::new(
271            u32::try_from(v).map_err(|_| Error::InternalError)?,
272        ))
273    }
274
275    /// Enable/disable discontinuous transmission (DTX).
276    ///
277    /// # Errors
278    /// Returns [`Error::InvalidState`] if the encoder handle is null or propagates any error
279    /// reported by libopus.
280    pub fn set_dtx(&mut self, enabled: bool) -> Result<()> {
281        self.simple_ctl(OPUS_SET_DTX_REQUEST as i32, i32::from(enabled))
282    }
283
284    /// Query whether DTX is enabled.
285    ///
286    /// # Errors
287    /// Returns [`Error::InvalidState`] if the encoder handle is null or propagates any error
288    /// reported by libopus.
289    pub fn dtx(&mut self) -> Result<bool> {
290        self.get_bool_ctl(OPUS_GET_DTX_REQUEST as i32)
291    }
292
293    /// Query whether the encoder is currently in DTX.
294    ///
295    /// # Errors
296    /// Returns [`Error::InvalidState`] if the encoder handle is null or propagates any error
297    /// reported by libopus.
298    pub fn in_dtx(&mut self) -> Result<bool> {
299        self.get_bool_ctl(OPUS_GET_IN_DTX_REQUEST as i32)
300    }
301
302    /// Enable/disable in-band FEC generation.
303    ///
304    /// # Errors
305    /// Returns [`Error::InvalidState`] if the encoder handle is null or propagates any error
306    /// reported by libopus.
307    pub fn set_inband_fec(&mut self, enabled: bool) -> Result<()> {
308        self.simple_ctl(OPUS_SET_INBAND_FEC_REQUEST as i32, i32::from(enabled))
309    }
310
311    /// Query whether in-band FEC is enabled.
312    ///
313    /// # Errors
314    /// Returns [`Error::InvalidState`] if the encoder handle is null or propagates any error
315    /// reported by libopus.
316    pub fn inband_fec(&mut self) -> Result<bool> {
317        self.get_bool_ctl(OPUS_GET_INBAND_FEC_REQUEST as i32)
318    }
319
320    /// Set expected packet loss percentage (0..=100).
321    ///
322    /// # Errors
323    /// Returns [`Error::BadArg`] when `perc` is outside `0..=100`, [`Error::InvalidState`] if
324    /// the encoder handle is null, or propagates any error reported by libopus.
325    pub fn set_packet_loss_perc(&mut self, perc: i32) -> Result<()> {
326        if !(0..=100).contains(&perc) {
327            return Err(Error::BadArg);
328        }
329        self.simple_ctl(OPUS_SET_PACKET_LOSS_PERC_REQUEST as i32, perc)
330    }
331
332    /// Query expected packet loss percentage.
333    ///
334    /// # Errors
335    /// Returns [`Error::InvalidState`] if the encoder handle is null or propagates any error
336    /// reported by libopus.
337    pub fn packet_loss_perc(&mut self) -> Result<i32> {
338        self.get_int_ctl(OPUS_GET_PACKET_LOSS_PERC_REQUEST as i32)
339    }
340
341    /// Enable/disable variable bitrate.
342    ///
343    /// # Errors
344    /// Returns [`Error::InvalidState`] if the encoder handle is null or propagates any error
345    /// reported by libopus.
346    pub fn set_vbr(&mut self, enabled: bool) -> Result<()> {
347        self.simple_ctl(OPUS_SET_VBR_REQUEST as i32, i32::from(enabled))
348    }
349
350    /// Query VBR status.
351    ///
352    /// # Errors
353    /// Returns [`Error::InvalidState`] if the encoder handle is null or propagates any error
354    /// reported by libopus.
355    pub fn vbr(&mut self) -> Result<bool> {
356        self.get_bool_ctl(OPUS_GET_VBR_REQUEST as i32)
357    }
358
359    /// Constrain VBR to reduce instantaneous bitrate swings.
360    ///
361    /// # Errors
362    /// Returns [`Error::InvalidState`] if the encoder handle is null or propagates any error
363    /// reported by libopus.
364    pub fn set_vbr_constraint(&mut self, constrained: bool) -> Result<()> {
365        self.simple_ctl(
366            OPUS_SET_VBR_CONSTRAINT_REQUEST as i32,
367            i32::from(constrained),
368        )
369    }
370
371    /// Query VBR constraint flag.
372    ///
373    /// # Errors
374    /// Returns [`Error::InvalidState`] if the encoder handle is null or propagates any error
375    /// reported by libopus.
376    pub fn vbr_constraint(&mut self) -> Result<bool> {
377        self.get_bool_ctl(OPUS_GET_VBR_CONSTRAINT_REQUEST as i32)
378    }
379
380    /// Set the maximum bandwidth the encoder may use.
381    ///
382    /// # Errors
383    /// Returns [`Error::InvalidState`] if the encoder handle is null or propagates any error
384    /// reported by libopus.
385    pub fn set_max_bandwidth(&mut self, bw: Bandwidth) -> Result<()> {
386        self.simple_ctl(OPUS_SET_MAX_BANDWIDTH_REQUEST as i32, bw as i32)
387    }
388
389    /// Query the configured maximum bandwidth.
390    ///
391    /// # Errors
392    /// Returns [`Error::InvalidState`] if the encoder handle is null, [`Error::InternalError`]
393    /// if the value cannot be represented, or propagates any error reported by libopus.
394    pub fn max_bandwidth(&mut self) -> Result<Bandwidth> {
395        self.get_bandwidth_ctl(OPUS_GET_MAX_BANDWIDTH_REQUEST as i32)
396    }
397
398    /// Force a specific output bandwidth (overrides automatic selection).
399    ///
400    /// # Errors
401    /// Returns [`Error::InvalidState`] if the encoder handle is null or propagates any error
402    /// reported by libopus.
403    pub fn set_bandwidth(&mut self, bw: Bandwidth) -> Result<()> {
404        self.simple_ctl(OPUS_SET_BANDWIDTH_REQUEST as i32, bw as i32)
405    }
406
407    /// Query the current forced bandwidth, if any.
408    ///
409    /// # Errors
410    /// Returns [`Error::InvalidState`] if the encoder handle is null or [`Error::InternalError`]
411    /// if the value is outside the known set, and propagates any error reported by libopus.
412    pub fn bandwidth(&mut self) -> Result<Bandwidth> {
413        self.get_bandwidth_ctl(OPUS_GET_BANDWIDTH_REQUEST as i32)
414    }
415
416    /// Force mono/stereo output for coupled streams, or `None` for automatic.
417    ///
418    /// # Errors
419    /// Returns [`Error::InvalidState`] if the encoder handle is null or propagates any error
420    /// reported by libopus.
421    pub fn set_force_channels(&mut self, channels: Option<Channels>) -> Result<()> {
422        let value = match channels {
423            Some(Channels::Mono) => 1,
424            Some(Channels::Stereo) => 2,
425            None => OPUS_AUTO,
426        };
427        self.simple_ctl(OPUS_SET_FORCE_CHANNELS_REQUEST as i32, value)
428    }
429
430    /// Query forced channel configuration (if any).
431    ///
432    /// # Errors
433    /// Returns [`Error::InvalidState`] if the encoder handle is null or propagates any error
434    /// reported by libopus.
435    pub fn force_channels(&mut self) -> Result<Option<Channels>> {
436        let v = self.get_int_ctl(OPUS_GET_FORCE_CHANNELS_REQUEST as i32)?;
437        Ok(match v {
438            1 => Some(Channels::Mono),
439            2 => Some(Channels::Stereo),
440            x if x == OPUS_AUTO => None,
441            _ => None,
442        })
443    }
444
445    /// Hint the type of content being encoded (voice/music).
446    ///
447    /// # Errors
448    /// Returns [`Error::InvalidState`] if the encoder handle is null or propagates any error
449    /// reported by libopus.
450    pub fn set_signal(&mut self, signal: Signal) -> Result<()> {
451        self.simple_ctl(OPUS_SET_SIGNAL_REQUEST as i32, signal as i32)
452    }
453
454    /// Query the current signal hint.
455    ///
456    /// # Errors
457    /// Returns [`Error::InvalidState`] if the encoder handle is null, [`Error::InternalError`]
458    /// if the response is not recognized, or propagates any error reported by libopus.
459    pub fn signal(&mut self) -> Result<Signal> {
460        let v = self.get_int_ctl(OPUS_GET_SIGNAL_REQUEST as i32)?;
461        match v {
462            x if x == OPUS_SIGNAL_VOICE as i32 => Ok(Signal::Voice),
463            x if x == OPUS_SIGNAL_MUSIC as i32 => Ok(Signal::Music),
464            _ => Err(Error::InternalError),
465        }
466    }
467
468    /// Query the algorithmic lookahead in samples at 48 kHz.
469    ///
470    /// # Errors
471    /// Returns [`Error::InvalidState`] if the encoder handle is null or propagates any error
472    /// reported by libopus.
473    pub fn lookahead(&mut self) -> Result<i32> {
474        self.get_int_ctl(OPUS_GET_LOOKAHEAD_REQUEST as i32)
475    }
476
477    /// Reset the encoder state (retaining configuration).
478    ///
479    /// # Errors
480    /// Returns [`Error::InvalidState`] if the encoder handle is null or propagates any error
481    /// reported by libopus.
482    pub fn reset(&mut self) -> Result<()> {
483        if self.raw.is_null() {
484            return Err(Error::InvalidState);
485        }
486        let r = unsafe { opus_multistream_encoder_ctl(self.raw, OPUS_RESET_STATE as i32) };
487        if r != 0 {
488            return Err(Error::from_code(r));
489        }
490        Ok(())
491    }
492
493    /// Channels of this encoder (interleaved input).
494    #[must_use]
495    pub const fn channels(&self) -> u8 {
496        self.channels
497    }
498    /// Input sampling rate.
499    #[must_use]
500    pub const fn sample_rate(&self) -> SampleRate {
501        self.sample_rate
502    }
503    /// Number of mono streams.
504    #[must_use]
505    pub const fn streams(&self) -> u8 {
506        self.streams
507    }
508    /// Number of coupled streams.
509    #[must_use]
510    pub const fn coupled_streams(&self) -> u8 {
511        self.coupled_streams
512    }
513
514    /// Create a multistream encoder using libopus surround mapping helpers.
515    ///
516    /// # Errors
517    /// Returns [`Error::BadArg`] for invalid channel counts or the mapped libopus
518    /// error when surround initialisation fails.
519    pub fn new_surround(
520        sr: SampleRate,
521        channels: u8,
522        mapping_family: i32,
523        app: Application,
524    ) -> Result<(Self, Vec<u8>)> {
525        if channels == 0 {
526            return Err(Error::BadArg);
527        }
528        let mut err = 0i32;
529        let mut streams = 0i32;
530        let mut coupled = 0i32;
531        let mut mapping = vec![0u8; channels as usize];
532        let enc = unsafe {
533            opus_multistream_surround_encoder_create(
534                sr as i32,
535                i32::from(channels),
536                mapping_family,
537                std::ptr::addr_of_mut!(streams),
538                std::ptr::addr_of_mut!(coupled),
539                mapping.as_mut_ptr(),
540                app as i32,
541                std::ptr::addr_of_mut!(err),
542            )
543        };
544        if err != 0 {
545            return Err(Error::from_code(err));
546        }
547        if enc.is_null() {
548            return Err(Error::AllocFail);
549        }
550        let streams_u8 = u8::try_from(streams).map_err(|_| Error::BadArg)?;
551        let coupled_u8 = u8::try_from(coupled).map_err(|_| Error::BadArg)?;
552        Ok((
553            Self {
554                raw: enc,
555                sample_rate: sr,
556                channels,
557                streams: streams_u8,
558                coupled_streams: coupled_u8,
559            },
560            mapping,
561        ))
562    }
563
564    /// Borrow a pointer to an individual underlying encoder state for CTLs.
565    ///
566    /// # Safety
567    /// Caller must not outlive the multistream encoder and must ensure the
568    /// returned pointer is only used for immediate FFI calls.
569    ///
570    /// # Errors
571    /// Returns [`Error::InvalidState`] if the encoder handle is invalid or propagates the
572    /// libopus error if retrieving the state fails.
573    pub unsafe fn encoder_state_ptr(&mut self, stream_index: i32) -> Result<*mut OpusEncoder> {
574        if self.raw.is_null() {
575            return Err(Error::InvalidState);
576        }
577        let mut state: *mut OpusEncoder = std::ptr::null_mut();
578        let r = unsafe {
579            opus_multistream_encoder_ctl(
580                self.raw,
581                OPUS_MULTISTREAM_GET_ENCODER_STATE_REQUEST as i32,
582                stream_index,
583                &mut state,
584            )
585        };
586        if r != 0 {
587            return Err(Error::from_code(r));
588        }
589        if state.is_null() {
590            return Err(Error::InternalError);
591        }
592        Ok(state)
593    }
594
595    fn simple_ctl(&mut self, req: i32, val: i32) -> Result<()> {
596        if self.raw.is_null() {
597            return Err(Error::InvalidState);
598        }
599        let r = unsafe { opus_multistream_encoder_ctl(self.raw, req, val) };
600        if r != 0 {
601            return Err(Error::from_code(r));
602        }
603        Ok(())
604    }
605
606    fn get_int_ctl(&mut self, req: i32) -> Result<i32> {
607        if self.raw.is_null() {
608            return Err(Error::InvalidState);
609        }
610        let mut v: i32 = 0;
611        let r = unsafe { opus_multistream_encoder_ctl(self.raw, req, &mut v) };
612        if r != 0 {
613            return Err(Error::from_code(r));
614        }
615        Ok(v)
616    }
617
618    fn get_bool_ctl(&mut self, req: i32) -> Result<bool> {
619        Ok(self.get_int_ctl(req)? != 0)
620    }
621
622    fn get_bandwidth_ctl(&mut self, req: i32) -> Result<Bandwidth> {
623        let v = u32::try_from(self.get_int_ctl(req)?).map_err(|_| Error::InternalError)?;
624        match v {
625            x if x == OPUS_BANDWIDTH_NARROWBAND => Ok(Bandwidth::Narrowband),
626            x if x == OPUS_BANDWIDTH_MEDIUMBAND => Ok(Bandwidth::Mediumband),
627            x if x == OPUS_BANDWIDTH_WIDEBAND => Ok(Bandwidth::Wideband),
628            x if x == OPUS_BANDWIDTH_SUPERWIDEBAND => Ok(Bandwidth::SuperWideband),
629            x if x == OPUS_BANDWIDTH_FULLBAND => Ok(Bandwidth::Fullband),
630            _ => Err(Error::InternalError),
631        }
632    }
633}
634
635impl Drop for MSEncoder {
636    fn drop(&mut self) {
637        unsafe { opus_multistream_encoder_destroy(self.raw) }
638    }
639}
640
641/// Safe wrapper around `OpusMSDecoder`.
642pub struct MSDecoder {
643    raw: *mut OpusMSDecoder,
644    sample_rate: SampleRate,
645    channels: u8,
646}
647
648unsafe impl Send for MSDecoder {}
649unsafe impl Sync for MSDecoder {}
650
651impl MSDecoder {
652    /// Create a new multistream decoder.
653    ///
654    /// # Errors
655    /// Returns [`Error::BadArg`] when the mapping dimensions are inconsistent, or
656    /// propagates allocation/configuration failures from libopus.
657    pub fn new(sr: SampleRate, mapping: Mapping<'_>) -> Result<Self> {
658        mapping.validate()?;
659        let mut err = 0i32;
660        let dec = unsafe {
661            opus_multistream_decoder_create(
662                sr as i32,
663                i32::from(mapping.channels),
664                i32::from(mapping.streams),
665                i32::from(mapping.coupled_streams),
666                mapping.mapping.as_ptr(),
667                std::ptr::addr_of_mut!(err),
668            )
669        };
670        if err != 0 {
671            return Err(Error::from_code(err));
672        }
673        if dec.is_null() {
674            return Err(Error::AllocFail);
675        }
676        Ok(Self {
677            raw: dec,
678            sample_rate: sr,
679            channels: mapping.channels,
680        })
681    }
682
683    /// Decode into interleaved i16 PCM (`frame_size` is per-channel).
684    ///
685    /// # Errors
686    /// Returns [`Error::InvalidState`] if the decoder handle is invalid, [`Error::BadArg`]
687    /// for buffer mismatches, or the mapped libopus error code.
688    pub fn decode(
689        &mut self,
690        packet: &[u8],
691        out: &mut [i16],
692        frame_size_per_ch: usize,
693        fec: bool,
694    ) -> Result<usize> {
695        if self.raw.is_null() {
696            return Err(Error::InvalidState);
697        }
698        if out.len() != frame_size_per_ch * self.channels as usize {
699            return Err(Error::BadArg);
700        }
701        let n = unsafe {
702            opus_multistream_decode(
703                self.raw,
704                if packet.is_empty() {
705                    std::ptr::null()
706                } else {
707                    packet.as_ptr()
708                },
709                if packet.is_empty() {
710                    0
711                } else {
712                    i32::try_from(packet.len()).map_err(|_| Error::BadArg)?
713                },
714                out.as_mut_ptr(),
715                i32::try_from(frame_size_per_ch).map_err(|_| Error::BadArg)?,
716                i32::from(fec),
717            )
718        };
719        if n < 0 {
720            return Err(Error::from_code(n));
721        }
722        usize::try_from(n).map_err(|_| Error::InternalError)
723    }
724
725    /// Decode into interleaved f32 PCM (`frame_size` is per-channel).
726    ///
727    /// # Errors
728    /// Returns [`Error::InvalidState`] if the decoder handle is invalid, [`Error::BadArg`]
729    /// for buffer mismatches, or the mapped libopus error code.
730    pub fn decode_float(
731        &mut self,
732        packet: &[u8],
733        out: &mut [f32],
734        frame_size_per_ch: usize,
735        fec: bool,
736    ) -> Result<usize> {
737        if self.raw.is_null() {
738            return Err(Error::InvalidState);
739        }
740        if out.len() != frame_size_per_ch * self.channels as usize {
741            return Err(Error::BadArg);
742        }
743        let n = unsafe {
744            opus_multistream_decode_float(
745                self.raw,
746                if packet.is_empty() {
747                    std::ptr::null()
748                } else {
749                    packet.as_ptr()
750                },
751                if packet.is_empty() {
752                    0
753                } else {
754                    i32::try_from(packet.len()).map_err(|_| Error::BadArg)?
755                },
756                out.as_mut_ptr(),
757                i32::try_from(frame_size_per_ch).map_err(|_| Error::BadArg)?,
758                i32::from(fec),
759            )
760        };
761        if n < 0 {
762            return Err(Error::from_code(n));
763        }
764        usize::try_from(n).map_err(|_| Error::InternalError)
765    }
766
767    /// Final RNG state from the last decode.
768    ///
769    /// # Errors
770    /// Returns [`Error::InvalidState`] when the decoder handle is null or
771    /// propagates the libopus error.
772    pub fn final_range(&mut self) -> Result<u32> {
773        if self.raw.is_null() {
774            return Err(Error::InvalidState);
775        }
776        let mut v: u32 = 0;
777        let r = unsafe {
778            opus_multistream_decoder_ctl(self.raw, OPUS_GET_FINAL_RANGE_REQUEST as i32, &mut v)
779        };
780        if r != 0 {
781            return Err(Error::from_code(r));
782        }
783        Ok(v)
784    }
785
786    /// Reset the decoder to its initial state.
787    ///
788    /// # Errors
789    /// Returns [`Error::InvalidState`] if the decoder handle is null or propagates any error
790    /// reported by libopus.
791    pub fn reset(&mut self) -> Result<()> {
792        if self.raw.is_null() {
793            return Err(Error::InvalidState);
794        }
795        let r = unsafe { opus_multistream_decoder_ctl(self.raw, OPUS_RESET_STATE as i32) };
796        if r != 0 {
797            return Err(Error::from_code(r));
798        }
799        Ok(())
800    }
801
802    /// Set post-decode gain in Q8 dB units.
803    ///
804    /// # Errors
805    /// Returns [`Error::InvalidState`] if the decoder handle is null or propagates any error
806    /// reported by libopus.
807    pub fn set_gain(&mut self, q8_db: i32) -> Result<()> {
808        self.simple_ctl(OPUS_SET_GAIN_REQUEST as i32, q8_db)
809    }
810
811    /// Query post-decode gain in Q8 dB units.
812    ///
813    /// # Errors
814    /// Returns [`Error::InvalidState`] if the decoder handle is null or propagates any error
815    /// reported by libopus.
816    pub fn gain(&mut self) -> Result<i32> {
817        self.get_int_ctl(OPUS_GET_GAIN_REQUEST as i32)
818    }
819
820    /// Disable or enable phase inversion (CELT stereo decorrelation).
821    ///
822    /// # Errors
823    /// Returns [`Error::InvalidState`] if the decoder handle is null or propagates any error
824    /// reported by libopus.
825    pub fn set_phase_inversion_disabled(&mut self, disabled: bool) -> Result<()> {
826        self.simple_ctl(
827            OPUS_SET_PHASE_INVERSION_DISABLED_REQUEST as i32,
828            i32::from(disabled),
829        )
830    }
831
832    /// Query the phase inversion disabled flag.
833    ///
834    /// # Errors
835    /// Returns [`Error::InvalidState`] if the decoder handle is null or propagates any error
836    /// reported by libopus.
837    pub fn phase_inversion_disabled(&mut self) -> Result<bool> {
838        self.get_bool_ctl(OPUS_GET_PHASE_INVERSION_DISABLED_REQUEST as i32)
839    }
840
841    /// Query decoder output sample rate.
842    ///
843    /// # Errors
844    /// Returns [`Error::InvalidState`] if the decoder handle is null or propagates any error
845    /// reported by libopus.
846    pub fn get_sample_rate(&mut self) -> Result<i32> {
847        self.get_int_ctl(OPUS_GET_SAMPLE_RATE_REQUEST as i32)
848    }
849
850    /// Query the pitch (fundamental period) of the last decoded frame.
851    ///
852    /// # Errors
853    /// Returns [`Error::InvalidState`] if the decoder handle is null or propagates any error
854    /// reported by libopus.
855    pub fn get_pitch(&mut self) -> Result<i32> {
856        self.get_int_ctl(OPUS_GET_PITCH_REQUEST as i32)
857    }
858
859    /// Query the duration (per channel) of the last decoded packet.
860    ///
861    /// # Errors
862    /// Returns [`Error::InvalidState`] if the decoder handle is null or propagates any error
863    /// reported by libopus.
864    pub fn get_last_packet_duration(&mut self) -> Result<i32> {
865        self.get_int_ctl(OPUS_GET_LAST_PACKET_DURATION_REQUEST as i32)
866    }
867
868    /// Output channels (interleaved).
869    #[must_use]
870    pub const fn channels(&self) -> u8 {
871        self.channels
872    }
873    /// Output sample rate.
874    #[must_use]
875    pub const fn sample_rate(&self) -> SampleRate {
876        self.sample_rate
877    }
878
879    /// Create a multistream decoder using libopus surround mapping helpers.
880    ///
881    /// # Errors
882    /// Returns [`Error::BadArg`] for invalid channel counts or the mapped libopus
883    /// error when decoder initialisation fails.
884    pub fn new_surround(
885        sr: SampleRate,
886        channels: u8,
887        mapping_family: i32,
888    ) -> Result<(Self, Vec<u8>, u8, u8)> {
889        if channels == 0 {
890            return Err(Error::BadArg);
891        }
892        let mut err = 0i32;
893        let mut streams = 0i32;
894        let mut coupled = 0i32;
895        let mut mapping = vec![0u8; channels as usize];
896        // libopus exposes surround helper creation only for encoders; callers
897        // should use the returned mapping/stream counts to configure this decoder.
898        let enc = unsafe {
899            opus_multistream_surround_encoder_create(
900                sr as i32,
901                i32::from(channels),
902                mapping_family,
903                std::ptr::addr_of_mut!(streams),
904                std::ptr::addr_of_mut!(coupled),
905                mapping.as_mut_ptr(),
906                Application::Audio as i32,
907                std::ptr::addr_of_mut!(err),
908            )
909        };
910        if !enc.is_null() {
911            unsafe { opus_multistream_encoder_destroy(enc) };
912        }
913        if err != 0 {
914            return Err(Error::from_code(err));
915        }
916        let dec = unsafe {
917            opus_multistream_decoder_create(
918                sr as i32,
919                i32::from(channels),
920                streams,
921                coupled,
922                mapping.as_ptr(),
923                std::ptr::addr_of_mut!(err),
924            )
925        };
926        if err != 0 {
927            return Err(Error::from_code(err));
928        }
929        if dec.is_null() {
930            return Err(Error::AllocFail);
931        }
932        Ok((
933            Self {
934                raw: dec,
935                sample_rate: sr,
936                channels,
937            },
938            mapping,
939            u8::try_from(streams).map_err(|_| Error::BadArg)?,
940            u8::try_from(coupled).map_err(|_| Error::BadArg)?,
941        ))
942    }
943
944    /// Borrow a pointer to an individual underlying decoder state for CTLs.
945    ///
946    /// # Safety
947    /// Caller must not outlive the multistream decoder and must ensure the
948    /// returned pointer is only used for immediate FFI calls.
949    ///
950    /// # Errors
951    /// Returns [`Error::InvalidState`] if the decoder handle is invalid or propagates the
952    /// libopus error when retrieving the per-stream state fails.
953    pub unsafe fn decoder_state_ptr(&mut self, stream_index: i32) -> Result<*mut OpusDecoder> {
954        if self.raw.is_null() {
955            return Err(Error::InvalidState);
956        }
957        let mut state: *mut OpusDecoder = std::ptr::null_mut();
958        let r = unsafe {
959            opus_multistream_decoder_ctl(
960                self.raw,
961                OPUS_MULTISTREAM_GET_DECODER_STATE_REQUEST as i32,
962                stream_index,
963                &mut state,
964            )
965        };
966        if r != 0 {
967            return Err(Error::from_code(r));
968        }
969        if state.is_null() {
970            return Err(Error::InternalError);
971        }
972        Ok(state)
973    }
974
975    fn simple_ctl(&mut self, req: i32, val: i32) -> Result<()> {
976        if self.raw.is_null() {
977            return Err(Error::InvalidState);
978        }
979        let r = unsafe { opus_multistream_decoder_ctl(self.raw, req, val) };
980        if r != 0 {
981            return Err(Error::from_code(r));
982        }
983        Ok(())
984    }
985
986    fn get_int_ctl(&mut self, req: i32) -> Result<i32> {
987        if self.raw.is_null() {
988            return Err(Error::InvalidState);
989        }
990        let mut v: i32 = 0;
991        let r = unsafe { opus_multistream_decoder_ctl(self.raw, req, &mut v) };
992        if r != 0 {
993            return Err(Error::from_code(r));
994        }
995        Ok(v)
996    }
997
998    fn get_bool_ctl(&mut self, req: i32) -> Result<bool> {
999        Ok(self.get_int_ctl(req)? != 0)
1000    }
1001}
1002
1003impl Drop for MSDecoder {
1004    fn drop(&mut self) {
1005        unsafe { opus_multistream_decoder_destroy(self.raw) }
1006    }
1007}
1008
1009#[cfg(test)]
1010mod tests {
1011    use super::*;
1012
1013    #[test]
1014    fn mapping_allows_dropped_channels() {
1015        let mapping = Mapping {
1016            channels: 6,
1017            streams: 1,
1018            coupled_streams: 2,
1019            mapping: &[0, 1, 1, 2, 2, u8::MAX],
1020        };
1021        assert!(mapping.validate().is_ok());
1022    }
1023
1024    #[test]
1025    fn mapping_rejects_duplicate_mono_assignments() {
1026        let mapping = Mapping {
1027            channels: 3,
1028            streams: 1,
1029            coupled_streams: 1,
1030            mapping: &[0, 0, 1],
1031        };
1032        assert!(mapping.validate().is_err());
1033    }
1034}