1use std::fs::File;
2use std::path::Path;
3
4use num::Float;
5use symphonia::core::audio::SampleBuffer;
6use symphonia::core::codecs::{CODEC_TYPE_NULL, DecoderOptions};
7use symphonia::core::errors::Error;
8use symphonia::core::formats::{FormatOptions, SeekMode, SeekTo};
9use symphonia::core::io::MediaSourceStream;
10use symphonia::core::meta::MetadataOptions;
11use symphonia::core::probe::Hint;
12use thiserror::Error;
13
14use crate::resample::resample;
15
16#[derive(Debug, Clone)]
18pub struct Audio<F> {
19 pub samples_interleaved: Vec<F>,
21 pub sample_rate: u32,
23 pub num_channels: u16,
25}
26
27#[derive(Debug, Error)]
28pub enum AudioReadError {
29 #[error("could not read file")]
30 FileError(#[from] std::io::Error),
31 #[error("could not decode audio")]
32 EncodingError(#[from] symphonia::core::errors::Error),
33 #[error("could not find track in file")]
34 NoTrack,
35 #[error("could not find sample rate in file")]
36 NoSampleRate,
37 #[error("end frame {0} is larger than start frame {1}")]
38 EndFrameLargerThanStartFrame(usize, usize),
39 #[error("start channel {0} invalid, audio file has only {1} channels")]
40 InvalidStartChannel(usize, usize),
41 #[error("invalid number of channels to extract: {0}")]
42 InvalidNumChannels(usize),
43}
44
45#[derive(Default, Debug, Clone, Copy)]
47pub enum Position {
48 #[default]
50 Default,
51 Time(std::time::Duration),
53 Frame(usize),
55}
56
57#[derive(Default)]
58pub struct AudioReadConfig {
59 pub start: Position,
61 pub stop: Position,
63 pub start_channel: Option<usize>,
65 pub num_channels: Option<usize>,
67 pub sample_rate: Option<u32>,
69}
70
71pub fn audio_read<F: Float + rubato::Sample>(
72 path: impl AsRef<Path>,
73 config: AudioReadConfig,
74) -> Result<Audio<F>, AudioReadError> {
75 let src = File::open(path.as_ref())?;
76 let mss = MediaSourceStream::new(Box::new(src), Default::default());
77
78 let mut hint = Hint::new();
79 if let Some(ext) = path.as_ref().extension()
80 && let Some(ext_str) = ext.to_str()
81 {
82 hint.with_extension(ext_str);
83 }
84
85 let meta_opts: MetadataOptions = Default::default();
86 let fmt_opts: FormatOptions = Default::default();
87
88 let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
89
90 let mut format = probed.format;
91
92 let track = format
93 .tracks()
94 .iter()
95 .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
96 .ok_or(AudioReadError::NoTrack)?;
97
98 let sample_rate = track
99 .codec_params
100 .sample_rate
101 .ok_or(AudioReadError::NoSampleRate)?;
102
103 let track_id = track.id;
104
105 let codec_params = track.codec_params.clone();
107 let time_base = track.codec_params.time_base;
108
109 let start_frame = match config.start {
111 Position::Default => 0,
112 Position::Time(duration) => {
113 let secs = duration.as_secs_f64();
114 (secs * sample_rate as f64) as usize
115 }
116 Position::Frame(frame) => frame,
117 };
118
119 let end_frame: Option<usize> = match config.stop {
120 Position::Default => None,
121 Position::Time(duration) => {
122 let secs = duration.as_secs_f64();
123 Some((secs * sample_rate as f64) as usize)
124 }
125 Position::Frame(frame) => Some(frame),
126 };
127
128 if let Some(end_frame) = end_frame
129 && start_frame > end_frame
130 {
131 return Err(AudioReadError::EndFrameLargerThanStartFrame(
132 end_frame,
133 start_frame,
134 ));
135 }
136
137 if start_frame > sample_rate as usize
139 && let Some(tb) = time_base
140 {
141 let seek_sample = (start_frame as f64 * 0.9) as u64;
143 let seek_ts = (seek_sample * tb.denom as u64) / (sample_rate as u64);
144
145 let _ = format.seek(
147 SeekMode::Accurate,
148 SeekTo::TimeStamp {
149 ts: seek_ts,
150 track_id,
151 },
152 );
153 }
154
155 let dec_opts: DecoderOptions = Default::default();
156 let mut decoder = symphonia::default::get_codecs().make(&codec_params, &dec_opts)?;
157
158 let mut sample_buf = None;
159 let mut samples = Vec::new();
160 let mut num_channels = 0usize;
161 let start_channel = config.start_channel;
162
163 let mut current_sample: Option<u64> = None;
165
166 loop {
167 let packet = match format.next_packet() {
168 Ok(packet) => packet,
169 Err(Error::ResetRequired) => {
170 decoder.reset();
171 continue;
172 }
173 Err(Error::IoError(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
174 break;
175 }
176 Err(err) => return Err(err.into()),
177 };
178
179 if packet.track_id() != track_id {
180 continue;
181 }
182
183 let decoded = decoder.decode(&packet)?;
184
185 if current_sample.is_none() {
187 let ts = packet.ts();
188 if let Some(tb) = time_base {
189 current_sample = Some((ts * sample_rate as u64) / tb.denom as u64);
191 } else {
192 current_sample = Some(0);
193 }
194 }
195
196 if sample_buf.is_none() {
197 let spec = *decoded.spec();
198 let duration = decoded.capacity() as u64;
199 sample_buf = Some(SampleBuffer::<f32>::new(duration, spec));
200
201 num_channels = spec.channels.count();
203
204 let ch_start = start_channel.unwrap_or(0);
206 let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
207
208 if ch_start >= num_channels {
209 return Err(AudioReadError::InvalidStartChannel(ch_start, num_channels));
210 }
211 if ch_count == 0 {
212 return Err(AudioReadError::InvalidNumChannels(0));
213 }
214 if ch_start + ch_count > num_channels {
215 return Err(AudioReadError::InvalidNumChannels(ch_count));
216 }
217 }
218
219 if let Some(buf) = &mut sample_buf {
220 buf.copy_interleaved_ref(decoded);
221 let packet_samples = buf.samples();
222
223 let mut pos = current_sample.unwrap_or(0);
224
225 let ch_start = start_channel.unwrap_or(0);
227 let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
228 let ch_end = ch_start + ch_count;
229
230 let frames = packet_samples.len() / num_channels;
232
233 for frame_idx in 0..frames {
235 if let Some(end) = end_frame
237 && pos >= end as u64
238 {
239 return Ok(Audio {
240 samples_interleaved: samples,
241 sample_rate,
242 num_channels: ch_count as u16,
243 });
244 }
245
246 if pos >= start_frame as u64 {
248 for ch in ch_start..ch_end {
251 let sample_idx = frame_idx * num_channels + ch;
252 samples.push(F::from(packet_samples[sample_idx]).unwrap());
253 }
254 }
255
256 pos += 1;
257 }
258
259 current_sample = Some(pos);
261 }
262 }
263
264 let samples = if let Some(sr_out) = config.sample_rate {
265 resample(&samples, num_channels, sample_rate, sr_out).unwrap()
266 } else {
267 samples
268 };
269
270 let ch_start = start_channel.unwrap_or(0);
271 let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
272
273 Ok(Audio {
274 samples_interleaved: samples,
275 sample_rate,
276 num_channels: ch_count as u16,
277 })
278}
279
280#[cfg(feature = "audio-blocks")]
281pub fn audio_read_block<F: num::Float + 'static + rubato::Sample>(
282 path: impl AsRef<Path>,
283 config: AudioReadConfig,
284) -> Result<(audio_blocks::AudioBlockInterleaved<F>, u32), AudioReadError> {
285 let audio = audio_read(path, config)?;
286 Ok((
287 audio_blocks::AudioBlockInterleaved::from_slice(
288 &audio.samples_interleaved,
289 audio.num_channels,
290 ),
291 audio.sample_rate,
292 ))
293}
294
295#[cfg(test)]
296mod tests {
297 use std::time::Duration;
298
299 use audio_blocks::{AudioBlock, AudioBlockInterleavedView};
300
301 use super::*;
302
303 fn to_block<F: num::Float + 'static>(audio: &Audio<F>) -> AudioBlockInterleavedView<'_, F> {
304 AudioBlockInterleavedView::from_slice(&audio.samples_interleaved, audio.num_channels)
305 }
306
307 #[test]
313 fn test_sine_wave_data_integrity() {
314 const SAMPLE_RATE: f64 = 48000.0;
315 const N_SAMPLES: usize = 48000;
316 const FREQUENCIES: [f64; 4] = [440.0, 554.37, 659.25, 880.0];
317
318 let audio =
319 audio_read::<f32>("test_data/test_4ch.wav", AudioReadConfig::default()).unwrap();
320 let block = to_block(&audio);
321
322 assert_eq!(audio.sample_rate, 48000);
323 assert_eq!(block.num_frames(), N_SAMPLES);
324 assert_eq!(block.num_channels(), 4);
325
326 for (ch, &freq) in FREQUENCIES.iter().enumerate() {
328 for frame in 0..N_SAMPLES {
329 let expected =
330 (2.0 * std::f64::consts::PI * freq * frame as f64 / SAMPLE_RATE).sin() as f32;
331 let actual = block.sample(ch as u16, frame);
332 assert!(
333 (actual - expected).abs() < 1e-4,
334 "Mismatch at channel {ch}, frame {frame}: expected {expected}, got {actual}"
335 );
336 }
337 }
338
339 let audio = audio_read::<f32>(
341 "test_data/test_4ch.wav",
342 AudioReadConfig {
343 start: Position::Frame(24000),
344 stop: Position::Frame(24100),
345 ..Default::default()
346 },
347 )
348 .unwrap();
349 let block = to_block(&audio);
350
351 for (ch, &freq) in FREQUENCIES.iter().enumerate() {
352 for frame in 0..100 {
353 let actual_frame = 24000 + frame;
354 let expected = (2.0 * std::f64::consts::PI * freq * actual_frame as f64
355 / SAMPLE_RATE)
356 .sin() as f32;
357 let actual = block.sample(ch as u16, frame);
358 assert!(
359 (actual - expected).abs() < 1e-4,
360 "Offset mismatch at channel {ch}, frame {actual_frame}: expected {expected}, got {actual}"
361 );
362 }
363 }
364 }
365
366 #[test]
367 fn test_samples_selection() {
368 let audio1 =
369 audio_read::<f32>("test_data/test_1ch.wav", AudioReadConfig::default()).unwrap();
370 let block1 = to_block(&audio1);
371 assert_eq!(audio1.sample_rate, 48000);
372 assert_eq!(block1.num_frames(), 48000);
373 assert_eq!(block1.num_channels(), 1);
374
375 let audio2 = audio_read::<f32>(
376 "test_data/test_1ch.wav",
377 AudioReadConfig {
378 start: Position::Frame(1100),
379 stop: Position::Frame(1200),
380 ..Default::default()
381 },
382 )
383 .unwrap();
384 let block2 = to_block(&audio2);
385 assert_eq!(audio2.sample_rate, 48000);
386 assert_eq!(block2.num_frames(), 100);
387 assert_eq!(block2.num_channels(), 1);
388 assert_eq!(block1.raw_data()[1100..1200], block2.raw_data()[..]);
389 }
390
391 #[test]
392 fn test_time_selection() {
393 let audio1 =
394 audio_read::<f32>("test_data/test_1ch.wav", AudioReadConfig::default()).unwrap();
395 let block1 = to_block(&audio1);
396 assert_eq!(audio1.sample_rate, 48000);
397 assert_eq!(block1.num_frames(), 48000);
398 assert_eq!(block1.num_channels(), 1);
399
400 let audio2 = audio_read::<f32>(
401 "test_data/test_1ch.wav",
402 AudioReadConfig {
403 start: Position::Time(Duration::from_secs_f32(0.5)),
404 stop: Position::Time(Duration::from_secs_f32(0.6)),
405 ..Default::default()
406 },
407 )
408 .unwrap();
409 let block2 = to_block(&audio2);
410
411 assert_eq!(audio2.sample_rate, 48000);
412 assert_eq!(block2.num_frames(), 4800);
413 assert_eq!(block2.num_channels(), 1);
414 assert_eq!(block1.raw_data()[24000..28800], block2.raw_data()[..]);
415 }
416
417 #[test]
418 fn test_channel_selection() {
419 let audio1 =
420 audio_read::<f32>("test_data/test_4ch.wav", AudioReadConfig::default()).unwrap();
421 let block1 = to_block(&audio1);
422 assert_eq!(audio1.sample_rate, 48000);
423 assert_eq!(block1.num_frames(), 48000);
424 assert_eq!(block1.num_channels(), 4);
425
426 let audio2 = audio_read::<f32>(
427 "test_data/test_4ch.wav",
428 AudioReadConfig {
429 start_channel: Some(1),
430 num_channels: Some(2),
431 ..Default::default()
432 },
433 )
434 .unwrap();
435 let block2 = to_block(&audio2);
436
437 assert_eq!(audio2.sample_rate, 48000);
438 assert_eq!(block2.num_frames(), 48000);
439 assert_eq!(block2.num_channels(), 2);
440
441 for frame in 0..10 {
443 assert_eq!(block2.sample(0, frame), block1.sample(1, frame));
444 assert_eq!(block2.sample(1, frame), block1.sample(2, frame));
445 }
446 }
447
448 #[test]
449 fn test_fail_selection() {
450 match audio_read::<f32>(
451 "test_data/test_1ch.wav",
452 AudioReadConfig {
453 start: Position::Frame(100),
454 stop: Position::Frame(99),
455 ..Default::default()
456 },
457 ) {
458 Err(AudioReadError::EndFrameLargerThanStartFrame(_, _)) => (),
459 _ => panic!(),
460 }
461
462 match audio_read::<f32>(
463 "test_data/test_1ch.wav",
464 AudioReadConfig {
465 start: Position::Time(Duration::from_secs_f32(0.6)),
466 stop: Position::Time(Duration::from_secs_f32(0.5)),
467 ..Default::default()
468 },
469 ) {
470 Err(AudioReadError::EndFrameLargerThanStartFrame(_, _)) => (),
471 _ => panic!(),
472 }
473
474 match audio_read::<f32>(
475 "test_data/test_1ch.wav",
476 AudioReadConfig {
477 start_channel: Some(1),
478 ..Default::default()
479 },
480 ) {
481 Err(AudioReadError::InvalidStartChannel(_, _)) => (),
482 _ => panic!(),
483 }
484
485 match audio_read::<f32>(
486 "test_data/test_1ch.wav",
487 AudioReadConfig {
488 num_channels: Some(0),
489 ..Default::default()
490 },
491 ) {
492 Err(AudioReadError::InvalidNumChannels(0)) => (),
493 _ => panic!(),
494 }
495
496 match audio_read::<f32>(
497 "test_data/test_1ch.wav",
498 AudioReadConfig {
499 num_channels: Some(2),
500 ..Default::default()
501 },
502 ) {
503 Err(AudioReadError::InvalidNumChannels(2)) => (),
504 _ => panic!(),
505 }
506 }
507
508 #[test]
509 fn test_resample_preserves_frequency() {
510 const FREQUENCIES: [f64; 4] = [440.0, 554.37, 659.25, 880.0];
511 let sr_out: u32 = 22050;
512
513 let audio = audio_read::<f32>(
515 "test_data/test_4ch.wav",
516 AudioReadConfig {
517 sample_rate: Some(sr_out),
518 ..Default::default()
519 },
520 )
521 .unwrap();
522 let block = to_block(&audio);
523
524 assert_eq!(audio.sample_rate, 48000); assert_eq!(block.num_channels(), 4);
526
527 let expected_frames = 22050;
529 assert_eq!(
530 block.num_frames(),
531 expected_frames,
532 "Expected {} frames, got {}",
533 expected_frames,
534 block.num_frames()
535 );
536
537 let start_frame = 100;
540 let test_frames = 1000;
541
542 for (ch, &freq) in FREQUENCIES.iter().enumerate() {
543 let mut max_error: f32 = 0.0;
544 for frame in start_frame..(start_frame + test_frames) {
545 let expected =
546 (2.0 * std::f64::consts::PI * freq * frame as f64 / sr_out as f64).sin() as f32;
547 let actual = block.sample(ch as u16, frame);
548 let error = (actual - expected).abs();
549 max_error = max_error.max(error);
550 }
551 assert!(
552 max_error < 0.02,
553 "Channel {} ({}Hz): max error {} exceeds threshold",
554 ch,
555 freq,
556 max_error
557 );
558 }
559 }
560}