1use hound::{WavReader, WavWriter, WavSpec, SampleFormat};
2use std::path::Path;
3use thiserror::Error;
4use crate::signal_processing::{to_mono, resample};
5use ndarray::ShapeError;
6use rayon::prelude::*;
7use std::sync::mpsc::{channel, Receiver};
8use std::io::Cursor;
9
10#[derive(Error, Debug)]
15pub enum AudioError {
16 #[error("WAV open failed: {0}")]
18 OpenError(#[from] hound::Error),
19
20 #[error("Unsupported WAV sample format")]
22 UnsupportedFormat,
23
24 #[error("Offset or duration out of bounds")]
26 InvalidRange,
27
28 #[error("I/O error: {0}")]
30 IoError(#[from] std::io::Error),
31
32 #[error("Hound processing error: {0}")]
34 HoundError(hound::Error),
35
36 #[error("Resampling error: {0}")]
38 ResampleError(#[from] crate::signal_processing::resampling::ResampleError),
39
40 #[error("Stream processing error")]
42 StreamError,
43
44 #[error("Shape mismatch: {0}")]
46 ShapeError(#[from] ShapeError),
47
48 #[error("Insufficient sample count: {0}")]
50 InsufficientData(String),
51
52 #[error("Invalid parameter: {0}")]
54 InvalidInput(String),
55
56 #[error("Computation error: {0}")]
58 ComputationFailed(String),
59
60 #[error("File not found: {0}")]
62 FileNotFound(String),
63}
64
65#[derive(Debug, Clone)]
75pub struct AudioData {
76 pub samples: Vec<f32>,
77 pub sample_rate: u32,
78 pub channels: u16,
79}
80
81impl AudioData {
82 pub fn new(samples: Vec<f32>, sample_rate: u32, channels: u16) -> Self {
105 Self { samples, sample_rate, channels }
106 }
107}
108
109pub fn load<P: AsRef<Path>>(
141 path: P,
142 sr: Option<u32>,
143 mono: Option<bool>,
144 offset: Option<f32>,
145 duration: Option<f32>,
146) -> Result<AudioData, AudioError> {
147 let path = path.as_ref();
148 if !path.exists() {
149 return Err(AudioError::FileNotFound(path.to_string_lossy().into_owned()));
150 }
151
152 let wav_data = std::fs::read(&path)?;
153 let mut reader = WavReader::new(Cursor::new(wav_data))?;
154 let spec = reader.spec();
155 let sample_rate = spec.sample_rate;
156
157 let start = (offset.unwrap_or(0.0) * sample_rate as f32) as usize;
158 let len = duration.map(|d| (d * sample_rate as f32) as usize);
159
160 let samples: Vec<f32> = match spec.sample_format {
161 SampleFormat::Float => reader.samples::<f32>()
162 .skip(start)
163 .take(len.unwrap_or(usize::MAX))
164 .map(|s| s.unwrap())
165 .collect(),
166 SampleFormat::Int => reader.samples::<i16>()
167 .skip(start)
168 .take(len.unwrap_or(usize::MAX))
169 .map(|s| s.unwrap() as f32 / i16::MAX as f32)
170 .collect(),
171 };
172
173 if start >= samples.len() && !samples.is_empty() {
174 return Err(AudioError::InvalidRange);
175 }
176
177 let mut samples = samples;
178 let channels = spec.channels as usize;
179 if channels > 1 && mono.unwrap_or(true) {
180 samples = to_mono(&samples, channels);
181 }
182
183 let final_samples = if let Some(target_samplerate) = sr {
184 if target_samplerate != sample_rate {
185 resample(&samples, sample_rate, target_samplerate)?
186 } else {
187 samples
188 }
189 } else {
190 samples
191 };
192
193 Ok(AudioData::new(final_samples, sr.unwrap_or(sample_rate), if mono.unwrap_or(true) { 1 } else { spec.channels }))
194}
195
196pub fn export<P: AsRef<Path>>(path: P, audio_data: &AudioData) -> Result<(), AudioError> {
224 let spec = WavSpec {
225 channels: audio_data.channels,
226 sample_rate: audio_data.sample_rate,
227 bits_per_sample: 32,
228 sample_format: SampleFormat::Float,
229 };
230
231 let mut buffer = Vec::with_capacity(audio_data.samples.len() * 4 + 44); let mut writer = WavWriter::new(Cursor::new(&mut buffer), spec)?;
233 for &sample in &audio_data.samples {
234 writer.write_sample(sample)?;
235 }
236 writer.finalize()?;
237 std::fs::write(path, buffer)?;
238 Ok(())
239}
240
241pub fn stream<P: AsRef<Path>>(
274 path: P,
275 block_length: usize,
276 frame_length: usize,
277 hop_length: Option<usize>,
278) -> Result<impl Iterator<Item = Vec<f32>>, AudioError> {
279 let path = path.as_ref();
280 if !path.exists() {
281 return Err(AudioError::FileNotFound(path.to_string_lossy().into_owned()));
282 }
283
284 let wav_data = std::fs::read(&path)?;
285 let mut reader = WavReader::new(Cursor::new(wav_data))?;
286 let spec = reader.spec();
287 let hop = hop_length.unwrap_or(frame_length);
288
289 let samples: Vec<f32> = match spec.sample_format {
290 SampleFormat::Float => reader.samples::<f32>().map(|s| s.unwrap()).collect(),
291 SampleFormat::Int => reader.samples::<i16>().map(|s| s.unwrap() as f32 / i16::MAX as f32).collect(),
292 };
293
294 let indices: Vec<usize> = (0..samples.len()).step_by(hop).take(block_length).collect();
295 let blocks: Vec<Vec<f32>> = indices
296 .into_par_iter()
297 .map(|i| {
298 let end = (i + frame_length).min(samples.len());
299 let mut block = Vec::with_capacity(frame_length);
300 block.extend_from_slice(&samples[i..end]);
301 block.resize(frame_length, 0.0);
302 block
303 })
304 .collect();
305
306 Ok(blocks.into_iter())
307}
308
309pub fn stream_lazy<P: AsRef<Path>>(
344 path: P,
345 block_length: usize,
346 frame_length: usize,
347 hop_length: Option<usize>,
348) -> Result<Receiver<Vec<f32>>, AudioError> {
349 let path = path.as_ref();
350 if !path.exists() {
351 return Err(AudioError::FileNotFound(path.to_string_lossy().into_owned()));
352 }
353
354 let wav_data = std::fs::read(&path)?;
355 let mut reader = WavReader::new(Cursor::new(wav_data))?;
356 let spec = reader.spec();
357 let hop = hop_length.unwrap_or(frame_length);
358
359 let (tx, rx) = channel();
360 std::thread::spawn(move || {
361 let samples_iter: Box<dyn Iterator<Item = Result<f32, _>>> = match spec.sample_format {
362 SampleFormat::Float => Box::new(reader.samples::<f32>()),
363 SampleFormat::Int => Box::new(reader.samples::<i16>().map(|s| s.map(|v| v as f32 / i16::MAX as f32))),
364 };
365
366 let mut chunk = Vec::with_capacity(frame_length * block_length);
367 let mut block_count = 0;
368
369 for sample in samples_iter {
370 let sample = sample.unwrap_or(0.0);
371 chunk.push(sample);
372
373 if chunk.len() >= frame_length && (chunk.len() % hop == 0 || chunk.len() >= frame_length * block_length) {
374 let indices: Vec<usize> = (0..chunk.len())
375 .step_by(hop)
376 .take(block_length - block_count)
377 .collect();
378 let drain_to = indices.last().map_or(0, |&i| (i + hop).min(chunk.len()));
379
380 let blocks: Vec<Vec<f32>> = indices
381 .into_par_iter()
382 .map(|i| {
383 let end = (i + frame_length).min(chunk.len());
384 let mut block = Vec::with_capacity(frame_length);
385 block.extend_from_slice(&chunk[i..end]);
386 block.resize(frame_length, 0.0);
387 block
388 })
389 .collect();
390
391 for block in blocks {
392 if tx.send(block).is_err() {
393 return;
394 }
395 block_count += 1;
396 if block_count >= block_length {
397 return;
398 }
399 }
400 chunk.drain(..drain_to);
401 }
402 }
403
404 if !chunk.is_empty() && block_count < block_length {
405 chunk.resize(frame_length, 0.0);
406 let _ = tx.send(chunk);
407 }
408 });
409
410 Ok(rx)
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416 use std::fs;
417
418 fn create_test_wav() -> AudioData {
419 AudioData::new(vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5], 44100, 1)
420 }
421
422 #[test]
423 fn test_load() {
424 let audio = create_test_wav();
425 export("test.wav", &audio).unwrap();
426 let loaded = load("test.wav", None, Some(true), None, None).unwrap();
427 assert_eq!(loaded.samples, audio.samples);
428 fs::remove_file("test.wav").unwrap();
429 }
430
431 #[test]
432 fn test_load_segment() {
433 let audio = create_test_wav();
434 export("test.wav", &audio).unwrap();
435 let loaded = load("test.wav", None, Some(true), Some(0.00004535147), Some(0.00004535148)).unwrap();
436 assert_eq!(loaded.samples, vec![0.1, 0.2]);
437 fs::remove_file("test.wav").unwrap();
438 }
439
440 #[test]
441 fn test_export() {
442 let audio = create_test_wav();
443 export("test.wav", &audio).unwrap();
444 let loaded = load("test.wav", None, Some(true), None, None).unwrap();
445 assert_eq!(loaded.samples, audio.samples);
446 fs::remove_file("test.wav").unwrap();
447 }
448
449 #[test]
450 fn test_stream() {
451 let audio = create_test_wav();
452 export("test.wav", &audio).unwrap();
453 let blocks: Vec<_> = stream("test.wav", 3, 2, Some(2)).unwrap().collect();
454 assert_eq!(blocks, vec![vec![0.0, 0.1], vec![0.2, 0.3], vec![0.4, 0.5]]);
455 fs::remove_file("test.wav").unwrap();
456 }
457
458 #[test]
459 fn test_stream_lazy() {
460 let audio = create_test_wav();
461 export("test.wav", &audio).unwrap();
462 let rx = stream_lazy("test.wav", 3, 2, Some(2)).unwrap();
463 let blocks: Vec<_> = rx.into_iter().collect();
464 assert_eq!(blocks, vec![vec![0.0, 0.1], vec![0.2, 0.3], vec![0.4, 0.5]]);
465 fs::remove_file("test.wav").unwrap();
466 }
467
468 #[test]
469 fn test_load_file_not_found() {
470 if std::path::Path::new("test.wav").exists() {
471 fs::remove_file("test.wav").unwrap();
472 }
473 let result = load("test.wav", None, Some(true), None, None);
474 assert!(result.is_err());
475 assert!(matches!(result.unwrap_err(), AudioError::FileNotFound(_)));
476 }
477
478 #[test]
479 fn test_stream_file_not_found() {
480 if std::path::Path::new("test.wav").exists() {
481 fs::remove_file("test.wav").unwrap();
482 }
483 let result = stream("test.wav", 3, 2, Some(2));
484 assert!(result.is_err());
485 if let Err(e) = result {
486 assert!(matches!(e, AudioError::FileNotFound(_)));
487 }
488 }
489
490 #[test]
491 fn test_stream_lazy_file_not_found() {
492 if std::path::Path::new("test.wav").exists() {
493 fs::remove_file("test.wav").unwrap();
494 }
495 let result = stream_lazy("test.wav", 3, 2, Some(2));
496 assert!(result.is_err());
497 assert!(matches!(result.unwrap_err(), AudioError::FileNotFound(_)));
498 }
499}