Skip to main content

ringkernel_audio_fft/
audio_input.rs

1//! Audio input handling for files and device streams.
2//!
3//! This module provides a unified interface for reading audio from:
4//! - WAV files (via hound)
5//! - Real-time audio devices (via cpal)
6
7use std::path::Path;
8use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
9use std::sync::Arc;
10
11#[cfg(feature = "device-input")]
12use crossbeam::channel::bounded;
13use crossbeam::channel::{Receiver, Sender};
14#[cfg(feature = "device-input")]
15use tracing::{debug, error, info, warn};
16#[cfg(not(feature = "device-input"))]
17use tracing::{debug, info};
18
19use crate::error::{AudioFftError, Result};
20use crate::messages::AudioFrame;
21
22/// Audio source trait for unified input handling.
23pub trait AudioSource: Send + Sync {
24    /// Get the sample rate in Hz.
25    fn sample_rate(&self) -> u32;
26
27    /// Get the number of channels.
28    fn channels(&self) -> u8;
29
30    /// Get the total number of samples (None for streams).
31    fn total_samples(&self) -> Option<u64>;
32
33    /// Read the next frame of audio.
34    fn read_frame(&mut self, frame_size: usize) -> Result<Option<AudioFrame>>;
35
36    /// Check if the source is exhausted.
37    fn is_exhausted(&self) -> bool;
38
39    /// Reset to the beginning (if supported).
40    fn reset(&mut self) -> Result<()>;
41}
42
43/// Audio input from a WAV file.
44pub struct FileSource {
45    path: String,
46    sample_rate: u32,
47    channels: u8,
48    samples: Vec<f32>,
49    position: usize,
50    frame_counter: u64,
51}
52
53impl FileSource {
54    /// Open a WAV file for reading.
55    pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
56        let path_str = path.as_ref().to_string_lossy().to_string();
57        info!("Opening audio file: {}", path_str);
58
59        let reader = hound::WavReader::open(path.as_ref())
60            .map_err(|e| AudioFftError::file_read(format!("{}: {}", path_str, e)))?;
61
62        let spec = reader.spec();
63        let sample_rate = spec.sample_rate;
64        let channels = spec.channels as u8;
65
66        debug!(
67            "File spec: {} Hz, {} channels, {} bits, {:?}",
68            sample_rate, channels, spec.bits_per_sample, spec.sample_format
69        );
70
71        // Read all samples and convert to f32
72        let samples: Vec<f32> = match spec.sample_format {
73            hound::SampleFormat::Float => reader
74                .into_samples::<f32>()
75                .filter_map(|s| s.ok())
76                .collect(),
77            hound::SampleFormat::Int => {
78                let scale = 1.0 / (1 << (spec.bits_per_sample - 1)) as f32;
79                reader
80                    .into_samples::<i32>()
81                    .filter_map(|s| s.ok())
82                    .map(|s| s as f32 * scale)
83                    .collect()
84            }
85        };
86
87        info!(
88            "Loaded {} samples ({:.2} seconds)",
89            samples.len(),
90            samples.len() as f64 / channels as f64 / sample_rate as f64
91        );
92
93        Ok(Self {
94            path: path_str,
95            sample_rate,
96            channels,
97            samples,
98            position: 0,
99            frame_counter: 0,
100        })
101    }
102
103    /// Get the file path.
104    pub fn path(&self) -> &str {
105        &self.path
106    }
107
108    /// Get the duration in seconds.
109    pub fn duration_secs(&self) -> f64 {
110        self.samples.len() as f64 / self.channels as f64 / self.sample_rate as f64
111    }
112}
113
114impl AudioSource for FileSource {
115    fn sample_rate(&self) -> u32 {
116        self.sample_rate
117    }
118
119    fn channels(&self) -> u8 {
120        self.channels
121    }
122
123    fn total_samples(&self) -> Option<u64> {
124        Some(self.samples.len() as u64 / self.channels as u64)
125    }
126
127    fn read_frame(&mut self, frame_size: usize) -> Result<Option<AudioFrame>> {
128        if self.position >= self.samples.len() {
129            return Ok(None);
130        }
131
132        let samples_to_read =
133            (frame_size * self.channels as usize).min(self.samples.len() - self.position);
134
135        let frame_samples = self.samples[self.position..self.position + samples_to_read].to_vec();
136        let timestamp = self.position as u64 / self.channels as u64;
137
138        self.position += samples_to_read;
139        self.frame_counter += 1;
140
141        Ok(Some(AudioFrame::new(
142            self.frame_counter,
143            self.sample_rate,
144            self.channels,
145            frame_samples,
146            timestamp,
147        )))
148    }
149
150    fn is_exhausted(&self) -> bool {
151        self.position >= self.samples.len()
152    }
153
154    fn reset(&mut self) -> Result<()> {
155        self.position = 0;
156        self.frame_counter = 0;
157        Ok(())
158    }
159}
160
161/// Real-time audio device stream.
162pub struct DeviceStream {
163    sample_rate: u32,
164    channels: u8,
165    receiver: Receiver<Vec<f32>>,
166    buffer: Vec<f32>,
167    frame_counter: Arc<AtomicU64>,
168    running: Arc<AtomicBool>,
169    // Keep the stream alive (only with device-input feature)
170    #[cfg(feature = "device-input")]
171    _stream: Option<cpal::Stream>,
172}
173
174/// Device stream configuration.
175#[derive(Debug, Clone)]
176pub struct DeviceConfig {
177    /// Preferred sample rate (None = use device default).
178    pub sample_rate: Option<u32>,
179    /// Preferred channels (None = use device default).
180    pub channels: Option<u8>,
181    /// Buffer size in samples.
182    pub buffer_size: usize,
183    /// Device name (None = default device).
184    pub device_name: Option<String>,
185}
186
187impl Default for DeviceConfig {
188    fn default() -> Self {
189        Self {
190            sample_rate: None,
191            channels: None,
192            buffer_size: 4096,
193            device_name: None,
194        }
195    }
196}
197
198impl DeviceStream {
199    /// Create a new device stream with the default input device.
200    #[cfg(feature = "device-input")]
201    pub fn new(config: DeviceConfig) -> Result<Self> {
202        use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
203
204        let host = cpal::default_host();
205
206        let device = if let Some(name) = &config.device_name {
207            host.input_devices()
208                .map_err(|e| AudioFftError::device(format!("Failed to enumerate devices: {}", e)))?
209                .find(|d| d.name().map(|n| n.contains(name)).unwrap_or(false))
210                .ok_or_else(|| AudioFftError::device(format!("Device '{}' not found", name)))?
211        } else {
212            host.default_input_device()
213                .ok_or_else(|| AudioFftError::device("No default input device"))?
214        };
215
216        let device_name = device.name().unwrap_or_else(|_| "unknown".to_string());
217        info!("Using input device: {}", device_name);
218
219        let supported_config = device
220            .default_input_config()
221            .map_err(|e| AudioFftError::device(format!("Failed to get device config: {}", e)))?;
222
223        let sample_rate = config
224            .sample_rate
225            .unwrap_or(supported_config.sample_rate().0);
226        let channels = config.channels.unwrap_or(supported_config.channels() as u8);
227
228        debug!("Stream config: {} Hz, {} channels", sample_rate, channels);
229
230        let (sender, receiver) = bounded(64);
231        let running = Arc::new(AtomicBool::new(true));
232        let frame_counter = Arc::new(AtomicU64::new(0));
233
234        let running_clone = running.clone();
235        let sender_clone = sender.clone();
236
237        let stream_config = cpal::StreamConfig {
238            channels: channels as u16,
239            sample_rate: cpal::SampleRate(sample_rate),
240            buffer_size: cpal::BufferSize::Fixed(config.buffer_size as u32),
241        };
242
243        let stream = device
244            .build_input_stream(
245                &stream_config,
246                move |data: &[f32], _: &cpal::InputCallbackInfo| {
247                    if running_clone.load(Ordering::Relaxed) {
248                        if sender_clone.try_send(data.to_vec()).is_err() {
249                            warn!("Audio buffer overflow - dropping samples");
250                        }
251                    }
252                },
253                move |err| {
254                    error!("Audio stream error: {}", err);
255                },
256                None,
257            )
258            .map_err(|e| AudioFftError::device(format!("Failed to build stream: {}", e)))?;
259
260        stream
261            .play()
262            .map_err(|e| AudioFftError::device(format!("Failed to start stream: {}", e)))?;
263
264        info!("Audio device stream started");
265
266        Ok(Self {
267            sample_rate,
268            channels,
269            receiver,
270            buffer: Vec::with_capacity(config.buffer_size * 2),
271            frame_counter,
272            running,
273            _stream: Some(stream),
274        })
275    }
276
277    /// Create a dummy device stream (for testing without audio device).
278    #[cfg(not(feature = "device-input"))]
279    pub fn new(_config: DeviceConfig) -> Result<Self> {
280        Err(AudioFftError::device(
281            "Device input not enabled. Compile with --features device-input",
282        ))
283    }
284
285    /// Create a mock stream for testing.
286    #[cfg(feature = "device-input")]
287    pub fn mock(
288        sample_rate: u32,
289        channels: u8,
290        _sender: Sender<Vec<f32>>,
291        receiver: Receiver<Vec<f32>>,
292    ) -> Self {
293        Self {
294            sample_rate,
295            channels,
296            receiver,
297            buffer: Vec::new(),
298            frame_counter: Arc::new(AtomicU64::new(0)),
299            running: Arc::new(AtomicBool::new(true)),
300            _stream: None,
301        }
302    }
303
304    /// Create a mock stream for testing.
305    #[cfg(not(feature = "device-input"))]
306    pub fn mock(
307        sample_rate: u32,
308        channels: u8,
309        _sender: Sender<Vec<f32>>,
310        receiver: Receiver<Vec<f32>>,
311    ) -> Self {
312        Self {
313            sample_rate,
314            channels,
315            receiver,
316            buffer: Vec::new(),
317            frame_counter: Arc::new(AtomicU64::new(0)),
318            running: Arc::new(AtomicBool::new(true)),
319        }
320    }
321
322    /// Stop the stream.
323    pub fn stop(&self) {
324        self.running.store(false, Ordering::Relaxed);
325    }
326
327    /// Check if the stream is running.
328    pub fn is_running(&self) -> bool {
329        self.running.load(Ordering::Relaxed)
330    }
331}
332
333impl AudioSource for DeviceStream {
334    fn sample_rate(&self) -> u32 {
335        self.sample_rate
336    }
337
338    fn channels(&self) -> u8 {
339        self.channels
340    }
341
342    fn total_samples(&self) -> Option<u64> {
343        None // Stream has no defined length
344    }
345
346    fn read_frame(&mut self, frame_size: usize) -> Result<Option<AudioFrame>> {
347        let required_samples = frame_size * self.channels as usize;
348
349        // Fill buffer from receiver
350        while self.buffer.len() < required_samples {
351            match self.receiver.try_recv() {
352                Ok(samples) => self.buffer.extend(samples),
353                Err(crossbeam::channel::TryRecvError::Empty) => {
354                    // Not enough data yet
355                    if !self.is_running() && self.buffer.is_empty() {
356                        return Ok(None);
357                    }
358                    // Return what we have (might be less than frame_size)
359                    if !self.buffer.is_empty() {
360                        break;
361                    }
362                    // Wait for more data
363                    match self
364                        .receiver
365                        .recv_timeout(std::time::Duration::from_millis(100))
366                    {
367                        Ok(samples) => self.buffer.extend(samples),
368                        Err(_) => {
369                            if !self.is_running() {
370                                return Ok(None);
371                            }
372                            continue;
373                        }
374                    }
375                }
376                Err(crossbeam::channel::TryRecvError::Disconnected) => {
377                    if self.buffer.is_empty() {
378                        return Ok(None);
379                    }
380                    break;
381                }
382            }
383        }
384
385        if self.buffer.is_empty() {
386            return Ok(None);
387        }
388
389        let samples_to_take = required_samples.min(self.buffer.len());
390        let frame_samples: Vec<f32> = self.buffer.drain(..samples_to_take).collect();
391
392        let frame_id = self.frame_counter.fetch_add(1, Ordering::Relaxed);
393        let timestamp = frame_id * frame_size as u64;
394
395        Ok(Some(AudioFrame::new(
396            frame_id,
397            self.sample_rate,
398            self.channels,
399            frame_samples,
400            timestamp,
401        )))
402    }
403
404    fn is_exhausted(&self) -> bool {
405        !self.is_running() && self.buffer.is_empty()
406    }
407
408    fn reset(&mut self) -> Result<()> {
409        self.buffer.clear();
410        self.frame_counter.store(0, Ordering::Relaxed);
411        Ok(())
412    }
413}
414
415impl Drop for DeviceStream {
416    fn drop(&mut self) {
417        self.stop();
418    }
419}
420
421/// Unified audio input that can be either file or device.
422pub enum AudioInput {
423    /// File-based input.
424    File(FileSource),
425    /// Device stream input.
426    Device(DeviceStream),
427}
428
429impl AudioInput {
430    /// Create input from a file.
431    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
432        Ok(Self::File(FileSource::open(path)?))
433    }
434
435    /// Create input from the default audio device.
436    pub fn from_device(config: DeviceConfig) -> Result<Self> {
437        Ok(Self::Device(DeviceStream::new(config)?))
438    }
439
440    /// Create input from raw samples (for testing).
441    pub fn from_samples(samples: Vec<f32>, sample_rate: u32, channels: u8) -> Self {
442        Self::File(FileSource {
443            path: "<memory>".to_string(),
444            sample_rate,
445            channels,
446            samples,
447            position: 0,
448            frame_counter: 0,
449        })
450    }
451}
452
453impl AudioSource for AudioInput {
454    fn sample_rate(&self) -> u32 {
455        match self {
456            Self::File(f) => f.sample_rate(),
457            Self::Device(d) => d.sample_rate(),
458        }
459    }
460
461    fn channels(&self) -> u8 {
462        match self {
463            Self::File(f) => f.channels(),
464            Self::Device(d) => d.channels(),
465        }
466    }
467
468    fn total_samples(&self) -> Option<u64> {
469        match self {
470            Self::File(f) => f.total_samples(),
471            Self::Device(d) => d.total_samples(),
472        }
473    }
474
475    fn read_frame(&mut self, frame_size: usize) -> Result<Option<AudioFrame>> {
476        match self {
477            Self::File(f) => f.read_frame(frame_size),
478            Self::Device(d) => d.read_frame(frame_size),
479        }
480    }
481
482    fn is_exhausted(&self) -> bool {
483        match self {
484            Self::File(f) => f.is_exhausted(),
485            Self::Device(d) => d.is_exhausted(),
486        }
487    }
488
489    fn reset(&mut self) -> Result<()> {
490        match self {
491            Self::File(f) => f.reset(),
492            Self::Device(d) => d.reset(),
493        }
494    }
495}
496
497/// Audio output writer for WAV files.
498#[derive(Debug, Clone)]
499pub struct AudioOutput {
500    /// Sample rate in Hz.
501    pub sample_rate: u32,
502    /// Number of channels.
503    pub channels: u8,
504    /// Audio samples.
505    pub samples: Vec<f32>,
506}
507
508impl AudioOutput {
509    /// Create a new empty audio output.
510    pub fn new(sample_rate: u32, channels: u8) -> Self {
511        Self {
512            sample_rate,
513            channels,
514            samples: Vec::new(),
515        }
516    }
517
518    /// Create from existing samples.
519    pub fn from_samples(samples: Vec<f32>, sample_rate: u32, channels: u8) -> Self {
520        Self {
521            sample_rate,
522            channels,
523            samples,
524        }
525    }
526
527    /// Append samples.
528    pub fn append(&mut self, samples: &[f32]) {
529        self.samples.extend_from_slice(samples);
530    }
531
532    /// Get duration in seconds.
533    pub fn duration_secs(&self) -> f64 {
534        self.samples.len() as f64 / self.channels as f64 / self.sample_rate as f64
535    }
536
537    /// Write to a WAV file.
538    pub fn write_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
539        let spec = hound::WavSpec {
540            channels: self.channels as u16,
541            sample_rate: self.sample_rate,
542            bits_per_sample: 32,
543            sample_format: hound::SampleFormat::Float,
544        };
545
546        let mut writer = hound::WavWriter::create(path.as_ref(), spec)
547            .map_err(|e| AudioFftError::file_write(e.to_string()))?;
548
549        for sample in &self.samples {
550            writer
551                .write_sample(*sample)
552                .map_err(|e| AudioFftError::file_write(e.to_string()))?;
553        }
554
555        writer
556            .finalize()
557            .map_err(|e| AudioFftError::file_write(e.to_string()))?;
558
559        info!(
560            "Wrote {} samples to {}",
561            self.samples.len(),
562            path.as_ref().display()
563        );
564
565        Ok(())
566    }
567
568    /// Normalize the audio to peak at 1.0.
569    pub fn normalize(&mut self) {
570        let max = self.samples.iter().map(|s| s.abs()).fold(0.0f32, f32::max);
571
572        if max > 1e-6 {
573            let scale = 1.0 / max;
574            for sample in &mut self.samples {
575                *sample *= scale;
576            }
577        }
578    }
579
580    /// Apply gain.
581    pub fn apply_gain(&mut self, gain: f32) {
582        for sample in &mut self.samples {
583            *sample *= gain;
584        }
585    }
586}
587
588#[cfg(test)]
589mod tests {
590    use super::*;
591
592    #[test]
593    fn test_audio_input_from_samples() {
594        let samples = vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7];
595        let mut input = AudioInput::from_samples(samples, 44100, 2);
596
597        assert_eq!(input.sample_rate(), 44100);
598        assert_eq!(input.channels(), 2);
599        assert_eq!(input.total_samples(), Some(4)); // 8 samples / 2 channels
600
601        let frame = input.read_frame(2).unwrap().unwrap();
602        assert_eq!(frame.samples.len(), 4); // 2 samples * 2 channels
603        assert!(!input.is_exhausted());
604
605        let frame2 = input.read_frame(2).unwrap().unwrap();
606        assert_eq!(frame2.samples.len(), 4);
607        assert!(input.is_exhausted());
608    }
609
610    #[test]
611    fn test_audio_output() {
612        let mut output = AudioOutput::new(44100, 1);
613        output.append(&[0.5, -0.5, 0.25, -0.25]);
614
615        assert_eq!(output.samples.len(), 4);
616        assert!((output.duration_secs() - 4.0 / 44100.0).abs() < 1e-6);
617
618        output.normalize();
619        assert!((output.samples[0] - 1.0).abs() < 1e-6);
620        assert!((output.samples[1] - (-1.0)).abs() < 1e-6);
621    }
622}