use crate::{AudioFeatures, AudioFormat, AudioInput, MullamaError};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use tokio::sync::{mpsc, RwLock};
use tokio_stream::Stream as TokioStream;
#[cfg(feature = "streaming-audio")]
use {
cpal::{
traits::{DeviceTrait, HostTrait, StreamTrait},
BufferSize, Device, Host, SampleRate, Stream as CpalStream, StreamConfig,
},
ringbuf::{HeapConsumer, HeapProducer, HeapRb},
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AudioStreamConfig {
pub sample_rate: u32,
pub channels: u16,
pub buffer_size: usize,
pub format: AudioFormat,
pub enable_noise_reduction: bool,
pub enable_voice_detection: bool,
pub enable_agc: bool,
pub vad_threshold: f32,
pub noise_gate_threshold: f32,
pub max_latency_ms: u32,
pub device_preference: DevicePreference,
pub ring_buffer_capacity: usize,
}
impl Default for AudioStreamConfig {
fn default() -> Self {
Self {
sample_rate: 16000, channels: 1, buffer_size: 512, format: wav_audio_format(),
enable_noise_reduction: true,
enable_voice_detection: true,
enable_agc: true,
vad_threshold: 0.3,
noise_gate_threshold: -45.0,
max_latency_ms: 50,
device_preference: DevicePreference::Default,
ring_buffer_capacity: 8,
}
}
}
impl AudioStreamConfig {
pub fn new() -> Self {
Self::default()
}
pub fn sample_rate(mut self, rate: u32) -> Self {
self.sample_rate = rate;
self
}
pub fn channels(mut self, channels: u16) -> Self {
self.channels = channels;
self
}
pub fn buffer_size(mut self, size: usize) -> Self {
self.buffer_size = size;
self
}
pub fn enable_noise_reduction(mut self, enable: bool) -> Self {
self.enable_noise_reduction = enable;
self
}
pub fn enable_voice_detection(mut self, enable: bool) -> Self {
self.enable_voice_detection = enable;
self
}
pub fn vad_threshold(mut self, threshold: f32) -> Self {
self.vad_threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn max_latency_ms(mut self, latency: u32) -> Self {
self.max_latency_ms = latency;
self
}
}
fn wav_audio_format() -> AudioFormat {
AudioFormat {
container: "wav".to_string(),
codec: "pcm".to_string(),
bit_depth: 16,
bitrate: None,
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum DevicePreference {
Default,
ByName(String),
LowestLatency,
HighestQuality,
}
#[derive(Debug, Clone)]
pub struct AudioChunk {
pub samples: Vec<f32>,
pub channels: u16,
pub sample_rate: u32,
pub timestamp: Instant,
pub duration: Duration,
pub voice_detected: bool,
pub signal_level: f32,
pub features: Option<AudioFeatures>,
}
impl AudioChunk {
pub fn new(samples: Vec<f32>, channels: u16, sample_rate: u32) -> Self {
let duration =
Duration::from_secs_f32(samples.len() as f32 / (sample_rate * channels as u32) as f32);
let signal_level = calculate_rms(&samples);
Self {
samples,
channels,
sample_rate,
timestamp: Instant::now(),
duration,
voice_detected: false,
signal_level,
features: None,
}
}
pub fn to_audio_input(&self) -> AudioInput {
AudioInput {
samples: self.samples.clone(),
sample_rate: self.sample_rate,
channels: self.channels as u32,
duration: self.duration.as_secs_f32(),
format: wav_audio_format(),
transcript: None,
metadata: HashMap::new(),
}
}
pub fn apply_noise_gate(&mut self, threshold_db: f32) {
let threshold_linear = db_to_linear(threshold_db);
if self.signal_level < threshold_linear {
let sample_count = self.samples.len() as f32;
for (i, sample) in self.samples.iter_mut().enumerate() {
let fade_factor = (1.0 - (i as f32 / sample_count)) * 0.1;
*sample *= fade_factor;
}
}
}
pub fn detect_voice_activity(&mut self, threshold: f32) {
let energy = self.signal_level;
let zcr = calculate_zero_crossing_rate(&self.samples);
self.voice_detected = energy > threshold && zcr < 0.5;
}
}
pub struct StreamingAudioProcessor {
config: AudioStreamConfig,
#[cfg(feature = "streaming-audio")]
host: Host,
#[cfg(feature = "streaming-audio")]
input_device: Option<Device>,
#[cfg(feature = "streaming-audio")]
output_device: Option<Device>,
#[cfg(feature = "streaming-audio")]
stream: Option<CpalStream>,
#[cfg(feature = "streaming-audio")]
audio_producer: Option<HeapProducer<f32>>,
#[cfg(feature = "streaming-audio")]
audio_consumer: Option<HeapConsumer<f32>>,
is_recording: Arc<RwLock<bool>>,
chunk_sender: Option<mpsc::UnboundedSender<AudioChunk>>,
metrics: Arc<RwLock<StreamingMetrics>>,
}
#[derive(Debug, Clone, Default)]
pub struct StreamingMetrics {
pub chunks_processed: u64,
pub total_samples: u64,
pub dropped_chunks: u64,
pub average_latency_ms: f32,
pub max_latency_ms: f32,
pub cpu_usage: f32,
pub memory_usage_mb: f32,
pub voice_activity_time: Duration,
pub noise_events: u64,
}
impl StreamingAudioProcessor {
pub fn new(config: AudioStreamConfig) -> Result<Self, MullamaError> {
#[cfg(feature = "streaming-audio")]
{
let host = cpal::default_host();
Ok(Self {
config,
host,
input_device: None,
output_device: None,
stream: None,
audio_producer: None,
audio_consumer: None,
is_recording: Arc::new(RwLock::new(false)),
chunk_sender: None,
metrics: Arc::new(RwLock::new(StreamingMetrics::default())),
})
}
#[cfg(not(feature = "streaming-audio"))]
{
Err(MullamaError::FeatureNotAvailable(
"streaming-audio feature not enabled".to_string(),
))
}
}
#[cfg(feature = "streaming-audio")]
pub async fn initialize(&mut self) -> Result<(), MullamaError> {
self.input_device = Some(self.select_input_device()?);
self.output_device = Some(self.select_output_device()?);
let buffer_capacity = self.config.ring_buffer_capacity * self.config.buffer_size;
let ring_buffer = HeapRb::<f32>::new(buffer_capacity);
let (producer, consumer) = ring_buffer.split();
self.audio_producer = Some(producer);
self.audio_consumer = Some(consumer);
Ok(())
}
#[cfg(not(feature = "streaming-audio"))]
pub async fn initialize(&mut self) -> Result<(), MullamaError> {
Err(MullamaError::FeatureNotAvailable(
"streaming-audio feature not enabled".to_string(),
))
}
pub async fn start_capture(&mut self) -> Result<AudioStream, MullamaError> {
#[cfg(feature = "streaming-audio")]
{
if self.input_device.is_none() {
self.initialize().await?;
}
let (chunk_sender, chunk_receiver) = mpsc::unbounded_channel();
self.chunk_sender = Some(chunk_sender.clone());
let stream_config = self.build_stream_config()?;
let mut producer = self.audio_producer.take().ok_or_else(|| {
MullamaError::StreamingError("Ring buffer not initialized".to_string())
})?;
let config = self.config.clone();
let metrics = self.metrics.clone();
let is_recording = self.is_recording.clone();
let stream = self.input_device.as_ref().unwrap().build_input_stream(
&stream_config,
move |data: &[f32], _: &cpal::InputCallbackInfo| {
let recording = is_recording.try_read().map(|guard| *guard).unwrap_or(false);
if recording {
Self::process_audio_callback(
data,
&mut producer,
&config,
&chunk_sender,
&metrics,
);
}
},
move |err| {
eprintln!("Audio stream error: {}", err);
},
None,
)?;
stream.play()?;
self.stream = Some(stream);
*self.is_recording.write().await = true;
Ok(AudioStream::new(chunk_receiver, self.metrics.clone()))
}
#[cfg(not(feature = "streaming-audio"))]
{
Err(MullamaError::FeatureNotAvailable(
"streaming-audio feature not enabled".to_string(),
))
}
}
pub async fn stop_capture(&mut self) -> Result<(), MullamaError> {
*self.is_recording.write().await = false;
#[cfg(feature = "streaming-audio")]
{
if let Some(stream) = self.stream.take() {
drop(stream);
}
}
Ok(())
}
pub async fn process_chunk(&self, chunk: &AudioChunk) -> Result<AudioChunk, MullamaError> {
let mut processed_chunk = chunk.clone();
if self.config.enable_noise_reduction {
processed_chunk.apply_noise_gate(self.config.noise_gate_threshold);
}
if self.config.enable_voice_detection {
processed_chunk.detect_voice_activity(self.config.vad_threshold);
}
let mut metrics = self.metrics.write().await;
metrics.chunks_processed += 1;
metrics.total_samples += chunk.samples.len() as u64;
if processed_chunk.voice_detected {
metrics.voice_activity_time += chunk.duration;
}
Ok(processed_chunk)
}
pub async fn metrics(&self) -> StreamingMetrics {
self.metrics.read().await.clone()
}
#[cfg(feature = "streaming-audio")]
pub fn list_input_devices(&self) -> Result<Vec<String>, MullamaError> {
let devices: Result<Vec<String>, _> = self
.host
.input_devices()?
.map(|device| device.name().map_err(MullamaError::from))
.collect();
devices
}
#[cfg(not(feature = "streaming-audio"))]
pub fn list_input_devices(&self) -> Result<Vec<String>, MullamaError> {
Err(MullamaError::FeatureNotAvailable(
"streaming-audio feature not enabled".to_string(),
))
}
#[cfg(feature = "streaming-audio")]
fn select_input_device(&self) -> Result<Device, MullamaError> {
match &self.config.device_preference {
DevicePreference::Default => self.host.default_input_device().ok_or_else(|| {
MullamaError::AudioError("No default input device available".to_string())
}),
DevicePreference::ByName(name) => self
.host
.input_devices()?
.find(|device| device.name().map(|n| n == *name).unwrap_or(false))
.ok_or_else(|| {
MullamaError::AudioError(format!("Input device '{}' not found", name))
}),
DevicePreference::LowestLatency => {
self.select_input_device()
}
DevicePreference::HighestQuality => {
self.select_input_device()
}
}
}
#[cfg(feature = "streaming-audio")]
fn select_output_device(&self) -> Result<Device, MullamaError> {
self.host.default_output_device().ok_or_else(|| {
MullamaError::AudioError("No default output device available".to_string())
})
}
#[cfg(feature = "streaming-audio")]
fn build_stream_config(&self) -> Result<StreamConfig, MullamaError> {
let device = self.input_device.as_ref().unwrap();
let supported_configs = device.supported_input_configs()?;
let _config = supported_configs
.filter(|config| config.channels() == self.config.channels)
.find(|config| {
config.min_sample_rate() <= SampleRate(self.config.sample_rate)
&& config.max_sample_rate() >= SampleRate(self.config.sample_rate)
})
.ok_or_else(|| {
MullamaError::AudioError("No compatible audio configuration found".to_string())
})?;
Ok(StreamConfig {
channels: self.config.channels,
sample_rate: SampleRate(self.config.sample_rate),
buffer_size: BufferSize::Fixed(self.config.buffer_size as u32),
})
}
#[cfg(feature = "streaming-audio")]
fn process_audio_callback(
data: &[f32],
producer: &mut HeapProducer<f32>,
config: &AudioStreamConfig,
sender: &mpsc::UnboundedSender<AudioChunk>,
metrics: &Arc<RwLock<StreamingMetrics>>,
) {
let samples_written = producer.push_slice(data);
if samples_written < data.len() {
if let Ok(mut m) = metrics.try_write() {
m.dropped_chunks += 1;
}
}
if data.len() >= config.buffer_size {
let chunk = AudioChunk::new(data.to_vec(), config.channels, config.sample_rate);
let _ = sender.send(chunk);
}
}
}
pub struct AudioStream {
receiver: mpsc::UnboundedReceiver<AudioChunk>,
metrics: Arc<RwLock<StreamingMetrics>>,
}
impl AudioStream {
fn new(
receiver: mpsc::UnboundedReceiver<AudioChunk>,
metrics: Arc<RwLock<StreamingMetrics>>,
) -> Self {
Self { receiver, metrics }
}
pub async fn next(&mut self) -> Option<AudioChunk> {
self.receiver.recv().await
}
pub async fn metrics(&self) -> StreamingMetrics {
self.metrics.read().await.clone()
}
}
impl TokioStream for AudioStream {
type Item = AudioChunk;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.receiver.poll_recv(cx)
}
}
fn calculate_rms(samples: &[f32]) -> f32 {
if samples.is_empty() {
return 0.0;
}
let sum_squares: f32 = samples.iter().map(|&x| x * x).sum();
(sum_squares / samples.len() as f32).sqrt()
}
fn calculate_zero_crossing_rate(samples: &[f32]) -> f32 {
if samples.len() < 2 {
return 0.0;
}
let zero_crossings = samples
.windows(2)
.filter(|window| (window[0] >= 0.0) != (window[1] >= 0.0))
.count();
zero_crossings as f32 / (samples.len() - 1) as f32
}
fn db_to_linear(db: f32) -> f32 {
10.0_f32.powf(db / 20.0)
}
#[cfg(feature = "streaming-audio")]
impl From<cpal::DevicesError> for MullamaError {
fn from(err: cpal::DevicesError) -> Self {
MullamaError::AudioError(format!("Device enumeration error: {}", err))
}
}
#[cfg(feature = "streaming-audio")]
impl From<cpal::BuildStreamError> for MullamaError {
fn from(err: cpal::BuildStreamError) -> Self {
MullamaError::AudioError(format!("Stream build error: {}", err))
}
}
#[cfg(feature = "streaming-audio")]
impl From<cpal::PlayStreamError> for MullamaError {
fn from(err: cpal::PlayStreamError) -> Self {
MullamaError::AudioError(format!("Stream play error: {}", err))
}
}
#[cfg(feature = "streaming-audio")]
impl From<cpal::SupportedStreamConfigsError> for MullamaError {
fn from(err: cpal::SupportedStreamConfigsError) -> Self {
MullamaError::AudioError(format!("Stream config error: {}", err))
}
}
#[cfg(feature = "streaming-audio")]
impl From<cpal::DeviceNameError> for MullamaError {
fn from(err: cpal::DeviceNameError) -> Self {
MullamaError::AudioError(format!("Device name error: {}", err))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_audio_chunk_creation() {
let samples = vec![0.1, -0.2, 0.3, -0.4];
let chunk = AudioChunk::new(samples.clone(), 1, 44100);
assert_eq!(chunk.samples, samples);
assert_eq!(chunk.channels, 1);
assert_eq!(chunk.sample_rate, 44100);
assert!(chunk.signal_level > 0.0);
}
#[test]
fn test_rms_calculation() {
let samples = vec![1.0, -1.0, 1.0, -1.0];
let rms = calculate_rms(&samples);
assert!((rms - 1.0).abs() < 0.001);
}
#[test]
fn test_zero_crossing_rate() {
let samples = vec![1.0, -1.0, 1.0, -1.0];
let zcr = calculate_zero_crossing_rate(&samples);
assert!((zcr - 1.0).abs() < 0.001);
}
#[test]
fn test_config_builder() {
let config = AudioStreamConfig::new()
.sample_rate(48000)
.channels(2)
.buffer_size(1024)
.vad_threshold(0.5);
assert_eq!(config.sample_rate, 48000);
assert_eq!(config.channels, 2);
assert_eq!(config.buffer_size, 1024);
assert_eq!(config.vad_threshold, 0.5);
}
}