dasp_rs/core/
io.rs

1use hound::{WavReader, WavWriter, WavSpec, SampleFormat};
2use std::path::Path;
3use thiserror::Error;
4use crate::signal_processing::{to_mono, resample};
5use ndarray::ShapeError;
6use rayon::prelude::*;
7use std::sync::mpsc::{channel, Receiver};
8use std::io::Cursor;
9
10/// Enumerates error conditions for WAV-based audio operations.
11///
12/// Variants encapsulate specific failure modes encountered during file I/O, format parsing,
13/// or signal processing, with detailed diagnostics for DSP pipeline debugging.
14#[derive(Error, Debug)]
15pub enum AudioError {
16    /// WAV file open failure, typically due to invalid path or corrupted header.
17    #[error("WAV open failed: {0}")]
18    OpenError(#[from] hound::Error),
19    
20    /// Unsupported WAV sample format (only PCM 16-bit int and 32-bit float are supported).
21    #[error("Unsupported WAV sample format")]
22    UnsupportedFormat,
23    
24    /// Offset or duration exceeds sample bounds.
25    #[error("Offset or duration out of bounds")]
26    InvalidRange,
27    
28    /// General I/O error outside `hound` operations (e.g., filesystem issues).
29    #[error("I/O error: {0}")]
30    IoError(#[from] std::io::Error),
31    
32    /// `hound`-specific error during sample read/write.
33    #[error("Hound processing error: {0}")]
34    HoundError(hound::Error),
35    
36    /// Resampling failure from `signal_processing::resampling`.
37    #[error("Resampling error: {0}")]
38    ResampleError(#[from] crate::signal_processing::resampling::ResampleError),
39    
40    /// Streaming operation failure (e.g., channel disconnect).
41    #[error("Stream processing error")]
42    StreamError,
43    
44    /// Array shape mismatch from `ndarray` operations.
45    #[error("Shape mismatch: {0}")]
46    ShapeError(#[from] ShapeError),
47    
48    /// Insufficient samples for requested operation.
49    #[error("Insufficient sample count: {0}")]
50    InsufficientData(String),
51    
52    /// Invalid parameter (e.g., negative offset).
53    #[error("Invalid parameter: {0}")]
54    InvalidInput(String),
55    
56    /// Numerical computation failure (e.g., overflow).
57    #[error("Computation error: {0}")]
58    ComputationFailed(String),
59
60    /// File not found at the specified path.
61    #[error("File not found: {0}")]
62    FileNotFound(String),
63}
64
65/// Core audio data container for WAV-based DSP workflows.
66///
67/// Stores interleaved 32-bit float samples with associated sample rate and channel count.
68/// Optimized for in-memory processing and compatibility with `signal_processing` operations.
69///
70/// # Fields
71/// - `samples`: Interleaved `f32` sample buffer (e.g., `[L1, R1, L2, R2...]` for stereo)
72/// - `sample_rate`: Samples per second (Hz)
73/// - `channels`: Number of channels (1 = mono, 2 = stereo)
74#[derive(Debug, Clone)]
75pub struct AudioData {
76    pub samples: Vec<f32>,
77    pub sample_rate: u32,
78    pub channels: u16,
79}
80
81impl AudioData {
82    /// Constructs an `AudioData` instance from raw components.
83    ///
84    /// # Parameters
85    /// - `samples`: Interleaved `f32` sample buffer
86    /// - `sample_rate`: Sample rate in Hz
87    /// - `channels`: Channel count
88    ///
89    /// # Returns
90    /// Initialized `AudioData` instance
91    ///
92    /// # Example
93    /// ```
94    /// use crate::core::AudioData;
95    /// let audio = AudioData::new(
96    ///     vec![0.5, -0.3, 0.8], // 3 mono samples
97    ///     44100,                // 44.1 kHz
98    ///     1                     // Mono
99    /// );
100    /// assert_eq!(audio.samples.len(), 3);
101    /// assert_eq!(audio.sample_rate, 44100);
102    /// assert_eq!(audio.channels, 1);
103    /// ```
104    pub fn new(samples: Vec<f32>, sample_rate: u32, channels: u16) -> Self {
105        Self { samples, sample_rate, channels }
106    }
107}
108
109/// Loads WAV file into `AudioData` with optional DSP transformations.
110///
111/// Reads WAV data in-memory via `Cursor`, supporting 16-bit PCM and 32-bit float formats.
112/// Applies resampling, mono conversion, and sample trimming as specified.
113///
114/// # Parameters
115/// - `path`: WAV file path (`AsRef<Path>`)
116/// - `sr`: Target sample rate (Hz); `None` retains source rate
117/// - `mono`: Convert to mono if `Some(true)`; `None` defaults to `true`
118/// - `offset`: Start time (seconds); `None` defaults to 0.0
119/// - `duration`: Segment length (seconds); `None` takes full length
120///
121/// # Returns
122/// - `Ok(AudioData)`: Processed audio data
123/// - `Err(AudioError)`: Failure due to I/O, format, or parameter errors
124///
125/// # Errors
126/// - `AudioError::FileNotFound`: The specified file does not exist
127/// - `AudioError::InvalidRange`: Offset/duration exceeds file length
128/// - `AudioError::OpenError`: Invalid WAV file or corrupted header
129///
130/// # Examples
131/// ```
132/// use crate::core::{load, AudioData};
133/// // Load entire file as mono at original sample rate
134/// let audio = load("track.wav", None, Some(true), None, None)?;
135/// 
136/// // Load 5-second segment starting at 2 seconds, resampled to 16kHz
137/// let segment = load("track.wav", Some(16000), Some(true), Some(2.0), Some(5.0))?;
138/// # Ok::<(), crate::core::AudioError>(())
139/// ```
140pub fn load<P: AsRef<Path>>(
141    path: P,
142    sr: Option<u32>,
143    mono: Option<bool>,
144    offset: Option<f32>,
145    duration: Option<f32>,
146) -> Result<AudioData, AudioError> {
147    let path = path.as_ref();
148    if !path.exists() {
149        return Err(AudioError::FileNotFound(path.to_string_lossy().into_owned()));
150    }
151
152    let wav_data = std::fs::read(&path)?;
153    let mut reader = WavReader::new(Cursor::new(wav_data))?;
154    let spec = reader.spec();
155    let sample_rate = spec.sample_rate;
156
157    let start = (offset.unwrap_or(0.0) * sample_rate as f32) as usize;
158    let len = duration.map(|d| (d * sample_rate as f32) as usize);
159
160    let samples: Vec<f32> = match spec.sample_format {
161        SampleFormat::Float => reader.samples::<f32>()
162            .skip(start)
163            .take(len.unwrap_or(usize::MAX))
164            .map(|s| s.unwrap())
165            .collect(),
166        SampleFormat::Int => reader.samples::<i16>()
167            .skip(start)
168            .take(len.unwrap_or(usize::MAX))
169            .map(|s| s.unwrap() as f32 / i16::MAX as f32)
170            .collect(),
171    };
172
173    if start >= samples.len() && !samples.is_empty() {
174        return Err(AudioError::InvalidRange);
175    }
176
177    let mut samples = samples;
178    let channels = spec.channels as usize;
179    if channels > 1 && mono.unwrap_or(true) {
180        samples = to_mono(&samples, channels);
181    }
182
183    let final_samples = if let Some(target_samplerate) = sr {
184        if target_samplerate != sample_rate {
185            resample(&samples, sample_rate, target_samplerate)?
186        } else {
187            samples
188        }
189    } else {
190        samples
191    };
192
193    Ok(AudioData::new(final_samples, sr.unwrap_or(sample_rate), if mono.unwrap_or(true) { 1 } else { spec.channels }))
194}
195
196/// Exports `AudioData` to a WAV file using in-memory buffering.
197///
198/// Writes 32-bit float WAV data via `Cursor`, committing to disk in a single operation.
199///
200/// # Parameters
201/// - `path`: Output WAV file path (`AsRef<Path>`)
202/// - `audio_data`: Source `AudioData` reference
203///
204/// # Returns
205/// - `Ok(())`: Successful write
206/// - `Err(AudioError)`: I/O or format error
207///
208/// # Errors
209/// - `AudioError::IoError`: Failed to write to filesystem
210/// - `AudioError::HoundError`: WAV format encoding error
211///
212/// # Notes
213/// - Automatically clamps samples to `[-1.0, 1.0]` range
214/// - Preserves channel count and sample rate metadata
215///
216/// # Example
217/// ```
218/// use crate::core::{AudioData, export};
219/// let audio = AudioData::new(vec![0.1, 0.2, 0.3], 44100, 1);
220/// export("output.wav", &audio)?;
221/// # Ok::<(), crate::core::AudioError>(())
222/// ```
223pub fn export<P: AsRef<Path>>(path: P, audio_data: &AudioData) -> Result<(), AudioError> {
224    let spec = WavSpec {
225        channels: audio_data.channels,
226        sample_rate: audio_data.sample_rate,
227        bits_per_sample: 32,
228        sample_format: SampleFormat::Float,
229    };
230
231    let mut buffer = Vec::with_capacity(audio_data.samples.len() * 4 + 44); // Rough WAV size estimate
232    let mut writer = WavWriter::new(Cursor::new(&mut buffer), spec)?;
233    for &sample in &audio_data.samples {
234        writer.write_sample(sample)?;
235    }
236    writer.finalize()?;
237    std::fs::write(path, buffer)?;
238    Ok(())
239}
240
241/// Generates an iterator over WAV sample blocks with parallel processing.
242///
243/// Splits WAV data into fixed-size blocks, processed in parallel using `rayon`.
244///
245/// # Parameters
246/// - `path`: WAV file path (`AsRef<Path>`).
247/// - `block_length`: Maximum block count.
248/// - `frame_length`: Samples per block.
249/// - `hop_length`: Step size between blocks; `None` uses `frame_length`.
250///
251/// # Returns
252/// - `Ok(impl Iterator<Item = Vec<f32>>)`: Block iterator.
253/// - `Err(AudioError)`: I/O or format error.
254///
255/// # Errors
256/// - `AudioError::FileNotFound`: The specified file does not exist
257/// - `AudioError::OpenError`: Invalid WAV file or corrupted header
258///
259/// # Example
260/// ```
261/// use crate::core::stream;
262/// let stream = stream("audio.wav", 100, 4096, None)?;
263/// for block in stream {
264///     // Process each 4096-sample block
265///     println!("Block size: {}", block.len());
266/// }
267/// # Ok::<(), crate::core::AudioError>(())
268/// ```
269///
270/// # Performance
271/// - Uses `rayon` thread pool for parallel block processing
272/// - Best for offline processing of <1GB files
273pub fn stream<P: AsRef<Path>>(
274    path: P,
275    block_length: usize,
276    frame_length: usize,
277    hop_length: Option<usize>,
278) -> Result<impl Iterator<Item = Vec<f32>>, AudioError> {
279    let path = path.as_ref();
280    if !path.exists() {
281        return Err(AudioError::FileNotFound(path.to_string_lossy().into_owned()));
282    }
283
284    let wav_data = std::fs::read(&path)?;
285    let mut reader = WavReader::new(Cursor::new(wav_data))?;
286    let spec = reader.spec();
287    let hop = hop_length.unwrap_or(frame_length);
288
289    let samples: Vec<f32> = match spec.sample_format {
290        SampleFormat::Float => reader.samples::<f32>().map(|s| s.unwrap()).collect(),
291        SampleFormat::Int => reader.samples::<i16>().map(|s| s.unwrap() as f32 / i16::MAX as f32).collect(),
292    };
293
294    let indices: Vec<usize> = (0..samples.len()).step_by(hop).take(block_length).collect();
295    let blocks: Vec<Vec<f32>> = indices
296        .into_par_iter()
297        .map(|i| {
298            let end = (i + frame_length).min(samples.len());
299            let mut block = Vec::with_capacity(frame_length);
300            block.extend_from_slice(&samples[i..end]);
301            block.resize(frame_length, 0.0);
302            block
303        })
304        .collect();
305
306    Ok(blocks.into_iter())
307}
308
309/// Streams WAV sample blocks lazily with parallel chunk processing.
310///
311/// Processes WAV data incrementally in a separate thread, generating blocks in parallel
312/// within chunks to minimize memory footprint.
313///
314/// # Parameters
315/// - `path`: WAV file path (`AsRef<Path>`).
316/// - `block_length`: Maximum block count.
317/// - `frame_length`: Samples per block.
318/// - `hop_length`: Step size between blocks; `None` uses `frame_length`.
319///
320/// # Returns
321/// - `Ok(Receiver<Vec<f32>>)`: Channel receiver for blocks.
322/// - `Err(AudioError)`: I/O or streaming error.
323///
324/// # Errors
325/// - `AudioError::FileNotFound`: The specified file does not exist
326/// - `AudioError::OpenError`: Invalid WAV file or corrupted header
327/// - `AudioError::StreamError`: Channel communication failure
328///
329/// # Example
330/// ```
331/// use crate::core::stream_lazy;
332/// let rx = stream_lazy("audio.wav", 1000, 1024, Some(512))?;
333/// while let Ok(block) = rx.recv() {
334///     // Process each 1024-sample block with 50% overlap
335///     println!("Received block of {} samples", block.len());
336/// }
337/// # Ok::<(), crate::core::AudioError>(())
338/// ```
339///
340/// # Performance
341/// - Background thread for file reading
342/// - Memory-efficient for files >1GB
343pub fn stream_lazy<P: AsRef<Path>>(
344    path: P,
345    block_length: usize,
346    frame_length: usize,
347    hop_length: Option<usize>,
348) -> Result<Receiver<Vec<f32>>, AudioError> {
349    let path = path.as_ref();
350    if !path.exists() {
351        return Err(AudioError::FileNotFound(path.to_string_lossy().into_owned()));
352    }
353
354    let wav_data = std::fs::read(&path)?;
355    let mut reader = WavReader::new(Cursor::new(wav_data))?;
356    let spec = reader.spec();
357    let hop = hop_length.unwrap_or(frame_length);
358
359    let (tx, rx) = channel();
360    std::thread::spawn(move || {
361        let samples_iter: Box<dyn Iterator<Item = Result<f32, _>>> = match spec.sample_format {
362            SampleFormat::Float => Box::new(reader.samples::<f32>()),
363            SampleFormat::Int => Box::new(reader.samples::<i16>().map(|s| s.map(|v| v as f32 / i16::MAX as f32))),
364        };
365
366        let mut chunk = Vec::with_capacity(frame_length * block_length);
367        let mut block_count = 0;
368
369        for sample in samples_iter {
370            let sample = sample.unwrap_or(0.0);
371            chunk.push(sample);
372
373            if chunk.len() >= frame_length && (chunk.len() % hop == 0 || chunk.len() >= frame_length * block_length) {
374                let indices: Vec<usize> = (0..chunk.len())
375                    .step_by(hop)
376                    .take(block_length - block_count)
377                    .collect();
378                let drain_to = indices.last().map_or(0, |&i| (i + hop).min(chunk.len()));
379
380                let blocks: Vec<Vec<f32>> = indices
381                    .into_par_iter()
382                    .map(|i| {
383                        let end = (i + frame_length).min(chunk.len());
384                        let mut block = Vec::with_capacity(frame_length);
385                        block.extend_from_slice(&chunk[i..end]);
386                        block.resize(frame_length, 0.0);
387                        block
388                    })
389                    .collect();
390
391                for block in blocks {
392                    if tx.send(block).is_err() {
393                        return;
394                    }
395                    block_count += 1;
396                    if block_count >= block_length {
397                        return;
398                    }
399                }
400                chunk.drain(..drain_to);
401            }
402        }
403
404        if !chunk.is_empty() && block_count < block_length {
405            chunk.resize(frame_length, 0.0);
406            let _ = tx.send(chunk);
407        }
408    });
409
410    Ok(rx)
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416    use std::fs;
417
418    fn create_test_wav() -> AudioData {
419        AudioData::new(vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5], 44100, 1)
420    }
421
422    #[test]
423    fn test_load() {
424        let audio = create_test_wav();
425        export("test.wav", &audio).unwrap();
426        let loaded = load("test.wav", None, Some(true), None, None).unwrap();
427        assert_eq!(loaded.samples, audio.samples);
428        fs::remove_file("test.wav").unwrap();
429    }
430
431    #[test]
432    fn test_load_segment() {
433        let audio = create_test_wav();
434        export("test.wav", &audio).unwrap();
435        let loaded = load("test.wav", None, Some(true), Some(0.00004535147), Some(0.00004535148)).unwrap();
436        assert_eq!(loaded.samples, vec![0.1, 0.2]);
437        fs::remove_file("test.wav").unwrap();
438    }
439
440    #[test]
441    fn test_export() {
442        let audio = create_test_wav();
443        export("test.wav", &audio).unwrap();
444        let loaded = load("test.wav", None, Some(true), None, None).unwrap();
445        assert_eq!(loaded.samples, audio.samples);
446        fs::remove_file("test.wav").unwrap();
447    }
448
449    #[test]
450    fn test_stream() {
451        let audio = create_test_wav();
452        export("test.wav", &audio).unwrap();
453        let blocks: Vec<_> = stream("test.wav", 3, 2, Some(2)).unwrap().collect();
454        assert_eq!(blocks, vec![vec![0.0, 0.1], vec![0.2, 0.3], vec![0.4, 0.5]]);
455        fs::remove_file("test.wav").unwrap();
456    }
457
458    #[test]
459    fn test_stream_lazy() {
460        let audio = create_test_wav();
461        export("test.wav", &audio).unwrap();
462        let rx = stream_lazy("test.wav", 3, 2, Some(2)).unwrap();
463        let blocks: Vec<_> = rx.into_iter().collect();
464        assert_eq!(blocks, vec![vec![0.0, 0.1], vec![0.2, 0.3], vec![0.4, 0.5]]);
465        fs::remove_file("test.wav").unwrap();
466    }
467
468    #[test]
469    fn test_load_file_not_found() {
470        if std::path::Path::new("test.wav").exists() {
471            fs::remove_file("test.wav").unwrap();
472        }
473        let result = load("test.wav", None, Some(true), None, None);
474        assert!(result.is_err());
475        assert!(matches!(result.unwrap_err(), AudioError::FileNotFound(_)));
476    }
477
478    #[test]
479    fn test_stream_file_not_found() {
480        if std::path::Path::new("test.wav").exists() {
481            fs::remove_file("test.wav").unwrap();
482        }
483        let result = stream("test.wav", 3, 2, Some(2));
484        assert!(result.is_err());
485        if let Err(e) = result {
486            assert!(matches!(e, AudioError::FileNotFound(_)));
487        }
488    }
489
490    #[test]
491    fn test_stream_lazy_file_not_found() {
492        if std::path::Path::new("test.wav").exists() {
493            fs::remove_file("test.wav").unwrap();
494        }
495        let result = stream_lazy("test.wav", 3, 2, Some(2));
496        assert!(result.is_err());
497        assert!(matches!(result.unwrap_err(), AudioError::FileNotFound(_)));
498    }
499}