use crate::error::PiperError;
use crate::streaming::AudioSink;
pub struct DummyPlayer {
total_samples: usize,
chunk_count: usize,
last_sample_rate: u32,
finalized: bool,
}
impl DummyPlayer {
pub fn new() -> Self {
Self {
total_samples: 0,
chunk_count: 0,
last_sample_rate: 0,
finalized: false,
}
}
pub fn total_samples(&self) -> usize {
self.total_samples
}
pub fn chunk_count(&self) -> usize {
self.chunk_count
}
pub fn last_sample_rate(&self) -> u32 {
self.last_sample_rate
}
pub fn is_finalized(&self) -> bool {
self.finalized
}
}
impl Default for DummyPlayer {
fn default() -> Self {
Self::new()
}
}
impl AudioSink for DummyPlayer {
fn write_chunk(&mut self, samples: &[i16], sample_rate: u32) -> Result<(), PiperError> {
if self.finalized {
return Err(PiperError::Inference(
"DummyPlayer: write_chunk called after finalize".to_string(),
));
}
if sample_rate == 0 {
return Err(PiperError::Inference("sample rate must be > 0".to_string()));
}
self.total_samples += samples.len();
self.chunk_count += 1;
self.last_sample_rate = sample_rate;
Ok(())
}
fn finalize(&mut self) -> Result<(), PiperError> {
self.finalized = true;
Ok(())
}
}
pub struct CollectorSink {
samples: Vec<i16>,
sample_rate: Option<u32>,
finalized: bool,
}
impl CollectorSink {
pub fn new() -> Self {
Self {
samples: Vec::new(),
sample_rate: None,
finalized: false,
}
}
pub fn samples(&self) -> &[i16] {
&self.samples
}
pub fn sample_rate(&self) -> Option<u32> {
self.sample_rate
}
pub fn is_finalized(&self) -> bool {
self.finalized
}
pub fn into_samples(self) -> Vec<i16> {
self.samples
}
}
impl Default for CollectorSink {
fn default() -> Self {
Self::new()
}
}
impl AudioSink for CollectorSink {
fn write_chunk(&mut self, samples: &[i16], sample_rate: u32) -> Result<(), PiperError> {
if self.finalized {
return Err(PiperError::Inference(
"CollectorSink: write_chunk called after finalize".to_string(),
));
}
if sample_rate == 0 {
return Err(PiperError::Inference("sample rate must be > 0".to_string()));
}
if let Some(prev) = self.sample_rate
&& prev != sample_rate
{
return Err(PiperError::Inference(format!(
"sample rate mismatch: expected {prev}, got {sample_rate}"
)));
}
self.sample_rate = Some(sample_rate);
self.samples.extend_from_slice(samples);
Ok(())
}
fn finalize(&mut self) -> Result<(), PiperError> {
self.finalized = true;
Ok(())
}
}
#[cfg(feature = "playback")]
pub struct RodioPlayer {
_stream: rodio::OutputStream,
sink: rodio::Sink,
target_sample_rate: Option<u32>,
finalized: bool,
}
#[cfg(feature = "playback")]
impl RodioPlayer {
pub fn new() -> Result<Self, PiperError> {
let (_stream, stream_handle) = rodio::OutputStream::try_default()
.map_err(|e| PiperError::Inference(format!("failed to open audio output: {e}")))?;
let sink = rodio::Sink::try_new(&stream_handle)
.map_err(|e| PiperError::Inference(format!("failed to create audio sink: {e}")))?;
Ok(Self {
_stream,
sink,
target_sample_rate: None,
finalized: false,
})
}
pub fn with_sample_rate(target_sample_rate: u32) -> Result<Self, PiperError> {
if target_sample_rate == 0 {
return Err(PiperError::Inference(
"target sample rate must be > 0".to_string(),
));
}
let (_stream, stream_handle) = rodio::OutputStream::try_default()
.map_err(|e| PiperError::Inference(format!("failed to open audio output: {e}")))?;
let sink = rodio::Sink::try_new(&stream_handle)
.map_err(|e| PiperError::Inference(format!("failed to create audio sink: {e}")))?;
Ok(Self {
_stream,
sink,
target_sample_rate: Some(target_sample_rate),
finalized: false,
})
}
pub fn wait_until_done(&self) {
self.sink.sleep_until_end();
}
fn linear_resample(samples: &[i16], src_rate: u32, dst_rate: u32) -> Vec<i16> {
if src_rate == dst_rate || samples.is_empty() {
return samples.to_vec();
}
let ratio = src_rate as f64 / dst_rate as f64;
let out_len = ((samples.len() as f64) / ratio).ceil() as usize;
let mut out = Vec::with_capacity(out_len);
for i in 0..out_len {
let src_pos = i as f64 * ratio;
let idx = src_pos as usize;
let frac = src_pos - idx as f64;
let s0 = samples[idx] as f64;
let s1 = if idx + 1 < samples.len() {
samples[idx + 1] as f64
} else {
s0
};
let interpolated = s0 + frac * (s1 - s0);
out.push(interpolated.clamp(-32768.0, 32767.0) as i16);
}
out
}
}
#[cfg(feature = "playback")]
impl AudioSink for RodioPlayer {
fn write_chunk(&mut self, samples: &[i16], sample_rate: u32) -> Result<(), PiperError> {
if self.finalized {
return Err(PiperError::Inference(
"RodioPlayer: write_chunk called after finalize".to_string(),
));
}
if sample_rate == 0 {
return Err(PiperError::Inference("sample rate must be > 0".to_string()));
}
if samples.is_empty() {
return Ok(());
}
let (play_samples, play_rate) = match self.target_sample_rate {
Some(target) if target != sample_rate => {
let resampled = Self::linear_resample(samples, sample_rate, target);
(resampled, target)
}
_ => (samples.to_vec(), sample_rate),
};
let source = rodio::buffer::SamplesBuffer::new(1, play_rate, play_samples);
self.sink.append(source);
Ok(())
}
fn finalize(&mut self) -> Result<(), PiperError> {
self.finalized = true;
Ok(())
}
}
pub fn play_audio(samples: &[i16], sample_rate: u32) -> Result<(), PiperError> {
if sample_rate == 0 {
return Err(PiperError::Inference("sample rate must be > 0".to_string()));
}
#[cfg(feature = "playback")]
{
let mut player = RodioPlayer::new()?;
player.write_chunk(samples, sample_rate)?;
player.finalize()?;
player.wait_until_done();
Ok(())
}
#[cfg(not(feature = "playback"))]
{
let mut player = DummyPlayer::new();
player.write_chunk(samples, sample_rate)?;
player.finalize()?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dummy_player_initial_state() {
let player = DummyPlayer::new();
assert_eq!(player.total_samples(), 0);
assert_eq!(player.chunk_count(), 0);
assert_eq!(player.last_sample_rate(), 0);
assert!(!player.is_finalized());
}
#[test]
fn dummy_player_single_chunk() {
let mut player = DummyPlayer::new();
let samples = vec![100i16, 200, 300];
player.write_chunk(&samples, 22050).unwrap();
assert_eq!(player.total_samples(), 3);
assert_eq!(player.chunk_count(), 1);
assert_eq!(player.last_sample_rate(), 22050);
}
#[test]
fn dummy_player_multiple_chunks() {
let mut player = DummyPlayer::new();
player.write_chunk(&[1, 2, 3], 22050).unwrap();
player.write_chunk(&[4, 5], 44100).unwrap();
player.write_chunk(&[6], 16000).unwrap();
assert_eq!(player.total_samples(), 6);
assert_eq!(player.chunk_count(), 3);
assert_eq!(player.last_sample_rate(), 16000);
}
#[test]
fn dummy_player_finalize() {
let mut player = DummyPlayer::new();
player.write_chunk(&[1, 2], 22050).unwrap();
assert!(!player.is_finalized());
player.finalize().unwrap();
assert!(player.is_finalized());
}
#[test]
fn dummy_player_write_after_finalize_errors() {
let mut player = DummyPlayer::new();
player.finalize().unwrap();
let result = player.write_chunk(&[1], 22050);
assert!(result.is_err());
assert!(
result.unwrap_err().to_string().contains("after finalize"),
"error message should mention finalize"
);
}
#[test]
fn dummy_player_zero_sample_rate_errors() {
let mut player = DummyPlayer::new();
let result = player.write_chunk(&[1, 2], 0);
assert!(result.is_err());
assert!(
result.unwrap_err().to_string().contains("sample rate"),
"error message should mention sample rate"
);
}
#[test]
fn dummy_player_empty_chunk() {
let mut player = DummyPlayer::new();
player.write_chunk(&[], 22050).unwrap();
assert_eq!(player.total_samples(), 0);
assert_eq!(player.chunk_count(), 1);
assert_eq!(player.last_sample_rate(), 22050);
}
#[test]
fn dummy_player_default_trait() {
let player = DummyPlayer::default();
assert_eq!(player.total_samples(), 0);
assert!(!player.is_finalized());
}
#[test]
fn collector_sink_collects_samples() {
let mut sink = CollectorSink::new();
sink.write_chunk(&[10, 20, 30], 22050).unwrap();
sink.write_chunk(&[40, 50], 22050).unwrap();
assert_eq!(sink.samples(), &[10, 20, 30, 40, 50]);
assert_eq!(sink.sample_rate(), Some(22050));
}
#[test]
fn collector_sink_sample_rate_mismatch_errors() {
let mut sink = CollectorSink::new();
sink.write_chunk(&[1], 22050).unwrap();
let result = sink.write_chunk(&[2], 44100);
assert!(result.is_err());
assert!(
result.unwrap_err().to_string().contains("mismatch"),
"error message should mention mismatch"
);
}
#[test]
fn collector_sink_write_after_finalize_errors() {
let mut sink = CollectorSink::new();
sink.finalize().unwrap();
let result = sink.write_chunk(&[1], 22050);
assert!(result.is_err());
}
#[test]
fn collector_sink_into_samples() {
let mut sink = CollectorSink::new();
sink.write_chunk(&[7, 8, 9], 16000).unwrap();
sink.finalize().unwrap();
let data = sink.into_samples();
assert_eq!(data, vec![7, 8, 9]);
}
#[test]
fn collector_sink_empty() {
let sink = CollectorSink::new();
assert!(sink.samples().is_empty());
assert_eq!(sink.sample_rate(), None);
assert!(!sink.is_finalized());
}
#[test]
fn collector_sink_zero_sample_rate_errors() {
let mut sink = CollectorSink::new();
let result = sink.write_chunk(&[1], 0);
assert!(result.is_err());
}
#[test]
fn collector_sink_default_trait() {
let sink = CollectorSink::default();
assert!(sink.samples().is_empty());
assert!(!sink.is_finalized());
}
#[test]
fn play_audio_zero_sample_rate_errors() {
let result = play_audio(&[1, 2, 3], 0);
assert!(result.is_err());
}
#[test]
fn play_audio_empty_samples_ok() {
let result = play_audio(&[], 22050);
assert!(result.is_ok());
}
#[test]
fn play_audio_normal_samples_ok() {
let samples: Vec<i16> = (0..100).map(|i| (i * 100) as i16).collect();
let result = play_audio(&samples, 22050);
assert!(result.is_ok());
}
#[test]
fn dummy_player_double_finalize_is_idempotent() {
let mut player = DummyPlayer::new();
player.write_chunk(&[1, 2, 3], 22050).unwrap();
player.finalize().unwrap();
assert!(player.is_finalized());
player.finalize().unwrap();
assert!(player.is_finalized());
}
#[test]
fn dummy_player_large_sample_count() {
let mut player = DummyPlayer::new();
let samples: Vec<i16> = vec![42; 1_000_000];
player.write_chunk(&samples, 22050).unwrap();
assert_eq!(player.total_samples(), 1_000_000);
assert_eq!(player.chunk_count(), 1);
assert_eq!(player.last_sample_rate(), 22050);
}
#[test]
fn collector_sink_double_finalize_is_idempotent() {
let mut sink = CollectorSink::new();
sink.write_chunk(&[10, 20], 44100).unwrap();
sink.finalize().unwrap();
assert!(sink.is_finalized());
sink.finalize().unwrap();
assert!(sink.is_finalized());
}
#[test]
fn collector_sink_multiple_different_sample_rates_errors() {
let mut sink = CollectorSink::new();
sink.write_chunk(&[1, 2, 3], 22050).unwrap();
assert_eq!(sink.sample_rate(), Some(22050));
let result = sink.write_chunk(&[4, 5], 44100);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("mismatch"),
"error should mention mismatch, got: {err_msg}"
);
assert!(
err_msg.contains("22050"),
"error should mention expected rate 22050, got: {err_msg}"
);
assert!(
err_msg.contains("44100"),
"error should mention actual rate 44100, got: {err_msg}"
);
let result2 = sink.write_chunk(&[6], 16000);
assert!(result2.is_err());
assert_eq!(sink.samples(), &[1, 2, 3]);
}
#[test]
fn collector_sink_into_samples_ownership() {
let mut sink = CollectorSink::new();
sink.write_chunk(&[100, 200, 300], 16000).unwrap();
sink.write_chunk(&[400, 500], 16000).unwrap();
sink.finalize().unwrap();
let owned: Vec<i16> = sink.into_samples();
assert_eq!(owned, vec![100, 200, 300, 400, 500]);
assert_eq!(owned.len(), 5);
}
#[test]
fn play_audio_various_sample_rates() {
let samples: Vec<i16> = (0..64).collect();
for &rate in &[8000u32, 16000, 22050, 44100] {
let result = play_audio(&samples, rate);
assert!(
result.is_ok(),
"play_audio should succeed at sample rate {rate}"
);
}
}
#[cfg(feature = "playback")]
mod rodio_tests {
use super::super::*;
#[test]
fn rodio_player_zero_target_rate_errors() {
let result = RodioPlayer::with_sample_rate(0);
assert!(result.is_err());
assert!(
result.unwrap_err().to_string().contains("sample rate"),
"error message should mention sample rate"
);
}
#[test]
fn rodio_linear_resample_same_rate() {
let input = vec![100i16, 200, 300, 400];
let output = RodioPlayer::linear_resample(&input, 22050, 22050);
assert_eq!(input, output);
}
#[test]
fn rodio_linear_resample_empty() {
let output = RodioPlayer::linear_resample(&[], 22050, 44100);
assert!(output.is_empty());
}
#[test]
fn rodio_linear_resample_upsample() {
let input = vec![0i16, 1000, 0, -1000];
let output = RodioPlayer::linear_resample(&input, 100, 200);
assert!(
output.len() >= input.len(),
"upsampled output should have more samples"
);
}
#[test]
fn rodio_linear_resample_downsample() {
let input: Vec<i16> = (0..1000).map(|i| (i % 256) as i16).collect();
let output = RodioPlayer::linear_resample(&input, 44100, 22050);
assert!(
output.len() < input.len(),
"downsampled output should have fewer samples"
);
}
#[test]
fn rodio_linear_resample_preserves_length_ratio() {
let input_len = 22050; let input: Vec<i16> = (0..input_len as i16).collect();
let output = RodioPlayer::linear_resample(&input, 22050, 48000);
let expected_len = ((input_len as f64) * (48000.0 / 22050.0)).ceil() as usize;
assert!(
(output.len() as isize - expected_len as isize).unsigned_abs() <= 1,
"expected ~{expected_len} samples, got {}",
output.len()
);
let ratio = output.len() as f64 / input_len as f64;
let expected_ratio = 48000.0 / 22050.0;
assert!(
(ratio - expected_ratio).abs() < 0.01,
"sample count ratio {ratio:.4} should be close to {expected_ratio:.4}"
);
}
#[test]
fn rodio_linear_resample_boundary_values() {
let input = vec![i16::MIN, i16::MAX, i16::MIN, i16::MAX, 0];
let output = RodioPlayer::linear_resample(&input, 22050, 48000);
assert!(!output.is_empty(), "resampled output should not be empty");
for (i, &sample) in output.iter().enumerate() {
assert!(
sample >= i16::MIN && sample <= i16::MAX,
"sample[{i}] = {sample} is out of i16 range"
);
}
assert_eq!(
output[0],
i16::MIN,
"first output sample should be i16::MIN"
);
}
}
}