1use std::fs::File;
2use std::path::Path;
3
4use num::Float;
5use symphonia::core::audio::SampleBuffer;
6use symphonia::core::codecs::{CODEC_TYPE_NULL, DecoderOptions};
7use symphonia::core::errors::Error;
8use symphonia::core::formats::{FormatOptions, SeekMode, SeekTo};
9use symphonia::core::io::MediaSourceStream;
10use symphonia::core::meta::MetadataOptions;
11use symphonia::core::probe::Hint;
12use thiserror::Error;
13
14use crate::resample::resample;
15
16#[derive(Debug, Clone)]
18pub struct Audio<F> {
19 pub samples_interleaved: Vec<F>,
21 pub sample_rate: u32,
23 pub num_channels: u16,
25}
26
27#[derive(Debug, Error)]
28pub enum AudioReadError {
29 #[error("could not read file")]
30 FileError(#[from] std::io::Error),
31 #[error("could not decode audio")]
32 EncodingError(#[from] symphonia::core::errors::Error),
33 #[error("could not find track in file")]
34 NoTrack,
35 #[error("could not find sample rate in file")]
36 NoSampleRate,
37 #[error("end frame {0} is larger than start frame {1}")]
38 EndFrameLargerThanStartFrame(usize, usize),
39 #[error("start channel {0} invalid, audio file has only {1} channels")]
40 InvalidStartChannel(usize, usize),
41 #[error("invalid number of channels to extract: {0}")]
42 InvalidNumChannels(usize),
43}
44
45#[derive(Default, Debug, Clone, Copy)]
47pub enum Position {
48 #[default]
50 Default,
51 Time(std::time::Duration),
53 Frame(usize),
55}
56
57#[derive(Default)]
58pub struct AudioReadConfig {
59 pub start: Position,
61 pub stop: Position,
63 pub start_channel: Option<usize>,
65 pub num_channels: Option<usize>,
67 pub sample_rate: Option<u32>,
69}
70
71pub fn audio_read<F: Float + rubato::Sample>(
72 path: impl AsRef<Path>,
73 config: AudioReadConfig,
74) -> Result<Audio<F>, AudioReadError> {
75 let src = File::open(path.as_ref())?;
76 let mss = MediaSourceStream::new(Box::new(src), Default::default());
77
78 let mut hint = Hint::new();
79 if let Some(ext) = path.as_ref().extension()
80 && let Some(ext_str) = ext.to_str()
81 {
82 hint.with_extension(ext_str);
83 }
84
85 let meta_opts: MetadataOptions = Default::default();
86 let fmt_opts: FormatOptions = Default::default();
87
88 let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
89
90 let mut format = probed.format;
91
92 let track = format
93 .tracks()
94 .iter()
95 .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
96 .ok_or(AudioReadError::NoTrack)?;
97
98 let sample_rate = track
99 .codec_params
100 .sample_rate
101 .ok_or(AudioReadError::NoSampleRate)?;
102
103 let track_id = track.id;
104
105 let codec_params = track.codec_params.clone();
107 let time_base = track.codec_params.time_base;
108
109 let start_frame = match config.start {
111 Position::Default => 0,
112 Position::Time(duration) => {
113 let secs = duration.as_secs_f64();
114 (secs * sample_rate as f64) as usize
115 }
116 Position::Frame(frame) => frame,
117 };
118
119 let end_frame: Option<usize> = match config.stop {
120 Position::Default => None,
121 Position::Time(duration) => {
122 let secs = duration.as_secs_f64();
123 Some((secs * sample_rate as f64) as usize)
124 }
125 Position::Frame(frame) => Some(frame),
126 };
127
128 if let Some(end_frame) = end_frame
129 && start_frame > end_frame
130 {
131 return Err(AudioReadError::EndFrameLargerThanStartFrame(
132 end_frame,
133 start_frame,
134 ));
135 }
136
137 if start_frame > sample_rate as usize
143 && let Some(tb) = time_base
144 {
145 let seek_sample = (start_frame as f64 * 0.9) as u64;
147 let seek_ts = (seek_sample * tb.denom as u64) / (sample_rate as u64);
148
149 let _ = format.seek(
151 SeekMode::Accurate,
152 SeekTo::TimeStamp {
153 ts: seek_ts,
154 track_id,
155 },
156 );
157 }
158
159 let dec_opts: DecoderOptions = Default::default();
160 let mut decoder = symphonia::default::get_codecs().make(&codec_params, &dec_opts)?;
161
162 let mut sample_buf = None;
163 let mut samples = Vec::new();
164 let mut num_channels = 0usize;
165 let start_channel = config.start_channel;
166
167 let mut current_sample: Option<u64> = None;
169
170 loop {
171 let packet = match format.next_packet() {
172 Ok(packet) => packet,
173 Err(Error::ResetRequired) => {
174 decoder.reset();
175 continue;
176 }
177 Err(Error::IoError(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
178 break;
179 }
180 Err(err) => return Err(err.into()),
181 };
182
183 if packet.track_id() != track_id {
184 continue;
185 }
186
187 let decoded = decoder.decode(&packet)?;
188
189 if current_sample.is_none() {
191 let ts = packet.ts();
192 if let Some(tb) = time_base {
193 current_sample = Some((ts * sample_rate as u64) / tb.denom as u64);
195 } else {
196 current_sample = Some(0);
197 }
198 }
199
200 if sample_buf.is_none() {
201 let spec = *decoded.spec();
202 let duration = decoded.capacity() as u64;
203 sample_buf = Some(SampleBuffer::<f32>::new(duration, spec));
204
205 num_channels = spec.channels.count();
207
208 let ch_start = start_channel.unwrap_or(0);
210 let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
211
212 if ch_start >= num_channels {
213 return Err(AudioReadError::InvalidStartChannel(ch_start, num_channels));
214 }
215 if ch_count == 0 {
216 return Err(AudioReadError::InvalidNumChannels(0));
217 }
218 if ch_start + ch_count > num_channels {
219 return Err(AudioReadError::InvalidNumChannels(ch_count));
220 }
221 }
222
223 if let Some(buf) = &mut sample_buf {
224 buf.copy_interleaved_ref(decoded);
225 let packet_samples = buf.samples();
226
227 let mut pos = current_sample.unwrap_or(0);
228
229 let ch_start = start_channel.unwrap_or(0);
231 let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
232 let ch_end = ch_start + ch_count;
233
234 let frames = packet_samples.len() / num_channels;
236
237 for frame_idx in 0..frames {
239 if let Some(end) = end_frame
241 && pos >= end as u64
242 {
243 return Ok(Audio {
244 samples_interleaved: samples,
245 sample_rate,
246 num_channels: ch_count as u16,
247 });
248 }
249
250 if pos >= start_frame as u64 {
252 for ch in ch_start..ch_end {
255 let sample_idx = frame_idx * num_channels + ch;
256 samples.push(F::from(packet_samples[sample_idx]).unwrap());
257 }
258 }
259
260 pos += 1;
261 }
262
263 current_sample = Some(pos);
265 }
266 }
267
268 let ch_start = start_channel.unwrap_or(0);
270 let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
271
272 let samples = if let Some(sr_out) = config.sample_rate {
273 resample(&samples, ch_count, sample_rate, sr_out).map_err(|_| AudioReadError::NoTrack)?
275 } else {
276 samples
277 };
278
279 let actual_sample_rate = config.sample_rate.unwrap_or(sample_rate);
281
282 Ok(Audio {
283 samples_interleaved: samples,
284 sample_rate: actual_sample_rate,
285 num_channels: ch_count as u16,
286 })
287}
288
289#[cfg(feature = "audio-blocks")]
290pub fn audio_read_block<F: num::Float + 'static + rubato::Sample>(
291 path: impl AsRef<Path>,
292 config: AudioReadConfig,
293) -> Result<(audio_blocks::AudioBlockInterleaved<F>, u32), AudioReadError> {
294 let audio = audio_read(path, config)?;
295 Ok((
296 audio_blocks::AudioBlockInterleaved::from_slice(
297 &audio.samples_interleaved,
298 audio.num_channels,
299 ),
300 audio.sample_rate,
301 ))
302}
303
304#[cfg(test)]
305mod tests {
306 use std::time::Duration;
307
308 use audio_blocks::{AudioBlock, AudioBlockInterleavedView};
309
310 use super::*;
311
312 fn to_block<F: num::Float + 'static>(audio: &Audio<F>) -> AudioBlockInterleavedView<'_, F> {
313 AudioBlockInterleavedView::from_slice(&audio.samples_interleaved, audio.num_channels)
314 }
315
316 #[test]
322 fn test_sine_wave_data_integrity() {
323 const SAMPLE_RATE: f64 = 48000.0;
324 const N_SAMPLES: usize = 48000;
325 const FREQUENCIES: [f64; 4] = [440.0, 554.37, 659.25, 880.0];
326
327 let audio =
328 audio_read::<f32>("test_data/test_4ch.wav", AudioReadConfig::default()).unwrap();
329 let block = to_block(&audio);
330
331 assert_eq!(audio.sample_rate, 48000);
332 assert_eq!(block.num_frames(), N_SAMPLES);
333 assert_eq!(block.num_channels(), 4);
334
335 for (ch, &freq) in FREQUENCIES.iter().enumerate() {
337 for frame in 0..N_SAMPLES {
338 let expected =
339 (2.0 * std::f64::consts::PI * freq * frame as f64 / SAMPLE_RATE).sin() as f32;
340 let actual = block.sample(ch as u16, frame);
341 assert!(
342 (actual - expected).abs() < 1e-4,
343 "Mismatch at channel {ch}, frame {frame}: expected {expected}, got {actual}"
344 );
345 }
346 }
347
348 let audio = audio_read::<f32>(
350 "test_data/test_4ch.wav",
351 AudioReadConfig {
352 start: Position::Frame(24000),
353 stop: Position::Frame(24100),
354 ..Default::default()
355 },
356 )
357 .unwrap();
358 let block = to_block(&audio);
359
360 for (ch, &freq) in FREQUENCIES.iter().enumerate() {
361 for frame in 0..100 {
362 let actual_frame = 24000 + frame;
363 let expected = (2.0 * std::f64::consts::PI * freq * actual_frame as f64
364 / SAMPLE_RATE)
365 .sin() as f32;
366 let actual = block.sample(ch as u16, frame);
367 assert!(
368 (actual - expected).abs() < 1e-4,
369 "Offset mismatch at channel {ch}, frame {actual_frame}: expected {expected}, got {actual}"
370 );
371 }
372 }
373 }
374
375 #[test]
376 fn test_samples_selection() {
377 let audio1 =
378 audio_read::<f32>("test_data/test_1ch.wav", AudioReadConfig::default()).unwrap();
379 let block1 = to_block(&audio1);
380 assert_eq!(audio1.sample_rate, 48000);
381 assert_eq!(block1.num_frames(), 48000);
382 assert_eq!(block1.num_channels(), 1);
383
384 let audio2 = audio_read::<f32>(
385 "test_data/test_1ch.wav",
386 AudioReadConfig {
387 start: Position::Frame(1100),
388 stop: Position::Frame(1200),
389 ..Default::default()
390 },
391 )
392 .unwrap();
393 let block2 = to_block(&audio2);
394 assert_eq!(audio2.sample_rate, 48000);
395 assert_eq!(block2.num_frames(), 100);
396 assert_eq!(block2.num_channels(), 1);
397 assert_eq!(block1.raw_data()[1100..1200], block2.raw_data()[..]);
398 }
399
400 #[test]
401 fn test_time_selection() {
402 let audio1 =
403 audio_read::<f32>("test_data/test_1ch.wav", AudioReadConfig::default()).unwrap();
404 let block1 = to_block(&audio1);
405 assert_eq!(audio1.sample_rate, 48000);
406 assert_eq!(block1.num_frames(), 48000);
407 assert_eq!(block1.num_channels(), 1);
408
409 let audio2 = audio_read::<f32>(
410 "test_data/test_1ch.wav",
411 AudioReadConfig {
412 start: Position::Time(Duration::from_secs_f32(0.5)),
413 stop: Position::Time(Duration::from_secs_f32(0.6)),
414 ..Default::default()
415 },
416 )
417 .unwrap();
418 let block2 = to_block(&audio2);
419
420 assert_eq!(audio2.sample_rate, 48000);
421 assert_eq!(block2.num_frames(), 4800);
422 assert_eq!(block2.num_channels(), 1);
423 assert_eq!(block1.raw_data()[24000..28800], block2.raw_data()[..]);
424 }
425
426 #[test]
427 fn test_channel_selection() {
428 let audio1 =
429 audio_read::<f32>("test_data/test_4ch.wav", AudioReadConfig::default()).unwrap();
430 let block1 = to_block(&audio1);
431 assert_eq!(audio1.sample_rate, 48000);
432 assert_eq!(block1.num_frames(), 48000);
433 assert_eq!(block1.num_channels(), 4);
434
435 let audio2 = audio_read::<f32>(
436 "test_data/test_4ch.wav",
437 AudioReadConfig {
438 start_channel: Some(1),
439 num_channels: Some(2),
440 ..Default::default()
441 },
442 )
443 .unwrap();
444 let block2 = to_block(&audio2);
445
446 assert_eq!(audio2.sample_rate, 48000);
447 assert_eq!(block2.num_frames(), 48000);
448 assert_eq!(block2.num_channels(), 2);
449
450 for frame in 0..10 {
452 assert_eq!(block2.sample(0, frame), block1.sample(1, frame));
453 assert_eq!(block2.sample(1, frame), block1.sample(2, frame));
454 }
455 }
456
457 #[test]
458 fn test_fail_selection() {
459 match audio_read::<f32>(
460 "test_data/test_1ch.wav",
461 AudioReadConfig {
462 start: Position::Frame(100),
463 stop: Position::Frame(99),
464 ..Default::default()
465 },
466 ) {
467 Err(AudioReadError::EndFrameLargerThanStartFrame(_, _)) => (),
468 _ => panic!(),
469 }
470
471 match audio_read::<f32>(
472 "test_data/test_1ch.wav",
473 AudioReadConfig {
474 start: Position::Time(Duration::from_secs_f32(0.6)),
475 stop: Position::Time(Duration::from_secs_f32(0.5)),
476 ..Default::default()
477 },
478 ) {
479 Err(AudioReadError::EndFrameLargerThanStartFrame(_, _)) => (),
480 _ => panic!(),
481 }
482
483 match audio_read::<f32>(
484 "test_data/test_1ch.wav",
485 AudioReadConfig {
486 start_channel: Some(1),
487 ..Default::default()
488 },
489 ) {
490 Err(AudioReadError::InvalidStartChannel(_, _)) => (),
491 _ => panic!(),
492 }
493
494 match audio_read::<f32>(
495 "test_data/test_1ch.wav",
496 AudioReadConfig {
497 num_channels: Some(0),
498 ..Default::default()
499 },
500 ) {
501 Err(AudioReadError::InvalidNumChannels(0)) => (),
502 _ => panic!(),
503 }
504
505 match audio_read::<f32>(
506 "test_data/test_1ch.wav",
507 AudioReadConfig {
508 num_channels: Some(2),
509 ..Default::default()
510 },
511 ) {
512 Err(AudioReadError::InvalidNumChannels(2)) => (),
513 _ => panic!(),
514 }
515 }
516
517 #[test]
518 fn test_resample_preserves_frequency() {
519 const FREQUENCIES: [f64; 4] = [440.0, 554.37, 659.25, 880.0];
520 let sr_out: u32 = 22050;
521
522 let audio = audio_read::<f32>(
524 "test_data/test_4ch.wav",
525 AudioReadConfig {
526 sample_rate: Some(sr_out),
527 ..Default::default()
528 },
529 )
530 .unwrap();
531 let block = to_block(&audio);
532
533 assert_eq!(audio.sample_rate, sr_out); assert_eq!(block.num_channels(), 4);
535
536 let expected_frames = 22050;
538 assert_eq!(
539 block.num_frames(),
540 expected_frames,
541 "Expected {} frames, got {}",
542 expected_frames,
543 block.num_frames()
544 );
545
546 let start_frame = 100;
549 let test_frames = 1000;
550
551 for (ch, &freq) in FREQUENCIES.iter().enumerate() {
552 let mut max_error: f32 = 0.0;
553 for frame in start_frame..(start_frame + test_frames) {
554 let expected =
555 (2.0 * std::f64::consts::PI * freq * frame as f64 / sr_out as f64).sin() as f32;
556 let actual = block.sample(ch as u16, frame);
557 let error = (actual - expected).abs();
558 max_error = max_error.max(error);
559 }
560 assert!(
561 max_error < 0.02,
562 "Channel {} ({}Hz): max error {} exceeds threshold",
563 ch,
564 freq,
565 max_error
566 );
567 }
568 }
569
570 #[test]
571 fn test_channel_selection_with_resampling() {
572 const FREQUENCIES: [f64; 4] = [440.0, 554.37, 659.25, 880.0];
574 let sr_out: u32 = 22050;
575
576 let audio = audio_read::<f32>(
578 "test_data/test_4ch.wav",
579 AudioReadConfig {
580 start_channel: Some(1),
581 num_channels: Some(2),
582 sample_rate: Some(sr_out),
583 ..Default::default()
584 },
585 )
586 .unwrap();
587 let block = to_block(&audio);
588
589 assert_eq!(audio.num_channels, 2, "Should have 2 channels");
590 assert_eq!(
591 audio.sample_rate, sr_out,
592 "Sample rate should be the resampled rate"
593 );
594
595 let expected_frames = 22050;
597 assert_eq!(
598 block.num_frames(),
599 expected_frames,
600 "Expected {} frames, got {}",
601 expected_frames,
602 block.num_frames()
603 );
604
605 let selected_freqs = &FREQUENCIES[1..3];
608
609 let start_frame = 100;
610 let test_frames = 1000;
611
612 for (ch, &freq) in selected_freqs.iter().enumerate() {
613 let mut max_error: f32 = 0.0;
614 for frame in start_frame..(start_frame + test_frames) {
615 let expected =
616 (2.0 * std::f64::consts::PI * freq * frame as f64 / sr_out as f64).sin() as f32;
617 let actual = block.sample(ch as u16, frame);
618 let error = (actual - expected).abs();
619 max_error = max_error.max(error);
620 }
621 assert!(
622 max_error < 0.02,
623 "Channel {} ({}Hz): max error {} exceeds threshold",
624 ch,
625 freq,
626 max_error
627 );
628 }
629 }
630}