Skip to main content

rift_media/
lib.rs

1//! Audio capture, playback, and codec helpers.
2//!
3//! This module wraps CPAL for audio I/O and Opus for encoding/decoding.
4//! It also provides a simple mixer for combining multiple streams.
5
6use std::collections::{HashMap, VecDeque};
7use std::sync::{Arc, Mutex};
8use std::time::Duration;
9
10use anyhow::{anyhow, Result};
11use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
12use rift_protocol::CodecId;
13
14#[derive(Debug, Clone)]
15pub struct AudioConfig {
16    /// Sample rate in Hz.
17    pub sample_rate: u32,
18    /// Frame duration in milliseconds.
19    pub frame_duration_ms: u32,
20    /// Channel count (mono/stereo).
21    pub channels: u16,
22    /// Target bitrate for Opus.
23    pub bitrate: u32,
24}
25
26impl Default for AudioConfig {
27    fn default() -> Self {
28        Self {
29            sample_rate: 48_000,
30            frame_duration_ms: 20,
31            channels: 1,
32            bitrate: 48_000,
33        }
34    }
35}
36
37impl AudioConfig {
38    /// Number of samples per frame (all channels included).
39    pub fn frame_samples(&self) -> usize {
40        let per_channel = (self.sample_rate as usize * self.frame_duration_ms as usize) / 1000;
41        per_channel * self.channels as usize
42    }
43
44    /// Frame duration as a `Duration`.
45    pub fn frame_duration(&self) -> Duration {
46        Duration::from_millis(self.frame_duration_ms as u64)
47    }
48}
49
50/// Encode a single audio frame into the requested codec.
51pub fn encode_frame(codec: CodecId, frame: &[i16], encoder: &mut OpusEncoder) -> Result<Vec<u8>> {
52    match codec {
53        CodecId::Opus => {
54            let mut out = vec![0u8; 4000];
55            let len = encoder.encode_i16(frame, &mut out)?;
56            out.truncate(len);
57            Ok(out)
58        }
59        CodecId::PCM16 => {
60            let mut out = Vec::with_capacity(frame.len() * 2);
61            for sample in frame {
62                out.extend_from_slice(&sample.to_le_bytes());
63            }
64            Ok(out)
65        }
66        CodecId::Experimental(_) => Err(anyhow!("unsupported codec")),
67    }
68}
69
70/// Decode a single audio frame payload into PCM samples.
71pub fn decode_frame(
72    codec: CodecId,
73    payload: &[u8],
74    decoder: &mut OpusDecoder,
75    frame_samples: usize,
76) -> Result<Vec<i16>> {
77    match codec {
78        CodecId::Opus => {
79            let mut out = vec![0i16; frame_samples];
80            let len = decoder.decode_i16(payload, &mut out)?;
81            out.truncate(len);
82            Ok(out)
83        }
84        CodecId::PCM16 => {
85            let mut out = Vec::with_capacity(payload.len() / 2);
86            for chunk in payload.chunks_exact(2) {
87                out.push(i16::from_le_bytes([chunk[0], chunk[1]]));
88            }
89            Ok(out)
90        }
91        CodecId::Experimental(_) => Err(anyhow!("unsupported codec")),
92    }
93}
94
95/// Active audio input stream wrapper.
96pub struct AudioIn {
97    _stream: cpal::Stream,
98}
99
100impl AudioIn {
101    /// Open the default input device and start capturing.
102    pub fn new(config: &AudioConfig) -> Result<(Self, tokio::sync::mpsc::Receiver<Vec<i16>>)> {
103        Self::new_with_device(config, None)
104    }
105
106    /// Open a specific input device by name (or default if None).
107    pub fn new_with_device(
108        config: &AudioConfig,
109        device_name: Option<&str>,
110    ) -> Result<(Self, tokio::sync::mpsc::Receiver<Vec<i16>>)> {
111        let host = cpal::default_host();
112        let device = if let Some(name) = device_name {
113            find_input_device(&host, name)?
114        } else {
115            host.default_input_device()
116                .ok_or_else(|| anyhow!("no default input device"))?
117        };
118
119        let supported_config = device.default_input_config()?;
120        let sample_format = supported_config.sample_format();
121        let mut stream_config: cpal::StreamConfig = supported_config.into();
122        stream_config.channels = config.channels;
123        stream_config.sample_rate = cpal::SampleRate(config.sample_rate);
124
125        let frame_samples = config.frame_samples();
126        let (tx, rx) = tokio::sync::mpsc::channel::<Vec<i16>>(64);
127        let buffer = Arc::new(Mutex::new(Vec::with_capacity(frame_samples * 2)));
128        let buffer_clone = buffer.clone();
129
130        let err_fn = |err| tracing::error!("audio input error: {err}");
131
132        let stream = match sample_format {
133            cpal::SampleFormat::I16 => device.build_input_stream(
134                &stream_config,
135                move |data: &[i16], _| {
136                    audio_in_callback(data, frame_samples, &tx, &buffer_clone);
137                },
138                err_fn,
139                None,
140            )?,
141            cpal::SampleFormat::F32 => device.build_input_stream(
142                &stream_config,
143                move |data: &[f32], _| {
144                    let converted: Vec<i16> = data
145                        .iter()
146                        .map(|s| (s.clamp(-1.0, 1.0) * i16::MAX as f32) as i16)
147                        .collect();
148                    audio_in_callback(&converted, frame_samples, &tx, &buffer_clone);
149                },
150                err_fn,
151                None,
152            )?,
153            cpal::SampleFormat::U16 => device.build_input_stream(
154                &stream_config,
155                move |data: &[u16], _| {
156                    let converted: Vec<i16> = data.iter().map(|s| (*s as i32 - 32768) as i16).collect();
157                    audio_in_callback(&converted, frame_samples, &tx, &buffer_clone);
158                },
159                err_fn,
160                None,
161            )?,
162            _ => return Err(anyhow!("unsupported input sample format")),
163        };
164
165        stream.play()?;
166        Ok((Self { _stream: stream }, rx))
167    }
168}
169
170/// Buffer callback for audio input streams.
171fn audio_in_callback(
172    data: &[i16],
173    frame_samples: usize,
174    tx: &tokio::sync::mpsc::Sender<Vec<i16>>,
175    buffer: &Arc<Mutex<Vec<i16>>>,
176) {
177    let mut buf = buffer.lock().unwrap();
178    buf.extend_from_slice(data);
179    while buf.len() >= frame_samples {
180        let frame: Vec<i16> = buf.drain(..frame_samples).collect();
181        let _ = tx.try_send(frame);
182    }
183}
184
185/// Active audio output stream wrapper.
186pub struct AudioOut {
187    _stream: cpal::Stream,
188    queue: Arc<Mutex<VecDeque<i16>>>,
189    frame_samples: usize,
190    channels: u16,
191}
192
193impl AudioOut {
194    /// Open the default output device and start playback.
195    pub fn new(config: &AudioConfig) -> Result<Self> {
196        Self::new_with_device(config, None)
197    }
198
199    /// Open a specific output device by name (or default if None).
200    pub fn new_with_device(config: &AudioConfig, device_name: Option<&str>) -> Result<Self> {
201        let host = cpal::default_host();
202        let device = if let Some(name) = device_name {
203            find_output_device(&host, name)?
204        } else {
205            host.default_output_device()
206                .ok_or_else(|| anyhow!("no default output device"))?
207        };
208
209        let supported_config = device.default_output_config()?;
210        let sample_format = supported_config.sample_format();
211        let mut stream_config: cpal::StreamConfig = supported_config.into();
212        stream_config.channels = config.channels;
213        stream_config.sample_rate = cpal::SampleRate(config.sample_rate);
214
215        let queue = Arc::new(Mutex::new(VecDeque::with_capacity(config.frame_samples() * 4)));
216        let queue_clone = queue.clone();
217        let frame_samples = config.frame_samples();
218        let channels = config.channels;
219
220        let err_fn = |err| tracing::error!("audio output error: {err}");
221
222        let stream = match sample_format {
223            cpal::SampleFormat::I16 => device.build_output_stream(
224                &stream_config,
225                move |data: &mut [i16], _| {
226                    audio_out_callback_i16(data, &queue_clone);
227                },
228                err_fn,
229                None,
230            )?,
231            cpal::SampleFormat::F32 => device.build_output_stream(
232                &stream_config,
233                move |data: &mut [f32], _| {
234                    audio_out_callback_f32(data, &queue_clone);
235                },
236                err_fn,
237                None,
238            )?,
239            cpal::SampleFormat::U16 => device.build_output_stream(
240                &stream_config,
241                move |data: &mut [u16], _| {
242                    audio_out_callback_u16(data, &queue_clone);
243                },
244                err_fn,
245                None,
246            )?,
247            _ => return Err(anyhow!("unsupported output sample format")),
248        };
249
250        stream.play()?;
251
252        Ok(Self {
253            _stream: stream,
254            queue,
255            frame_samples,
256            channels,
257        })
258    }
259
260    /// Push a PCM frame into the playback queue.
261    pub fn push_frame(&self, frame: &[i16]) {
262        let mut queue = self.queue.lock().unwrap();
263        for sample in frame {
264            queue.push_back(*sample);
265        }
266    }
267
268    /// Number of samples per frame.
269    pub fn frame_samples(&self) -> usize {
270        self.frame_samples
271    }
272
273    /// Channel count.
274    pub fn channels(&self) -> u16 {
275        self.channels
276    }
277
278    /// Samples currently queued for playback.
279    pub fn queued_samples(&self) -> usize {
280        let queue = self.queue.lock().unwrap();
281        queue.len()
282    }
283}
284
285/// Find an input device by name.
286fn find_input_device(host: &cpal::Host, name: &str) -> Result<cpal::Device> {
287    for device in host.input_devices()? {
288        if let Ok(dev_name) = device.name() {
289            if dev_name == name {
290                return Ok(device);
291            }
292        }
293    }
294    Err(anyhow!("input device not found: {}", name))
295}
296
297/// Find an output device by name.
298fn find_output_device(host: &cpal::Host, name: &str) -> Result<cpal::Device> {
299    for device in host.output_devices()? {
300        if let Ok(dev_name) = device.name() {
301            if dev_name == name {
302                return Ok(device);
303            }
304        }
305    }
306    Err(anyhow!("output device not found: {}", name))
307}
308
309/// Write queued PCM data into an i16 output buffer.
310fn audio_out_callback_i16(data: &mut [i16], queue: &Arc<Mutex<VecDeque<i16>>>) {
311    let mut q = queue.lock().unwrap();
312    for sample in data.iter_mut() {
313        *sample = q.pop_front().unwrap_or(0);
314    }
315}
316
317/// Write queued PCM data into an f32 output buffer.
318fn audio_out_callback_f32(data: &mut [f32], queue: &Arc<Mutex<VecDeque<i16>>>) {
319    let mut q = queue.lock().unwrap();
320    for sample in data.iter_mut() {
321        let v = q.pop_front().unwrap_or(0);
322        *sample = v as f32 / i16::MAX as f32;
323    }
324}
325
326/// Write queued PCM data into a u16 output buffer.
327fn audio_out_callback_u16(data: &mut [u16], queue: &Arc<Mutex<VecDeque<i16>>>) {
328    let mut q = queue.lock().unwrap();
329    for sample in data.iter_mut() {
330        let v = q.pop_front().unwrap_or(0);
331        *sample = (v as i32 + 32768) as u16;
332    }
333}
334
335/// Thin wrapper around Opus encoder with Rift-friendly defaults.
336pub struct OpusEncoder {
337    inner: opus::Encoder,
338}
339
340impl OpusEncoder {
341    /// Construct an Opus encoder based on the audio config.
342    pub fn new(config: &AudioConfig) -> Result<Self> {
343        let channels = match config.channels {
344            1 => opus::Channels::Mono,
345            2 => opus::Channels::Stereo,
346            _ => return Err(anyhow!("unsupported channel count")),
347        };
348        let mut encoder = opus::Encoder::new(config.sample_rate, channels, opus::Application::Voip)?;
349        encoder.set_bitrate(opus::Bitrate::Bits(config.bitrate as i32))?;
350        Ok(Self { inner: encoder })
351    }
352
353    /// Encode i16 PCM samples.
354    pub fn encode_i16(&mut self, frame: &[i16], out: &mut [u8]) -> Result<usize> {
355        let len = self.inner.encode(frame, out)?;
356        Ok(len)
357    }
358
359    /// Encode f32 PCM samples.
360    pub fn encode_f32(&mut self, frame: &[f32], out: &mut [u8]) -> Result<usize> {
361        let len = self.inner.encode_float(frame, out)?;
362        Ok(len)
363    }
364
365    /// Update target bitrate.
366    pub fn set_bitrate(&mut self, bitrate: u32) -> Result<()> {
367        self.inner
368            .set_bitrate(opus::Bitrate::Bits(bitrate as i32))?;
369        Ok(())
370    }
371
372    /// Enable or disable in-band FEC.
373    pub fn set_fec(&mut self, enabled: bool) -> Result<()> {
374        self.inner.set_inband_fec(enabled)?;
375        Ok(())
376    }
377
378    /// Set expected packet loss percentage.
379    pub fn set_packet_loss(&mut self, loss_pct: u8) -> Result<()> {
380        let pct = loss_pct.min(100);
381        self.inner.set_packet_loss_perc(pct as i32)?;
382        Ok(())
383    }
384}
385
386/// Thin wrapper around Opus decoder.
387pub struct OpusDecoder {
388    inner: opus::Decoder,
389}
390
391impl OpusDecoder {
392    /// Construct an Opus decoder based on the audio config.
393    pub fn new(config: &AudioConfig) -> Result<Self> {
394        let channels = match config.channels {
395            1 => opus::Channels::Mono,
396            2 => opus::Channels::Stereo,
397            _ => return Err(anyhow!("unsupported channel count")),
398        };
399        let decoder = opus::Decoder::new(config.sample_rate, channels)?;
400        Ok(Self { inner: decoder })
401    }
402
403    /// Decode Opus payload into i16 PCM samples.
404    pub fn decode_i16(&mut self, data: &[u8], out: &mut [i16]) -> Result<usize> {
405        let len = self.inner.decode(data, out, false)?;
406        Ok(len)
407    }
408
409    /// Decode Opus payload into f32 PCM samples.
410    pub fn decode_f32(&mut self, data: &[u8], out: &mut [f32]) -> Result<usize> {
411        let len = self.inner.decode_float(data, out, false)?;
412        Ok(len)
413    }
414}
415
416#[derive(Debug, Clone)]
417pub struct VoiceFrame {
418    /// Sequence number assigned by sender.
419    pub seq: u32,
420    /// Capture timestamp in milliseconds.
421    pub timestamp: u64,
422    /// Encoded audio payload.
423    pub payload: Vec<u8>,
424}
425
426/// Simple audio mixer for combining multiple PCM streams.
427pub struct AudioMixer {
428    frame_samples: usize,
429    prebuffer_frames: usize,
430    streams: HashMap<u64, StreamState>,
431}
432
433struct StreamState {
434    queue: VecDeque<Vec<i16>>,
435    prebuffer: usize,
436}
437
438impl AudioMixer {
439    /// Create a mixer with no prebuffer.
440    pub fn new(frame_samples: usize) -> Self {
441        Self::with_prebuffer(frame_samples, 0)
442    }
443
444    /// Create a mixer with a specified prebuffer in frames.
445    pub fn with_prebuffer(frame_samples: usize, prebuffer_frames: usize) -> Self {
446        Self {
447            frame_samples,
448            prebuffer_frames,
449            streams: HashMap::new(),
450        }
451    }
452
453    /// Push a new frame into a stream's queue.
454    pub fn push(&mut self, stream_id: u64, frame: Vec<i16>) {
455        let entry = self.streams.entry(stream_id).or_insert_with(|| StreamState {
456            queue: VecDeque::new(),
457            prebuffer: self.prebuffer_frames,
458        });
459        entry.queue.push_back(frame);
460    }
461
462    /// Mix the next frame without returning activity information.
463    pub fn mix_next(&mut self) -> Vec<i16> {
464        let (frame, _active) = self.mix_next_with_activity();
465        frame
466    }
467
468    /// Mix the next frame and report whether any streams were active.
469    pub fn mix_next_with_activity(&mut self) -> (Vec<i16>, bool) {
470        let mut active = 0usize;
471        let mut mix = vec![0i32; self.frame_samples];
472
473        let mut remove = Vec::new();
474        for (id, queue) in self.streams.iter_mut() {
475            if queue.prebuffer > 0 {
476                if queue.queue.len() >= queue.prebuffer {
477                    queue.prebuffer = 0;
478                } else {
479                    continue;
480                }
481            }
482
483            if let Some(frame) = queue.queue.pop_front() {
484                active += 1;
485                for (i, sample) in frame.iter().enumerate() {
486                    mix[i] += *sample as i32;
487                }
488            }
489
490            if queue.queue.is_empty() && queue.prebuffer == 0 {
491                remove.push(*id);
492            }
493        }
494
495        for id in remove {
496            self.streams.remove(&id);
497        }
498
499        let scale = if active > 0 { active as i32 } else { 1 };
500        let frame = mix.into_iter()
501            .map(|v| {
502                let scaled = v / scale;
503                scaled.clamp(i16::MIN as i32, i16::MAX as i32) as i16
504            })
505            .collect();
506        (frame, active > 0)
507    }
508}
509
510#[cfg(test)]
511mod tests {
512    use super::*;
513
514    #[test]
515    fn audio_config_default_frame_samples() {
516        let config = AudioConfig::default();
517        // 48kHz * 20ms / 1000 * 1 channel = 960 samples
518        assert_eq!(config.frame_samples(), 960);
519    }
520
521    #[test]
522    fn audio_config_stereo_frame_samples() {
523        let config = AudioConfig {
524            sample_rate: 48_000,
525            frame_duration_ms: 20,
526            channels: 2,
527            bitrate: 48_000,
528        };
529        // 48kHz * 20ms / 1000 * 2 channels = 1920 samples
530        assert_eq!(config.frame_samples(), 1920);
531    }
532
533    #[test]
534    fn audio_config_different_sample_rates() {
535        // 8kHz mono 20ms
536        let config8k = AudioConfig {
537            sample_rate: 8_000,
538            frame_duration_ms: 20,
539            channels: 1,
540            bitrate: 12_000,
541        };
542        assert_eq!(config8k.frame_samples(), 160);
543
544        // 16kHz mono 20ms
545        let config16k = AudioConfig {
546            sample_rate: 16_000,
547            frame_duration_ms: 20,
548            channels: 1,
549            bitrate: 24_000,
550        };
551        assert_eq!(config16k.frame_samples(), 320);
552    }
553
554    #[test]
555    fn audio_config_frame_duration() {
556        let config = AudioConfig {
557            sample_rate: 48_000,
558            frame_duration_ms: 10,
559            channels: 1,
560            bitrate: 48_000,
561        };
562        assert_eq!(config.frame_duration(), Duration::from_millis(10));
563    }
564
565    #[test]
566    fn pcm16_encode_decode_roundtrip() {
567        let config = AudioConfig::default();
568        let mut encoder = OpusEncoder::new(&config).unwrap();
569        let mut decoder = OpusDecoder::new(&config).unwrap();
570
571        // Create a simple test frame with a sine wave pattern
572        let frame: Vec<i16> = (0..config.frame_samples())
573            .map(|i| ((i as f32 * 0.1).sin() * 10000.0) as i16)
574            .collect();
575
576        // PCM16 should be lossless
577        let encoded = encode_frame(CodecId::PCM16, &frame, &mut encoder).unwrap();
578        let decoded = decode_frame(CodecId::PCM16, &encoded, &mut decoder, config.frame_samples()).unwrap();
579
580        assert_eq!(frame, decoded);
581    }
582
583    #[test]
584    fn opus_encode_decode_roundtrip() {
585        let config = AudioConfig::default();
586        let mut encoder = OpusEncoder::new(&config).unwrap();
587        let mut decoder = OpusDecoder::new(&config).unwrap();
588
589        // Create test frame
590        let frame: Vec<i16> = (0..config.frame_samples())
591            .map(|i| ((i as f32 * 0.05).sin() * 5000.0) as i16)
592            .collect();
593
594        let encoded = encode_frame(CodecId::Opus, &frame, &mut encoder).unwrap();
595        let decoded = decode_frame(CodecId::Opus, &encoded, &mut decoder, config.frame_samples()).unwrap();
596
597        // Opus is lossy, so just verify we got the right number of samples
598        assert_eq!(decoded.len(), config.frame_samples());
599    }
600
601    #[test]
602    fn opus_encoder_bitrate_change() {
603        let config = AudioConfig::default();
604        let mut encoder = OpusEncoder::new(&config).unwrap();
605
606        // Should not error
607        encoder.set_bitrate(24_000).unwrap();
608        encoder.set_bitrate(96_000).unwrap();
609    }
610
611    #[test]
612    fn opus_encoder_fec_setting() {
613        let config = AudioConfig::default();
614        let mut encoder = OpusEncoder::new(&config).unwrap();
615
616        encoder.set_fec(true).unwrap();
617        encoder.set_fec(false).unwrap();
618    }
619
620    #[test]
621    fn opus_encoder_packet_loss_setting() {
622        let config = AudioConfig::default();
623        let mut encoder = OpusEncoder::new(&config).unwrap();
624
625        encoder.set_packet_loss(0).unwrap();
626        encoder.set_packet_loss(10).unwrap();
627        encoder.set_packet_loss(100).unwrap();
628        // Values > 100 should be clamped
629        encoder.set_packet_loss(255).unwrap();
630    }
631
632    #[test]
633    fn mixer_single_stream() {
634        let mut mixer = AudioMixer::new(4);
635
636        mixer.push(1, vec![100, 200, 300, 400]);
637        let (frame, active) = mixer.mix_next_with_activity();
638
639        assert!(active);
640        assert_eq!(frame, vec![100, 200, 300, 400]);
641    }
642
643    #[test]
644    fn mixer_multiple_streams_averaging() {
645        let mut mixer = AudioMixer::new(4);
646
647        mixer.push(1, vec![100, 200, 300, 400]);
648        mixer.push(2, vec![100, 200, 300, 400]);
649        let (frame, active) = mixer.mix_next_with_activity();
650
651        assert!(active);
652        // Two streams with identical values, divided by 2
653        assert_eq!(frame, vec![100, 200, 300, 400]);
654    }
655
656    #[test]
657    fn mixer_empty_returns_zeros() {
658        let mut mixer = AudioMixer::new(4);
659
660        let (frame, active) = mixer.mix_next_with_activity();
661
662        assert!(!active);
663        assert_eq!(frame, vec![0, 0, 0, 0]);
664    }
665
666    #[test]
667    fn mixer_stream_removal_when_empty() {
668        let mut mixer = AudioMixer::new(4);
669
670        mixer.push(1, vec![100, 200, 300, 400]);
671
672        // First mix consumes the frame
673        let (_, active) = mixer.mix_next_with_activity();
674        assert!(active);
675
676        // Second mix should find no active streams
677        let (frame, active) = mixer.mix_next_with_activity();
678        assert!(!active);
679        assert_eq!(frame, vec![0, 0, 0, 0]);
680    }
681
682    #[test]
683    fn mixer_prebuffer() {
684        let mut mixer = AudioMixer::with_prebuffer(4, 2);
685
686        // Push first frame - still prebuffering
687        mixer.push(1, vec![100, 200, 300, 400]);
688        let (frame, active) = mixer.mix_next_with_activity();
689        assert!(!active);
690        assert_eq!(frame, vec![0, 0, 0, 0]);
691
692        // Push second frame - prebuffer satisfied
693        mixer.push(1, vec![500, 600, 700, 800]);
694        let (frame, active) = mixer.mix_next_with_activity();
695        assert!(active);
696        assert_eq!(frame, vec![100, 200, 300, 400]);
697    }
698
699    #[test]
700    fn mixer_clipping_protection() {
701        let mut mixer = AudioMixer::new(2);
702
703        // Push values that would overflow if not averaged
704        mixer.push(1, vec![i16::MAX, i16::MAX]);
705        mixer.push(2, vec![i16::MAX, i16::MAX]);
706        let (frame, _) = mixer.mix_next_with_activity();
707
708        // Should be averaged, not clipped
709        assert_eq!(frame, vec![i16::MAX, i16::MAX]);
710    }
711
712    #[test]
713    fn voice_frame_construction() {
714        let frame = VoiceFrame {
715            seq: 42,
716            timestamp: 1234567890,
717            payload: vec![1, 2, 3, 4],
718        };
719
720        assert_eq!(frame.seq, 42);
721        assert_eq!(frame.timestamp, 1234567890);
722        assert_eq!(frame.payload, vec![1, 2, 3, 4]);
723    }
724}