use std::sync::Mutex;
use bytes::Bytes;
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
use tokio::sync::mpsc;
use crate::error::{AudioError, AudioResult};
use crate::frame::AudioFrame;
use super::device::{AudioDevice, CaptureConfig};
pub type AudioStream = mpsc::Receiver<AudioFrame>;
struct SyncStream(#[allow(dead_code)] cpal::Stream);
unsafe impl Sync for SyncStream {}
unsafe impl Send for SyncStream {}
pub struct AudioCapture {
stream: Option<SyncStream>,
}
impl AudioCapture {
pub fn new() -> Self {
Self { stream: None }
}
pub fn list_input_devices() -> AudioResult<Vec<AudioDevice>> {
let host = cpal::default_host();
let devices = host.input_devices().map_err(|e| {
AudioError::Device(format!(
"failed to enumerate input devices: {e}. Check that audio drivers are installed."
))
})?;
let result: Vec<AudioDevice> = devices
.filter_map(|device| {
let name = device.name().unwrap_or_default();
if name.is_empty() { None } else { Some(AudioDevice::new(name.clone(), name)) }
})
.collect();
Ok(result)
}
pub fn start_capture(
&mut self,
device_id: &str,
config: &CaptureConfig,
) -> AudioResult<AudioStream> {
config.validate()?;
let host = cpal::default_host();
let devices = host.input_devices().map_err(|e| {
AudioError::Device(format!(
"failed to enumerate input devices: {e}. Check that audio drivers are installed."
))
})?;
let device = devices
.into_iter()
.find(|d| d.name().unwrap_or_default() == device_id)
.ok_or_else(|| {
AudioError::Device(format!(
"input device not found: '{device_id}'. Use list_input_devices() to see available devices."
))
})?;
let stream_config = cpal::StreamConfig {
channels: config.channels as u16,
sample_rate: cpal::SampleRate(config.sample_rate),
buffer_size: cpal::BufferSize::Default,
};
let (tx, rx) = mpsc::channel::<AudioFrame>(64);
let samples_per_frame = (config.sample_rate as usize
* config.channels as usize
* config.frame_duration_ms as usize)
/ 1000;
let sample_rate = config.sample_rate;
let channels = config.channels;
let buffer: Mutex<Vec<i16>> = Mutex::new(Vec::with_capacity(samples_per_frame));
let cpal_stream = device
.build_input_stream(
&stream_config,
move |data: &[f32], _: &cpal::InputCallbackInfo| {
let mut buf = buffer.lock().expect("audio buffer lock poisoned");
for &sample in data {
let clamped = sample.clamp(-1.0, 1.0);
let as_i16 = (clamped * i16::MAX as f32) as i16;
buf.push(as_i16);
if buf.len() >= samples_per_frame {
let pcm_bytes: Vec<u8> = buf
.drain(..samples_per_frame)
.flat_map(|s| s.to_le_bytes())
.collect();
let frame =
AudioFrame::new(Bytes::from(pcm_bytes), sample_rate, channels);
if tx.try_send(frame).is_err() {
tracing::warn!(
"audio capture channel full — frame dropped. Consumer may be too slow."
);
}
}
}
},
move |err| {
tracing::error!("cpal input stream error: {err}");
},
None, )
.map_err(|e| {
AudioError::Device(format!("failed to open input stream on '{device_id}': {e}"))
})?;
cpal_stream.play().map_err(|e| {
AudioError::Device(format!("failed to start input stream on '{device_id}': {e}"))
})?;
self.stream = Some(SyncStream(cpal_stream));
Ok(rx)
}
pub fn stop_capture(&mut self) {
self.stream = None;
}
}
impl Default for AudioCapture {
fn default() -> Self {
Self::new()
}
}
#[allow(dead_code)]
const _: fn() = || {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<AudioCapture>();
};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stop_capture_idempotent() {
let mut capture = AudioCapture::new();
capture.stop_capture();
assert!(capture.stream.is_none());
capture.stop_capture();
assert!(capture.stream.is_none());
}
#[test]
fn test_audio_capture_default() {
let capture = AudioCapture::default();
assert!(capture.stream.is_none());
}
#[test]
fn test_audio_capture_is_send_sync() {
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
assert_send::<AudioCapture>();
assert_sync::<AudioCapture>();
}
}