Skip to main content

piper_plus/
playback.rs

1//! Real-time audio playback via rodio
2//!
3//! Feature-gated behind `playback` feature flag.
4//! `cargo build --features playback` to enable.
5//!
6//! When the `playback` feature is disabled, only [`DummyPlayer`] and
7//! the helper function [`play_audio`] (which delegates to [`DummyPlayer`])
8//! are available.
9
10use crate::error::PiperError;
11use crate::streaming::AudioSink;
12
13// ---------------------------------------------------------------------------
14// DummyPlayer -- always available, useful for testing / benchmarking
15// ---------------------------------------------------------------------------
16
17/// A no-op audio player that discards all samples.
18///
19/// Useful for testing, benchmarking, or running on systems without an
20/// audio output device.
21pub struct DummyPlayer {
22    /// Total number of samples received across all `write_chunk` calls.
23    total_samples: usize,
24    /// Number of `write_chunk` calls received.
25    chunk_count: usize,
26    /// Last sample rate seen (0 if no chunks received yet).
27    last_sample_rate: u32,
28    /// Whether `finalize` has been called.
29    finalized: bool,
30}
31
32impl DummyPlayer {
33    /// Create a new `DummyPlayer`.
34    pub fn new() -> Self {
35        Self {
36            total_samples: 0,
37            chunk_count: 0,
38            last_sample_rate: 0,
39            finalized: false,
40        }
41    }
42
43    /// Total number of samples received.
44    pub fn total_samples(&self) -> usize {
45        self.total_samples
46    }
47
48    /// Number of chunks received.
49    pub fn chunk_count(&self) -> usize {
50        self.chunk_count
51    }
52
53    /// Last sample rate seen (0 if no chunks received).
54    pub fn last_sample_rate(&self) -> u32 {
55        self.last_sample_rate
56    }
57
58    /// Whether `finalize` has been called.
59    pub fn is_finalized(&self) -> bool {
60        self.finalized
61    }
62}
63
64impl Default for DummyPlayer {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl AudioSink for DummyPlayer {
71    fn write_chunk(&mut self, samples: &[i16], sample_rate: u32) -> Result<(), PiperError> {
72        if self.finalized {
73            return Err(PiperError::Inference(
74                "DummyPlayer: write_chunk called after finalize".to_string(),
75            ));
76        }
77        if sample_rate == 0 {
78            return Err(PiperError::Inference("sample rate must be > 0".to_string()));
79        }
80        self.total_samples += samples.len();
81        self.chunk_count += 1;
82        self.last_sample_rate = sample_rate;
83        Ok(())
84    }
85
86    fn finalize(&mut self) -> Result<(), PiperError> {
87        self.finalized = true;
88        Ok(())
89    }
90}
91
92// ---------------------------------------------------------------------------
93// CollectorSink -- collects all samples for later inspection
94// ---------------------------------------------------------------------------
95
96/// An `AudioSink` that collects all samples into an internal buffer.
97///
98/// Primarily intended for testing -- you can inspect the accumulated
99/// samples after synthesis is complete.
100pub struct CollectorSink {
101    samples: Vec<i16>,
102    sample_rate: Option<u32>,
103    finalized: bool,
104}
105
106impl CollectorSink {
107    /// Create a new empty collector.
108    pub fn new() -> Self {
109        Self {
110            samples: Vec::new(),
111            sample_rate: None,
112            finalized: false,
113        }
114    }
115
116    /// Return all collected samples.
117    pub fn samples(&self) -> &[i16] {
118        &self.samples
119    }
120
121    /// Return the sample rate (from the first chunk), if any.
122    pub fn sample_rate(&self) -> Option<u32> {
123        self.sample_rate
124    }
125
126    /// Whether `finalize` has been called.
127    pub fn is_finalized(&self) -> bool {
128        self.finalized
129    }
130
131    /// Consume self and return the collected samples.
132    pub fn into_samples(self) -> Vec<i16> {
133        self.samples
134    }
135}
136
137impl Default for CollectorSink {
138    fn default() -> Self {
139        Self::new()
140    }
141}
142
143impl AudioSink for CollectorSink {
144    fn write_chunk(&mut self, samples: &[i16], sample_rate: u32) -> Result<(), PiperError> {
145        if self.finalized {
146            return Err(PiperError::Inference(
147                "CollectorSink: write_chunk called after finalize".to_string(),
148            ));
149        }
150        if sample_rate == 0 {
151            return Err(PiperError::Inference("sample rate must be > 0".to_string()));
152        }
153        // Detect sample rate mismatch across chunks
154        if let Some(prev) = self.sample_rate
155            && prev != sample_rate
156        {
157            return Err(PiperError::Inference(format!(
158                "sample rate mismatch: expected {prev}, got {sample_rate}"
159            )));
160        }
161        self.sample_rate = Some(sample_rate);
162        self.samples.extend_from_slice(samples);
163        Ok(())
164    }
165
166    fn finalize(&mut self) -> Result<(), PiperError> {
167        self.finalized = true;
168        Ok(())
169    }
170}
171
172// ---------------------------------------------------------------------------
173// RodioPlayer -- feature-gated behind "playback"
174// ---------------------------------------------------------------------------
175
176/// Real-time audio player using rodio.
177///
178/// Plays audio chunks through the default audio output device.
179/// Feature-gated behind the `playback` Cargo feature.
180///
181/// # Example (requires `--features playback`)
182///
183/// ```ignore
184/// use piper_plus::playback::RodioPlayer;
185/// use piper_plus::streaming::AudioSink;
186///
187/// let mut player = RodioPlayer::new()?;
188/// player.write_chunk(&samples, 22050)?;
189/// player.finalize()?;
190/// player.wait_until_done();
191/// ```
192#[cfg(feature = "playback")]
193pub struct RodioPlayer {
194    /// Must be kept alive for the duration of playback -- dropping it
195    /// stops audio output.
196    _stream: rodio::OutputStream,
197    /// The actual playback sink.
198    sink: rodio::Sink,
199    /// Target sample rate for the output device. If `Some`, incoming
200    /// audio is resampled to this rate.  If `None`, audio is played
201    /// at its native sample rate.
202    target_sample_rate: Option<u32>,
203    /// Whether `finalize` has been called.
204    finalized: bool,
205}
206
207#[cfg(feature = "playback")]
208impl RodioPlayer {
209    /// Create a new player using the default output device.
210    ///
211    /// Audio is played at whatever sample rate each chunk declares.
212    pub fn new() -> Result<Self, PiperError> {
213        let (_stream, stream_handle) = rodio::OutputStream::try_default()
214            .map_err(|e| PiperError::Inference(format!("failed to open audio output: {e}")))?;
215
216        let sink = rodio::Sink::try_new(&stream_handle)
217            .map_err(|e| PiperError::Inference(format!("failed to create audio sink: {e}")))?;
218
219        Ok(Self {
220            _stream,
221            sink,
222            target_sample_rate: None,
223            finalized: false,
224        })
225    }
226
227    /// Create a player that resamples all incoming audio to
228    /// `target_sample_rate` before sending it to the output device.
229    ///
230    /// Returns an error if `target_sample_rate` is 0.
231    pub fn with_sample_rate(target_sample_rate: u32) -> Result<Self, PiperError> {
232        if target_sample_rate == 0 {
233            return Err(PiperError::Inference(
234                "target sample rate must be > 0".to_string(),
235            ));
236        }
237
238        let (_stream, stream_handle) = rodio::OutputStream::try_default()
239            .map_err(|e| PiperError::Inference(format!("failed to open audio output: {e}")))?;
240
241        let sink = rodio::Sink::try_new(&stream_handle)
242            .map_err(|e| PiperError::Inference(format!("failed to create audio sink: {e}")))?;
243
244        Ok(Self {
245            _stream,
246            sink,
247            target_sample_rate: Some(target_sample_rate),
248            finalized: false,
249        })
250    }
251
252    /// Block until all queued audio has finished playing.
253    pub fn wait_until_done(&self) {
254        self.sink.sleep_until_end();
255    }
256
257    /// Resample `samples` from `src_rate` to `dst_rate` using linear
258    /// interpolation.  Good enough for real-time preview; not
259    /// production-grade.
260    fn linear_resample(samples: &[i16], src_rate: u32, dst_rate: u32) -> Vec<i16> {
261        if src_rate == dst_rate || samples.is_empty() {
262            return samples.to_vec();
263        }
264
265        let ratio = src_rate as f64 / dst_rate as f64;
266        let out_len = ((samples.len() as f64) / ratio).ceil() as usize;
267        let mut out = Vec::with_capacity(out_len);
268
269        for i in 0..out_len {
270            let src_pos = i as f64 * ratio;
271            let idx = src_pos as usize;
272            let frac = src_pos - idx as f64;
273
274            let s0 = samples[idx] as f64;
275            let s1 = if idx + 1 < samples.len() {
276                samples[idx + 1] as f64
277            } else {
278                s0
279            };
280
281            let interpolated = s0 + frac * (s1 - s0);
282            out.push(interpolated.clamp(-32768.0, 32767.0) as i16);
283        }
284
285        out
286    }
287}
288
289#[cfg(feature = "playback")]
290impl AudioSink for RodioPlayer {
291    fn write_chunk(&mut self, samples: &[i16], sample_rate: u32) -> Result<(), PiperError> {
292        if self.finalized {
293            return Err(PiperError::Inference(
294                "RodioPlayer: write_chunk called after finalize".to_string(),
295            ));
296        }
297        if sample_rate == 0 {
298            return Err(PiperError::Inference("sample rate must be > 0".to_string()));
299        }
300        if samples.is_empty() {
301            return Ok(());
302        }
303
304        let (play_samples, play_rate) = match self.target_sample_rate {
305            Some(target) if target != sample_rate => {
306                let resampled = Self::linear_resample(samples, sample_rate, target);
307                (resampled, target)
308            }
309            _ => (samples.to_vec(), sample_rate),
310        };
311
312        let source = rodio::buffer::SamplesBuffer::new(1, play_rate, play_samples);
313        self.sink.append(source);
314
315        Ok(())
316    }
317
318    fn finalize(&mut self) -> Result<(), PiperError> {
319        self.finalized = true;
320        Ok(())
321    }
322}
323
324// ---------------------------------------------------------------------------
325// Helper function
326// ---------------------------------------------------------------------------
327
328/// Play audio synchronously through the default output device.
329///
330/// When the `playback` feature is enabled this uses [`RodioPlayer`].
331/// Otherwise it falls back to [`DummyPlayer`] (no-op).
332///
333/// Returns an error if `sample_rate` is 0 or (with `playback` enabled)
334/// if the audio device cannot be opened.
335pub fn play_audio(samples: &[i16], sample_rate: u32) -> Result<(), PiperError> {
336    if sample_rate == 0 {
337        return Err(PiperError::Inference("sample rate must be > 0".to_string()));
338    }
339
340    #[cfg(feature = "playback")]
341    {
342        let mut player = RodioPlayer::new()?;
343        player.write_chunk(samples, sample_rate)?;
344        player.finalize()?;
345        player.wait_until_done();
346        Ok(())
347    }
348
349    #[cfg(not(feature = "playback"))]
350    {
351        let mut player = DummyPlayer::new();
352        player.write_chunk(samples, sample_rate)?;
353        player.finalize()?;
354        Ok(())
355    }
356}
357
358// ---------------------------------------------------------------------------
359// Tests
360// ---------------------------------------------------------------------------
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365
366    // -- DummyPlayer tests --------------------------------------------------
367
368    #[test]
369    fn dummy_player_initial_state() {
370        let player = DummyPlayer::new();
371        assert_eq!(player.total_samples(), 0);
372        assert_eq!(player.chunk_count(), 0);
373        assert_eq!(player.last_sample_rate(), 0);
374        assert!(!player.is_finalized());
375    }
376
377    #[test]
378    fn dummy_player_single_chunk() {
379        let mut player = DummyPlayer::new();
380        let samples = vec![100i16, 200, 300];
381        player.write_chunk(&samples, 22050).unwrap();
382
383        assert_eq!(player.total_samples(), 3);
384        assert_eq!(player.chunk_count(), 1);
385        assert_eq!(player.last_sample_rate(), 22050);
386    }
387
388    #[test]
389    fn dummy_player_multiple_chunks() {
390        let mut player = DummyPlayer::new();
391        player.write_chunk(&[1, 2, 3], 22050).unwrap();
392        player.write_chunk(&[4, 5], 44100).unwrap();
393        player.write_chunk(&[6], 16000).unwrap();
394
395        assert_eq!(player.total_samples(), 6);
396        assert_eq!(player.chunk_count(), 3);
397        assert_eq!(player.last_sample_rate(), 16000);
398    }
399
400    #[test]
401    fn dummy_player_finalize() {
402        let mut player = DummyPlayer::new();
403        player.write_chunk(&[1, 2], 22050).unwrap();
404        assert!(!player.is_finalized());
405
406        player.finalize().unwrap();
407        assert!(player.is_finalized());
408    }
409
410    #[test]
411    fn dummy_player_write_after_finalize_errors() {
412        let mut player = DummyPlayer::new();
413        player.finalize().unwrap();
414
415        let result = player.write_chunk(&[1], 22050);
416        assert!(result.is_err());
417        assert!(
418            result.unwrap_err().to_string().contains("after finalize"),
419            "error message should mention finalize"
420        );
421    }
422
423    #[test]
424    fn dummy_player_zero_sample_rate_errors() {
425        let mut player = DummyPlayer::new();
426        let result = player.write_chunk(&[1, 2], 0);
427        assert!(result.is_err());
428        assert!(
429            result.unwrap_err().to_string().contains("sample rate"),
430            "error message should mention sample rate"
431        );
432    }
433
434    #[test]
435    fn dummy_player_empty_chunk() {
436        let mut player = DummyPlayer::new();
437        player.write_chunk(&[], 22050).unwrap();
438
439        assert_eq!(player.total_samples(), 0);
440        assert_eq!(player.chunk_count(), 1);
441        assert_eq!(player.last_sample_rate(), 22050);
442    }
443
444    #[test]
445    fn dummy_player_default_trait() {
446        let player = DummyPlayer::default();
447        assert_eq!(player.total_samples(), 0);
448        assert!(!player.is_finalized());
449    }
450
451    // -- CollectorSink tests ------------------------------------------------
452
453    #[test]
454    fn collector_sink_collects_samples() {
455        let mut sink = CollectorSink::new();
456        sink.write_chunk(&[10, 20, 30], 22050).unwrap();
457        sink.write_chunk(&[40, 50], 22050).unwrap();
458
459        assert_eq!(sink.samples(), &[10, 20, 30, 40, 50]);
460        assert_eq!(sink.sample_rate(), Some(22050));
461    }
462
463    #[test]
464    fn collector_sink_sample_rate_mismatch_errors() {
465        let mut sink = CollectorSink::new();
466        sink.write_chunk(&[1], 22050).unwrap();
467
468        let result = sink.write_chunk(&[2], 44100);
469        assert!(result.is_err());
470        assert!(
471            result.unwrap_err().to_string().contains("mismatch"),
472            "error message should mention mismatch"
473        );
474    }
475
476    #[test]
477    fn collector_sink_write_after_finalize_errors() {
478        let mut sink = CollectorSink::new();
479        sink.finalize().unwrap();
480
481        let result = sink.write_chunk(&[1], 22050);
482        assert!(result.is_err());
483    }
484
485    #[test]
486    fn collector_sink_into_samples() {
487        let mut sink = CollectorSink::new();
488        sink.write_chunk(&[7, 8, 9], 16000).unwrap();
489        sink.finalize().unwrap();
490
491        let data = sink.into_samples();
492        assert_eq!(data, vec![7, 8, 9]);
493    }
494
495    #[test]
496    fn collector_sink_empty() {
497        let sink = CollectorSink::new();
498        assert!(sink.samples().is_empty());
499        assert_eq!(sink.sample_rate(), None);
500        assert!(!sink.is_finalized());
501    }
502
503    #[test]
504    fn collector_sink_zero_sample_rate_errors() {
505        let mut sink = CollectorSink::new();
506        let result = sink.write_chunk(&[1], 0);
507        assert!(result.is_err());
508    }
509
510    #[test]
511    fn collector_sink_default_trait() {
512        let sink = CollectorSink::default();
513        assert!(sink.samples().is_empty());
514        assert!(!sink.is_finalized());
515    }
516
517    // -- play_audio helper tests --------------------------------------------
518
519    #[test]
520    fn play_audio_zero_sample_rate_errors() {
521        let result = play_audio(&[1, 2, 3], 0);
522        assert!(result.is_err());
523    }
524
525    #[test]
526    fn play_audio_empty_samples_ok() {
527        // Without the playback feature, this goes through DummyPlayer
528        // and should succeed.
529        let result = play_audio(&[], 22050);
530        assert!(result.is_ok());
531    }
532
533    #[test]
534    fn play_audio_normal_samples_ok() {
535        // Without the playback feature this is a no-op via DummyPlayer.
536        let samples: Vec<i16> = (0..100).map(|i| (i * 100) as i16).collect();
537        let result = play_audio(&samples, 22050);
538        assert!(result.is_ok());
539    }
540
541    // -- DummyPlayer additional tests ----------------------------------------
542
543    #[test]
544    fn dummy_player_double_finalize_is_idempotent() {
545        let mut player = DummyPlayer::new();
546        player.write_chunk(&[1, 2, 3], 22050).unwrap();
547        player.finalize().unwrap();
548        assert!(player.is_finalized());
549
550        // Second finalize should also succeed (idempotent)
551        player.finalize().unwrap();
552        assert!(player.is_finalized());
553    }
554
555    #[test]
556    fn dummy_player_large_sample_count() {
557        let mut player = DummyPlayer::new();
558        let samples: Vec<i16> = vec![42; 1_000_000];
559        player.write_chunk(&samples, 22050).unwrap();
560
561        assert_eq!(player.total_samples(), 1_000_000);
562        assert_eq!(player.chunk_count(), 1);
563        assert_eq!(player.last_sample_rate(), 22050);
564    }
565
566    // -- CollectorSink additional tests --------------------------------------
567
568    #[test]
569    fn collector_sink_double_finalize_is_idempotent() {
570        let mut sink = CollectorSink::new();
571        sink.write_chunk(&[10, 20], 44100).unwrap();
572        sink.finalize().unwrap();
573        assert!(sink.is_finalized());
574
575        // Second finalize should also succeed (idempotent)
576        sink.finalize().unwrap();
577        assert!(sink.is_finalized());
578    }
579
580    #[test]
581    fn collector_sink_multiple_different_sample_rates_errors() {
582        let mut sink = CollectorSink::new();
583
584        // First chunk at 22050 sets the rate
585        sink.write_chunk(&[1, 2, 3], 22050).unwrap();
586        assert_eq!(sink.sample_rate(), Some(22050));
587
588        // Second chunk at 44100 must fail with mismatch error
589        let result = sink.write_chunk(&[4, 5], 44100);
590        assert!(result.is_err());
591        let err_msg = result.unwrap_err().to_string();
592        assert!(
593            err_msg.contains("mismatch"),
594            "error should mention mismatch, got: {err_msg}"
595        );
596        assert!(
597            err_msg.contains("22050"),
598            "error should mention expected rate 22050, got: {err_msg}"
599        );
600        assert!(
601            err_msg.contains("44100"),
602            "error should mention actual rate 44100, got: {err_msg}"
603        );
604
605        // Third chunk at 16000 must also fail (first rate still locked at 22050)
606        let result2 = sink.write_chunk(&[6], 16000);
607        assert!(result2.is_err());
608
609        // Verify only the first chunk's samples were collected
610        assert_eq!(sink.samples(), &[1, 2, 3]);
611    }
612
613    #[test]
614    fn collector_sink_into_samples_ownership() {
615        let mut sink = CollectorSink::new();
616        sink.write_chunk(&[100, 200, 300], 16000).unwrap();
617        sink.write_chunk(&[400, 500], 16000).unwrap();
618        sink.finalize().unwrap();
619
620        // into_samples consumes self and returns owned Vec
621        let owned: Vec<i16> = sink.into_samples();
622        assert_eq!(owned, vec![100, 200, 300, 400, 500]);
623        assert_eq!(owned.len(), 5);
624
625        // After into_samples, `sink` is moved -- cannot be used.
626        // (This is enforced at compile time, no runtime assertion needed.)
627    }
628
629    // -- play_audio with various sample rates --------------------------------
630
631    #[test]
632    fn play_audio_various_sample_rates() {
633        // Without the `playback` feature, play_audio uses DummyPlayer.
634        // All valid sample rates should succeed.
635        let samples: Vec<i16> = (0..64).collect();
636
637        for &rate in &[8000u32, 16000, 22050, 44100] {
638            let result = play_audio(&samples, rate);
639            assert!(
640                result.is_ok(),
641                "play_audio should succeed at sample rate {rate}"
642            );
643        }
644    }
645
646    // -- RodioPlayer compile-time checks (feature-gated) --------------------
647    // These tests verify that the RodioPlayer API compiles correctly
648    // under the `playback` feature.  Actual audio output is not tested
649    // here because CI environments typically lack an audio device.
650
651    #[cfg(feature = "playback")]
652    mod rodio_tests {
653        use super::super::*;
654
655        #[test]
656        fn rodio_player_zero_target_rate_errors() {
657            let result = RodioPlayer::with_sample_rate(0);
658            assert!(result.is_err());
659            assert!(
660                result.unwrap_err().to_string().contains("sample rate"),
661                "error message should mention sample rate"
662            );
663        }
664
665        #[test]
666        fn rodio_linear_resample_same_rate() {
667            let input = vec![100i16, 200, 300, 400];
668            let output = RodioPlayer::linear_resample(&input, 22050, 22050);
669            assert_eq!(input, output);
670        }
671
672        #[test]
673        fn rodio_linear_resample_empty() {
674            let output = RodioPlayer::linear_resample(&[], 22050, 44100);
675            assert!(output.is_empty());
676        }
677
678        #[test]
679        fn rodio_linear_resample_upsample() {
680            // 1 Hz -> 2 Hz should roughly double the number of samples
681            let input = vec![0i16, 1000, 0, -1000];
682            let output = RodioPlayer::linear_resample(&input, 100, 200);
683            assert!(
684                output.len() >= input.len(),
685                "upsampled output should have more samples"
686            );
687        }
688
689        #[test]
690        fn rodio_linear_resample_downsample() {
691            let input: Vec<i16> = (0..1000).map(|i| (i % 256) as i16).collect();
692            let output = RodioPlayer::linear_resample(&input, 44100, 22050);
693            assert!(
694                output.len() < input.len(),
695                "downsampled output should have fewer samples"
696            );
697        }
698
699        #[test]
700        fn rodio_linear_resample_preserves_length_ratio() {
701            // Upsample 22050 -> 48000: output length should be
702            // ceil(input_len * 48000/22050)
703            let input_len = 22050; // 1 second of audio at 22050 Hz
704            let input: Vec<i16> = (0..input_len as i16).collect();
705            let output = RodioPlayer::linear_resample(&input, 22050, 48000);
706
707            let expected_len = ((input_len as f64) * (48000.0 / 22050.0)).ceil() as usize;
708            // Allow +/- 1 sample tolerance for rounding
709            assert!(
710                (output.len() as isize - expected_len as isize).unsigned_abs() <= 1,
711                "expected ~{expected_len} samples, got {}",
712                output.len()
713            );
714
715            // Verify ratio is approximately correct
716            let ratio = output.len() as f64 / input_len as f64;
717            let expected_ratio = 48000.0 / 22050.0;
718            assert!(
719                (ratio - expected_ratio).abs() < 0.01,
720                "sample count ratio {ratio:.4} should be close to {expected_ratio:.4}"
721            );
722        }
723
724        #[test]
725        fn rodio_linear_resample_boundary_values() {
726            // Test with extreme i16 values (MIN and MAX) to verify
727            // clamping and interpolation do not overflow or wrap
728            let input = vec![i16::MIN, i16::MAX, i16::MIN, i16::MAX, 0];
729            let output = RodioPlayer::linear_resample(&input, 22050, 48000);
730
731            assert!(!output.is_empty(), "resampled output should not be empty");
732
733            // Every output sample must stay within valid i16 range
734            for (i, &sample) in output.iter().enumerate() {
735                assert!(
736                    sample >= i16::MIN && sample <= i16::MAX,
737                    "sample[{i}] = {sample} is out of i16 range"
738                );
739            }
740
741            // Verify the extreme values appear in the output (first and
742            // last input samples map directly to output positions)
743            assert_eq!(
744                output[0],
745                i16::MIN,
746                "first output sample should be i16::MIN"
747            );
748        }
749    }
750}