use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use anyhow::Context;
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
use cpal::{SampleFormat, SampleRate, StreamConfig};
use tokio::sync::mpsc;
use tracing::{debug, error, info, warn};
use super::AudioChunk;
const SAMPLE_RATE: u32 = 16_000;
const CHANNELS: u16 = 1;
pub struct AudioCaptureHandle {
receiver: Option<mpsc::UnboundedReceiver<AudioChunk>>,
stop_signal: Arc<AtomicBool>,
thread_handle: Option<std::thread::JoinHandle<()>>,
}
unsafe impl Send for AudioCaptureHandle {}
impl Drop for AudioCaptureHandle {
fn drop(&mut self) {
self.stop_signal.store(true, Ordering::Release);
if let Some(handle) = self.thread_handle.take() {
handle.join().ok();
}
}
}
impl AudioCaptureHandle {
pub fn start() -> anyhow::Result<Self> {
let (tx, rx) = mpsc::unbounded_channel::<AudioChunk>();
let stop_signal = Arc::new(AtomicBool::new(false));
let stop_clone = Arc::clone(&stop_signal);
let (init_tx, init_rx) = std::sync::mpsc::channel::<anyhow::Result<()>>();
let thread_handle = std::thread::Builder::new()
.name("whisrs-audio".into())
.spawn(move || {
run_capture(tx, stop_clone, init_tx);
})
.context("failed to spawn audio capture thread")?;
let init_result = init_rx
.recv()
.map_err(|_| anyhow::anyhow!("audio capture thread exited unexpectedly"))?;
init_result?;
Ok(Self {
receiver: Some(rx),
stop_signal,
thread_handle: Some(thread_handle),
})
}
pub fn take_receiver(&mut self) -> Option<mpsc::UnboundedReceiver<AudioChunk>> {
self.receiver.take()
}
pub fn stop(&mut self) {
self.stop_signal.store(true, Ordering::Release);
}
pub async fn stop_and_collect(mut self) -> anyhow::Result<Vec<i16>> {
self.stop_signal.store(true, Ordering::Release);
if let Some(handle) = self.thread_handle.take() {
tokio::task::spawn_blocking(move || {
handle.join().ok();
})
.await?;
}
let mut all_samples = Vec::new();
if let Some(mut rx) = self.receiver.take() {
rx.close();
while let Ok(chunk) = rx.try_recv() {
all_samples.extend_from_slice(&chunk);
}
}
info!("captured {} audio samples", all_samples.len());
Ok(all_samples)
}
}
fn run_capture(
tx: mpsc::UnboundedSender<AudioChunk>,
stop_signal: Arc<AtomicBool>,
init_tx: std::sync::mpsc::Sender<anyhow::Result<()>>,
) {
let result = setup_and_run(tx, stop_signal, &init_tx);
if let Err(e) = result {
init_tx.send(Err(e)).ok();
}
}
fn setup_and_run(
tx: mpsc::UnboundedSender<AudioChunk>,
stop_signal: Arc<AtomicBool>,
init_tx: &std::sync::mpsc::Sender<anyhow::Result<()>>,
) -> anyhow::Result<()> {
let host = cpal::default_host();
let device = host
.default_input_device()
.ok_or_else(|| anyhow::anyhow!("no default audio input device found"))?;
let device_name = device.name().unwrap_or_else(|_| "unknown".into());
info!("using audio input device: {device_name}");
let config = StreamConfig {
channels: CHANNELS,
sample_rate: SampleRate(SAMPLE_RATE),
buffer_size: cpal::BufferSize::Default,
};
let supported = device
.supported_input_configs()
.context("failed to query supported input configs")?;
let mut found_match = false;
for range in supported {
if range.channels() == CHANNELS
&& range.min_sample_rate().0 <= SAMPLE_RATE
&& range.max_sample_rate().0 >= SAMPLE_RATE
&& range.sample_format() == SampleFormat::I16
{
found_match = true;
break;
}
}
if !found_match {
warn!(
"device may not natively support {SAMPLE_RATE}Hz mono i16; \
cpal will attempt conversion"
);
}
let err_callback = |err: cpal::StreamError| {
error!("audio stream error: {err}");
};
let stream = device
.build_input_stream(
&config,
move |data: &[i16], _info: &cpal::InputCallbackInfo| {
if tx.send(data.to_vec()).is_err() {
}
},
err_callback,
None,
)
.context("failed to build audio input stream")?;
stream.play().context("failed to start audio stream")?;
debug!("audio capture started at {SAMPLE_RATE}Hz mono i16");
init_tx.send(Ok(())).ok();
while !stop_signal.load(Ordering::Acquire) {
std::thread::sleep(std::time::Duration::from_millis(50));
}
debug!("audio capture stopping");
drop(stream);
Ok(())
}
pub fn encode_wav(samples: &[i16]) -> anyhow::Result<Vec<u8>> {
let spec = hound::WavSpec {
channels: CHANNELS,
sample_rate: SAMPLE_RATE,
bits_per_sample: 16,
sample_format: hound::SampleFormat::Int,
};
let mut cursor = std::io::Cursor::new(Vec::new());
{
let mut writer =
hound::WavWriter::new(&mut cursor, spec).context("failed to create WAV writer")?;
for &sample in samples {
writer
.write_sample(sample)
.context("failed to write WAV sample")?;
}
writer.finalize().context("failed to finalize WAV")?;
}
Ok(cursor.into_inner())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encode_wav_produces_valid_output() {
let samples: Vec<i16> = (0..1600).map(|i| (i % 256) as i16).collect();
let wav = encode_wav(&samples).unwrap();
assert_eq!(&wav[..4], b"RIFF");
let cursor = std::io::Cursor::new(&wav);
let reader = hound::WavReader::new(cursor).unwrap();
let spec = reader.spec();
assert_eq!(spec.channels, 1);
assert_eq!(spec.sample_rate, 16_000);
assert_eq!(spec.bits_per_sample, 16);
let read_samples: Vec<i16> = reader.into_samples::<i16>().map(|s| s.unwrap()).collect();
assert_eq!(read_samples.len(), 1600);
assert_eq!(read_samples, samples);
}
#[test]
fn encode_wav_empty_samples() {
let wav = encode_wav(&[]).unwrap();
assert_eq!(&wav[..4], b"RIFF");
}
}