1use std::path::Path;
8use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
9use std::sync::Arc;
10
11#[cfg(feature = "device-input")]
12use crossbeam::channel::bounded;
13use crossbeam::channel::{Receiver, Sender};
14#[cfg(feature = "device-input")]
15use tracing::{debug, error, info, warn};
16#[cfg(not(feature = "device-input"))]
17use tracing::{debug, info};
18
19use crate::error::{AudioFftError, Result};
20use crate::messages::AudioFrame;
21
22pub trait AudioSource: Send + Sync {
24 fn sample_rate(&self) -> u32;
26
27 fn channels(&self) -> u8;
29
30 fn total_samples(&self) -> Option<u64>;
32
33 fn read_frame(&mut self, frame_size: usize) -> Result<Option<AudioFrame>>;
35
36 fn is_exhausted(&self) -> bool;
38
39 fn reset(&mut self) -> Result<()>;
41}
42
43pub struct FileSource {
45 path: String,
46 sample_rate: u32,
47 channels: u8,
48 samples: Vec<f32>,
49 position: usize,
50 frame_counter: u64,
51}
52
53impl FileSource {
54 pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
56 let path_str = path.as_ref().to_string_lossy().to_string();
57 info!("Opening audio file: {}", path_str);
58
59 let reader = hound::WavReader::open(path.as_ref())
60 .map_err(|e| AudioFftError::file_read(format!("{}: {}", path_str, e)))?;
61
62 let spec = reader.spec();
63 let sample_rate = spec.sample_rate;
64 let channels = spec.channels as u8;
65
66 debug!(
67 "File spec: {} Hz, {} channels, {} bits, {:?}",
68 sample_rate, channels, spec.bits_per_sample, spec.sample_format
69 );
70
71 let samples: Vec<f32> = match spec.sample_format {
73 hound::SampleFormat::Float => reader
74 .into_samples::<f32>()
75 .filter_map(|s| s.ok())
76 .collect(),
77 hound::SampleFormat::Int => {
78 let scale = 1.0 / (1 << (spec.bits_per_sample - 1)) as f32;
79 reader
80 .into_samples::<i32>()
81 .filter_map(|s| s.ok())
82 .map(|s| s as f32 * scale)
83 .collect()
84 }
85 };
86
87 info!(
88 "Loaded {} samples ({:.2} seconds)",
89 samples.len(),
90 samples.len() as f64 / channels as f64 / sample_rate as f64
91 );
92
93 Ok(Self {
94 path: path_str,
95 sample_rate,
96 channels,
97 samples,
98 position: 0,
99 frame_counter: 0,
100 })
101 }
102
103 pub fn path(&self) -> &str {
105 &self.path
106 }
107
108 pub fn duration_secs(&self) -> f64 {
110 self.samples.len() as f64 / self.channels as f64 / self.sample_rate as f64
111 }
112}
113
114impl AudioSource for FileSource {
115 fn sample_rate(&self) -> u32 {
116 self.sample_rate
117 }
118
119 fn channels(&self) -> u8 {
120 self.channels
121 }
122
123 fn total_samples(&self) -> Option<u64> {
124 Some(self.samples.len() as u64 / self.channels as u64)
125 }
126
127 fn read_frame(&mut self, frame_size: usize) -> Result<Option<AudioFrame>> {
128 if self.position >= self.samples.len() {
129 return Ok(None);
130 }
131
132 let samples_to_read =
133 (frame_size * self.channels as usize).min(self.samples.len() - self.position);
134
135 let frame_samples = self.samples[self.position..self.position + samples_to_read].to_vec();
136 let timestamp = self.position as u64 / self.channels as u64;
137
138 self.position += samples_to_read;
139 self.frame_counter += 1;
140
141 Ok(Some(AudioFrame::new(
142 self.frame_counter,
143 self.sample_rate,
144 self.channels,
145 frame_samples,
146 timestamp,
147 )))
148 }
149
150 fn is_exhausted(&self) -> bool {
151 self.position >= self.samples.len()
152 }
153
154 fn reset(&mut self) -> Result<()> {
155 self.position = 0;
156 self.frame_counter = 0;
157 Ok(())
158 }
159}
160
161pub struct DeviceStream {
163 sample_rate: u32,
164 channels: u8,
165 receiver: Receiver<Vec<f32>>,
166 buffer: Vec<f32>,
167 frame_counter: Arc<AtomicU64>,
168 running: Arc<AtomicBool>,
169 #[cfg(feature = "device-input")]
171 _stream: Option<cpal::Stream>,
172}
173
174#[derive(Debug, Clone)]
176pub struct DeviceConfig {
177 pub sample_rate: Option<u32>,
179 pub channels: Option<u8>,
181 pub buffer_size: usize,
183 pub device_name: Option<String>,
185}
186
187impl Default for DeviceConfig {
188 fn default() -> Self {
189 Self {
190 sample_rate: None,
191 channels: None,
192 buffer_size: 4096,
193 device_name: None,
194 }
195 }
196}
197
198impl DeviceStream {
199 #[cfg(feature = "device-input")]
201 pub fn new(config: DeviceConfig) -> Result<Self> {
202 use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
203
204 let host = cpal::default_host();
205
206 let device = if let Some(name) = &config.device_name {
207 host.input_devices()
208 .map_err(|e| AudioFftError::device(format!("Failed to enumerate devices: {}", e)))?
209 .find(|d| d.name().map(|n| n.contains(name)).unwrap_or(false))
210 .ok_or_else(|| AudioFftError::device(format!("Device '{}' not found", name)))?
211 } else {
212 host.default_input_device()
213 .ok_or_else(|| AudioFftError::device("No default input device"))?
214 };
215
216 let device_name = device.name().unwrap_or_else(|_| "unknown".to_string());
217 info!("Using input device: {}", device_name);
218
219 let supported_config = device
220 .default_input_config()
221 .map_err(|e| AudioFftError::device(format!("Failed to get device config: {}", e)))?;
222
223 let sample_rate = config
224 .sample_rate
225 .unwrap_or(supported_config.sample_rate().0);
226 let channels = config.channels.unwrap_or(supported_config.channels() as u8);
227
228 debug!("Stream config: {} Hz, {} channels", sample_rate, channels);
229
230 let (sender, receiver) = bounded(64);
231 let running = Arc::new(AtomicBool::new(true));
232 let frame_counter = Arc::new(AtomicU64::new(0));
233
234 let running_clone = running.clone();
235 let sender_clone = sender.clone();
236
237 let stream_config = cpal::StreamConfig {
238 channels: channels as u16,
239 sample_rate: cpal::SampleRate(sample_rate),
240 buffer_size: cpal::BufferSize::Fixed(config.buffer_size as u32),
241 };
242
243 let stream = device
244 .build_input_stream(
245 &stream_config,
246 move |data: &[f32], _: &cpal::InputCallbackInfo| {
247 if running_clone.load(Ordering::Relaxed) {
248 if sender_clone.try_send(data.to_vec()).is_err() {
249 warn!("Audio buffer overflow - dropping samples");
250 }
251 }
252 },
253 move |err| {
254 error!("Audio stream error: {}", err);
255 },
256 None,
257 )
258 .map_err(|e| AudioFftError::device(format!("Failed to build stream: {}", e)))?;
259
260 stream
261 .play()
262 .map_err(|e| AudioFftError::device(format!("Failed to start stream: {}", e)))?;
263
264 info!("Audio device stream started");
265
266 Ok(Self {
267 sample_rate,
268 channels,
269 receiver,
270 buffer: Vec::with_capacity(config.buffer_size * 2),
271 frame_counter,
272 running,
273 _stream: Some(stream),
274 })
275 }
276
277 #[cfg(not(feature = "device-input"))]
279 pub fn new(_config: DeviceConfig) -> Result<Self> {
280 Err(AudioFftError::device(
281 "Device input not enabled. Compile with --features device-input",
282 ))
283 }
284
285 #[cfg(feature = "device-input")]
287 pub fn mock(
288 sample_rate: u32,
289 channels: u8,
290 _sender: Sender<Vec<f32>>,
291 receiver: Receiver<Vec<f32>>,
292 ) -> Self {
293 Self {
294 sample_rate,
295 channels,
296 receiver,
297 buffer: Vec::new(),
298 frame_counter: Arc::new(AtomicU64::new(0)),
299 running: Arc::new(AtomicBool::new(true)),
300 _stream: None,
301 }
302 }
303
304 #[cfg(not(feature = "device-input"))]
306 pub fn mock(
307 sample_rate: u32,
308 channels: u8,
309 _sender: Sender<Vec<f32>>,
310 receiver: Receiver<Vec<f32>>,
311 ) -> Self {
312 Self {
313 sample_rate,
314 channels,
315 receiver,
316 buffer: Vec::new(),
317 frame_counter: Arc::new(AtomicU64::new(0)),
318 running: Arc::new(AtomicBool::new(true)),
319 }
320 }
321
322 pub fn stop(&self) {
324 self.running.store(false, Ordering::Relaxed);
325 }
326
327 pub fn is_running(&self) -> bool {
329 self.running.load(Ordering::Relaxed)
330 }
331}
332
333impl AudioSource for DeviceStream {
334 fn sample_rate(&self) -> u32 {
335 self.sample_rate
336 }
337
338 fn channels(&self) -> u8 {
339 self.channels
340 }
341
342 fn total_samples(&self) -> Option<u64> {
343 None }
345
346 fn read_frame(&mut self, frame_size: usize) -> Result<Option<AudioFrame>> {
347 let required_samples = frame_size * self.channels as usize;
348
349 while self.buffer.len() < required_samples {
351 match self.receiver.try_recv() {
352 Ok(samples) => self.buffer.extend(samples),
353 Err(crossbeam::channel::TryRecvError::Empty) => {
354 if !self.is_running() && self.buffer.is_empty() {
356 return Ok(None);
357 }
358 if !self.buffer.is_empty() {
360 break;
361 }
362 match self
364 .receiver
365 .recv_timeout(std::time::Duration::from_millis(100))
366 {
367 Ok(samples) => self.buffer.extend(samples),
368 Err(_) => {
369 if !self.is_running() {
370 return Ok(None);
371 }
372 continue;
373 }
374 }
375 }
376 Err(crossbeam::channel::TryRecvError::Disconnected) => {
377 if self.buffer.is_empty() {
378 return Ok(None);
379 }
380 break;
381 }
382 }
383 }
384
385 if self.buffer.is_empty() {
386 return Ok(None);
387 }
388
389 let samples_to_take = required_samples.min(self.buffer.len());
390 let frame_samples: Vec<f32> = self.buffer.drain(..samples_to_take).collect();
391
392 let frame_id = self.frame_counter.fetch_add(1, Ordering::Relaxed);
393 let timestamp = frame_id * frame_size as u64;
394
395 Ok(Some(AudioFrame::new(
396 frame_id,
397 self.sample_rate,
398 self.channels,
399 frame_samples,
400 timestamp,
401 )))
402 }
403
404 fn is_exhausted(&self) -> bool {
405 !self.is_running() && self.buffer.is_empty()
406 }
407
408 fn reset(&mut self) -> Result<()> {
409 self.buffer.clear();
410 self.frame_counter.store(0, Ordering::Relaxed);
411 Ok(())
412 }
413}
414
415impl Drop for DeviceStream {
416 fn drop(&mut self) {
417 self.stop();
418 }
419}
420
421pub enum AudioInput {
423 File(FileSource),
425 Device(DeviceStream),
427}
428
429impl AudioInput {
430 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
432 Ok(Self::File(FileSource::open(path)?))
433 }
434
435 pub fn from_device(config: DeviceConfig) -> Result<Self> {
437 Ok(Self::Device(DeviceStream::new(config)?))
438 }
439
440 pub fn from_samples(samples: Vec<f32>, sample_rate: u32, channels: u8) -> Self {
442 Self::File(FileSource {
443 path: "<memory>".to_string(),
444 sample_rate,
445 channels,
446 samples,
447 position: 0,
448 frame_counter: 0,
449 })
450 }
451}
452
453impl AudioSource for AudioInput {
454 fn sample_rate(&self) -> u32 {
455 match self {
456 Self::File(f) => f.sample_rate(),
457 Self::Device(d) => d.sample_rate(),
458 }
459 }
460
461 fn channels(&self) -> u8 {
462 match self {
463 Self::File(f) => f.channels(),
464 Self::Device(d) => d.channels(),
465 }
466 }
467
468 fn total_samples(&self) -> Option<u64> {
469 match self {
470 Self::File(f) => f.total_samples(),
471 Self::Device(d) => d.total_samples(),
472 }
473 }
474
475 fn read_frame(&mut self, frame_size: usize) -> Result<Option<AudioFrame>> {
476 match self {
477 Self::File(f) => f.read_frame(frame_size),
478 Self::Device(d) => d.read_frame(frame_size),
479 }
480 }
481
482 fn is_exhausted(&self) -> bool {
483 match self {
484 Self::File(f) => f.is_exhausted(),
485 Self::Device(d) => d.is_exhausted(),
486 }
487 }
488
489 fn reset(&mut self) -> Result<()> {
490 match self {
491 Self::File(f) => f.reset(),
492 Self::Device(d) => d.reset(),
493 }
494 }
495}
496
497#[derive(Debug, Clone)]
499pub struct AudioOutput {
500 pub sample_rate: u32,
502 pub channels: u8,
504 pub samples: Vec<f32>,
506}
507
508impl AudioOutput {
509 pub fn new(sample_rate: u32, channels: u8) -> Self {
511 Self {
512 sample_rate,
513 channels,
514 samples: Vec::new(),
515 }
516 }
517
518 pub fn from_samples(samples: Vec<f32>, sample_rate: u32, channels: u8) -> Self {
520 Self {
521 sample_rate,
522 channels,
523 samples,
524 }
525 }
526
527 pub fn append(&mut self, samples: &[f32]) {
529 self.samples.extend_from_slice(samples);
530 }
531
532 pub fn duration_secs(&self) -> f64 {
534 self.samples.len() as f64 / self.channels as f64 / self.sample_rate as f64
535 }
536
537 pub fn write_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
539 let spec = hound::WavSpec {
540 channels: self.channels as u16,
541 sample_rate: self.sample_rate,
542 bits_per_sample: 32,
543 sample_format: hound::SampleFormat::Float,
544 };
545
546 let mut writer = hound::WavWriter::create(path.as_ref(), spec)
547 .map_err(|e| AudioFftError::file_write(e.to_string()))?;
548
549 for sample in &self.samples {
550 writer
551 .write_sample(*sample)
552 .map_err(|e| AudioFftError::file_write(e.to_string()))?;
553 }
554
555 writer
556 .finalize()
557 .map_err(|e| AudioFftError::file_write(e.to_string()))?;
558
559 info!(
560 "Wrote {} samples to {}",
561 self.samples.len(),
562 path.as_ref().display()
563 );
564
565 Ok(())
566 }
567
568 pub fn normalize(&mut self) {
570 let max = self.samples.iter().map(|s| s.abs()).fold(0.0f32, f32::max);
571
572 if max > 1e-6 {
573 let scale = 1.0 / max;
574 for sample in &mut self.samples {
575 *sample *= scale;
576 }
577 }
578 }
579
580 pub fn apply_gain(&mut self, gain: f32) {
582 for sample in &mut self.samples {
583 *sample *= gain;
584 }
585 }
586}
587
588#[cfg(test)]
589mod tests {
590 use super::*;
591
592 #[test]
593 fn test_audio_input_from_samples() {
594 let samples = vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7];
595 let mut input = AudioInput::from_samples(samples, 44100, 2);
596
597 assert_eq!(input.sample_rate(), 44100);
598 assert_eq!(input.channels(), 2);
599 assert_eq!(input.total_samples(), Some(4)); let frame = input.read_frame(2).unwrap().unwrap();
602 assert_eq!(frame.samples.len(), 4); assert!(!input.is_exhausted());
604
605 let frame2 = input.read_frame(2).unwrap().unwrap();
606 assert_eq!(frame2.samples.len(), 4);
607 assert!(input.is_exhausted());
608 }
609
610 #[test]
611 fn test_audio_output() {
612 let mut output = AudioOutput::new(44100, 1);
613 output.append(&[0.5, -0.5, 0.25, -0.25]);
614
615 assert_eq!(output.samples.len(), 4);
616 assert!((output.duration_secs() - 4.0 / 44100.0).abs() < 1e-6);
617
618 output.normalize();
619 assert!((output.samples[0] - 1.0).abs() < 1e-6);
620 assert!((output.samples[1] - (-1.0)).abs() < 1e-6);
621 }
622}