1use std::fs::File;
2use std::path::Path;
3
4use num::Float;
5use symphonia::core::audio::SampleBuffer;
6use symphonia::core::codecs::{CODEC_TYPE_NULL, DecoderOptions};
7use symphonia::core::errors::Error;
8use symphonia::core::formats::{FormatOptions, SeekMode, SeekTo};
9use symphonia::core::io::MediaSourceStream;
10use symphonia::core::meta::MetadataOptions;
11use symphonia::core::probe::Hint;
12use thiserror::Error;
13
14use crate::resample::{ResampleError, resample};
15
16#[derive(Debug, Clone)]
18pub struct Audio<F> {
19 pub samples_interleaved: Vec<F>,
21 pub sample_rate: u32,
23 pub num_channels: u16,
25}
26
27#[derive(Debug, Error)]
28pub enum ReadError {
29 #[error("could not read file")]
30 Io(#[from] std::io::Error),
31
32 #[error("could not decode audio")]
33 Decode(#[from] symphonia::core::errors::Error),
34
35 #[error("no track found")]
36 NoTrack,
37
38 #[error("no sample rate found")]
39 NoSampleRate,
40
41 #[error("end frame ({end}) must not exceed start frame ({start})")]
42 InvalidFrameRange { start: usize, end: usize },
43
44 #[error("start channel {index} out of bounds (file has {total} channels)")]
45 InvalidChannel { index: usize, total: usize },
46
47 #[error("invalid channel count: {0}")]
48 InvalidChannelCount(usize),
49
50 #[error("resample failed")]
51 Resample(#[from] ResampleError),
52}
53
54#[derive(Default, Debug, Clone, Copy)]
56pub enum Position {
57 #[default]
59 Default,
60 Time(std::time::Duration),
62 Frame(usize),
64}
65
66#[derive(Default)]
67pub struct ReadConfig {
68 pub start: Position,
70 pub stop: Position,
72 pub start_channel: Option<usize>,
74 pub num_channels: Option<usize>,
76 pub sample_rate: Option<u32>,
78}
79
80pub fn read<F: Float + rubato::Sample>(
81 path: impl AsRef<Path>,
82 config: ReadConfig,
83) -> Result<Audio<F>, ReadError> {
84 let src = File::open(path.as_ref())?;
85 let mss = MediaSourceStream::new(Box::new(src), Default::default());
86
87 let mut hint = Hint::new();
88 if let Some(ext) = path.as_ref().extension()
89 && let Some(ext_str) = ext.to_str()
90 {
91 hint.with_extension(ext_str);
92 }
93
94 let meta_opts: MetadataOptions = Default::default();
95 let fmt_opts: FormatOptions = Default::default();
96
97 let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
98
99 let mut format = probed.format;
100
101 let track = format
102 .tracks()
103 .iter()
104 .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
105 .ok_or(ReadError::NoTrack)?;
106
107 let sample_rate = track
108 .codec_params
109 .sample_rate
110 .ok_or(ReadError::NoSampleRate)?;
111
112 let track_id = track.id;
113
114 let codec_params = track.codec_params.clone();
116 let time_base = track.codec_params.time_base;
117
118 let start_frame = match config.start {
120 Position::Default => 0,
121 Position::Time(duration) => {
122 let secs = duration.as_secs_f64();
123 (secs * sample_rate as f64) as usize
124 }
125 Position::Frame(frame) => frame,
126 };
127
128 let end_frame: Option<usize> = match config.stop {
129 Position::Default => None,
130 Position::Time(duration) => {
131 let secs = duration.as_secs_f64();
132 Some((secs * sample_rate as f64) as usize)
133 }
134 Position::Frame(frame) => Some(frame),
135 };
136
137 if let Some(end_frame) = end_frame
138 && start_frame > end_frame
139 {
140 return Err(ReadError::InvalidFrameRange {
141 start: start_frame,
142 end: end_frame,
143 });
144 }
145
146 if start_frame > sample_rate as usize
152 && let Some(tb) = time_base
153 {
154 let seek_sample = (start_frame as f64 * 0.9) as u64;
156 let seek_ts = (seek_sample * tb.denom as u64) / (sample_rate as u64);
157
158 let _ = format.seek(
160 SeekMode::Accurate,
161 SeekTo::TimeStamp {
162 ts: seek_ts,
163 track_id,
164 },
165 );
166 }
167
168 let dec_opts: DecoderOptions = Default::default();
169 let mut decoder = symphonia::default::get_codecs().make(&codec_params, &dec_opts)?;
170
171 let mut sample_buf = None;
172 let mut samples = Vec::new();
173 let mut num_channels = 0usize;
174 let start_channel = config.start_channel;
175
176 let mut current_sample: Option<u64> = None;
178
179 loop {
180 let packet = match format.next_packet() {
181 Ok(packet) => packet,
182 Err(Error::ResetRequired) => {
183 decoder.reset();
184 continue;
185 }
186 Err(Error::IoError(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
187 break;
188 }
189 Err(err) => return Err(err.into()),
190 };
191
192 if packet.track_id() != track_id {
193 continue;
194 }
195
196 let decoded = decoder.decode(&packet)?;
197
198 if current_sample.is_none() {
200 let ts = packet.ts();
201 if let Some(tb) = time_base {
202 current_sample = Some((ts * sample_rate as u64) / tb.denom as u64);
204 } else {
205 current_sample = Some(0);
206 }
207 }
208
209 if sample_buf.is_none() {
210 let spec = *decoded.spec();
211 let duration = decoded.capacity() as u64;
212 sample_buf = Some(SampleBuffer::<f32>::new(duration, spec));
213
214 num_channels = spec.channels.count();
216
217 let ch_start = start_channel.unwrap_or(0);
219 let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
220
221 if ch_start >= num_channels {
222 return Err(ReadError::InvalidChannel {
223 index: ch_start,
224 total: num_channels,
225 });
226 }
227 if ch_count == 0 {
228 return Err(ReadError::InvalidChannelCount(0));
229 }
230 if ch_start + ch_count > num_channels {
231 return Err(ReadError::InvalidChannelCount(ch_count));
232 }
233 }
234
235 if let Some(buf) = &mut sample_buf {
236 buf.copy_interleaved_ref(decoded);
237 let packet_samples = buf.samples();
238
239 let mut pos = current_sample.unwrap_or(0);
240
241 let ch_start = start_channel.unwrap_or(0);
243 let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
244 let ch_end = ch_start + ch_count;
245
246 let frames = packet_samples.len() / num_channels;
248
249 for frame_idx in 0..frames {
251 if let Some(end) = end_frame
253 && pos >= end as u64
254 {
255 return Ok(Audio {
256 samples_interleaved: samples,
257 sample_rate,
258 num_channels: ch_count as u16,
259 });
260 }
261
262 if pos >= start_frame as u64 {
264 for ch in ch_start..ch_end {
267 let sample_idx = frame_idx * num_channels + ch;
268 samples.push(F::from(packet_samples[sample_idx]).unwrap());
269 }
270 }
271
272 pos += 1;
273 }
274
275 current_sample = Some(pos);
277 }
278 }
279
280 let ch_start = start_channel.unwrap_or(0);
282 let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
283
284 let samples = if let Some(sr_out) = config.sample_rate {
285 resample(&samples, ch_count, sample_rate, sr_out)?
287 } else {
288 samples
289 };
290
291 let actual_sample_rate = config.sample_rate.unwrap_or(sample_rate);
293
294 Ok(Audio {
295 samples_interleaved: samples,
296 sample_rate: actual_sample_rate,
297 num_channels: ch_count as u16,
298 })
299}
300
301#[cfg(feature = "audio-blocks")]
302pub fn read_block<F: num::Float + 'static + rubato::Sample>(
303 path: impl AsRef<Path>,
304 config: ReadConfig,
305) -> Result<(audio_blocks::Interleaved<F>, u32), ReadError> {
306 let audio = read(path, config)?;
307 Ok((
308 audio_blocks::Interleaved::from_slice(&audio.samples_interleaved, audio.num_channels),
309 audio.sample_rate,
310 ))
311}
312
313#[cfg(test)]
314mod tests {
315 use std::time::Duration;
316
317 use audio_blocks::{AudioBlock, InterleavedView};
318
319 use super::*;
320
321 fn to_block<F: num::Float + 'static>(audio: &Audio<F>) -> InterleavedView<'_, F> {
322 InterleavedView::from_slice(&audio.samples_interleaved, audio.num_channels)
323 }
324
325 #[test]
331 fn test_sine_wave_data_integrity() {
332 const SAMPLE_RATE: f64 = 48000.0;
333 const N_SAMPLES: usize = 48000;
334 const FREQUENCIES: [f64; 4] = [440.0, 554.37, 659.25, 880.0];
335
336 let audio = read::<f32>("test_data/test_4ch.wav", ReadConfig::default()).unwrap();
337 let block = to_block(&audio);
338
339 assert_eq!(audio.sample_rate, 48000);
340 assert_eq!(block.num_frames(), N_SAMPLES);
341 assert_eq!(block.num_channels(), 4);
342
343 for (ch, &freq) in FREQUENCIES.iter().enumerate() {
345 for frame in 0..N_SAMPLES {
346 let expected =
347 (2.0 * std::f64::consts::PI * freq * frame as f64 / SAMPLE_RATE).sin() as f32;
348 let actual = block.sample(ch as u16, frame);
349 assert!(
350 (actual - expected).abs() < 1e-4,
351 "Mismatch at channel {ch}, frame {frame}: expected {expected}, got {actual}"
352 );
353 }
354 }
355
356 let audio = read::<f32>(
358 "test_data/test_4ch.wav",
359 ReadConfig {
360 start: Position::Frame(24000),
361 stop: Position::Frame(24100),
362 ..Default::default()
363 },
364 )
365 .unwrap();
366 let block = to_block(&audio);
367
368 for (ch, &freq) in FREQUENCIES.iter().enumerate() {
369 for frame in 0..100 {
370 let actual_frame = 24000 + frame;
371 let expected = (2.0 * std::f64::consts::PI * freq * actual_frame as f64
372 / SAMPLE_RATE)
373 .sin() as f32;
374 let actual = block.sample(ch as u16, frame);
375 assert!(
376 (actual - expected).abs() < 1e-4,
377 "Offset mismatch at channel {ch}, frame {actual_frame}: expected {expected}, got {actual}"
378 );
379 }
380 }
381 }
382
383 #[test]
384 fn test_samples_selection() {
385 let audio1 = read::<f32>("test_data/test_1ch.wav", ReadConfig::default()).unwrap();
386 let block1 = to_block(&audio1);
387 assert_eq!(audio1.sample_rate, 48000);
388 assert_eq!(block1.num_frames(), 48000);
389 assert_eq!(block1.num_channels(), 1);
390
391 let audio2 = read::<f32>(
392 "test_data/test_1ch.wav",
393 ReadConfig {
394 start: Position::Frame(1100),
395 stop: Position::Frame(1200),
396 ..Default::default()
397 },
398 )
399 .unwrap();
400 let block2 = to_block(&audio2);
401 assert_eq!(audio2.sample_rate, 48000);
402 assert_eq!(block2.num_frames(), 100);
403 assert_eq!(block2.num_channels(), 1);
404 assert_eq!(block1.raw_data()[1100..1200], block2.raw_data()[..]);
405 }
406
407 #[test]
408 fn test_time_selection() {
409 let audio1 = read::<f32>("test_data/test_1ch.wav", ReadConfig::default()).unwrap();
410 let block1 = to_block(&audio1);
411 assert_eq!(audio1.sample_rate, 48000);
412 assert_eq!(block1.num_frames(), 48000);
413 assert_eq!(block1.num_channels(), 1);
414
415 let audio2 = read::<f32>(
416 "test_data/test_1ch.wav",
417 ReadConfig {
418 start: Position::Time(Duration::from_secs_f32(0.5)),
419 stop: Position::Time(Duration::from_secs_f32(0.6)),
420 ..Default::default()
421 },
422 )
423 .unwrap();
424 let block2 = to_block(&audio2);
425
426 assert_eq!(audio2.sample_rate, 48000);
427 assert_eq!(block2.num_frames(), 4800);
428 assert_eq!(block2.num_channels(), 1);
429 assert_eq!(block1.raw_data()[24000..28800], block2.raw_data()[..]);
430 }
431
432 #[test]
433 fn test_channel_selection() {
434 let audio1 = read::<f32>("test_data/test_4ch.wav", ReadConfig::default()).unwrap();
435 let block1 = to_block(&audio1);
436 assert_eq!(audio1.sample_rate, 48000);
437 assert_eq!(block1.num_frames(), 48000);
438 assert_eq!(block1.num_channels(), 4);
439
440 let audio2 = read::<f32>(
441 "test_data/test_4ch.wav",
442 ReadConfig {
443 start_channel: Some(1),
444 num_channels: Some(2),
445 ..Default::default()
446 },
447 )
448 .unwrap();
449 let block2 = to_block(&audio2);
450
451 assert_eq!(audio2.sample_rate, 48000);
452 assert_eq!(block2.num_frames(), 48000);
453 assert_eq!(block2.num_channels(), 2);
454
455 for frame in 0..10 {
457 assert_eq!(block2.sample(0, frame), block1.sample(1, frame));
458 assert_eq!(block2.sample(1, frame), block1.sample(2, frame));
459 }
460 }
461
462 #[test]
463 fn test_fail_selection() {
464 match read::<f32>(
465 "test_data/test_1ch.wav",
466 ReadConfig {
467 start: Position::Frame(100),
468 stop: Position::Frame(99),
469 ..Default::default()
470 },
471 ) {
472 Err(ReadError::InvalidFrameRange { start: _, end: _ }) => (),
473 _ => panic!(),
474 }
475
476 match read::<f32>(
477 "test_data/test_1ch.wav",
478 ReadConfig {
479 start: Position::Time(Duration::from_secs_f32(0.6)),
480 stop: Position::Time(Duration::from_secs_f32(0.5)),
481 ..Default::default()
482 },
483 ) {
484 Err(ReadError::InvalidFrameRange { start: _, end: _ }) => (),
485 _ => panic!(),
486 }
487
488 match read::<f32>(
489 "test_data/test_1ch.wav",
490 ReadConfig {
491 start_channel: Some(1),
492 ..Default::default()
493 },
494 ) {
495 Err(ReadError::InvalidChannel { index: _, total: _ }) => (),
496 _ => panic!(),
497 }
498
499 match read::<f32>(
500 "test_data/test_1ch.wav",
501 ReadConfig {
502 num_channels: Some(0),
503 ..Default::default()
504 },
505 ) {
506 Err(ReadError::InvalidChannelCount(0)) => (),
507 _ => panic!(),
508 }
509
510 match read::<f32>(
511 "test_data/test_1ch.wav",
512 ReadConfig {
513 num_channels: Some(2),
514 ..Default::default()
515 },
516 ) {
517 Err(ReadError::InvalidChannelCount(2)) => (),
518 _ => panic!(),
519 }
520 }
521
522 #[test]
523 fn test_resample_preserves_frequency() {
524 const FREQUENCIES: [f64; 4] = [440.0, 554.37, 659.25, 880.0];
525 let sr_out: u32 = 22050;
526
527 let audio = read::<f32>(
529 "test_data/test_4ch.wav",
530 ReadConfig {
531 sample_rate: Some(sr_out),
532 ..Default::default()
533 },
534 )
535 .unwrap();
536 let block = to_block(&audio);
537
538 assert_eq!(audio.sample_rate, sr_out); assert_eq!(block.num_channels(), 4);
540
541 let expected_frames = 22050;
543 assert_eq!(
544 block.num_frames(),
545 expected_frames,
546 "Expected {} frames, got {}",
547 expected_frames,
548 block.num_frames()
549 );
550
551 let start_frame = 100;
554 let test_frames = 1000;
555
556 for (ch, &freq) in FREQUENCIES.iter().enumerate() {
557 let mut max_error: f32 = 0.0;
558 for frame in start_frame..(start_frame + test_frames) {
559 let expected =
560 (2.0 * std::f64::consts::PI * freq * frame as f64 / sr_out as f64).sin() as f32;
561 let actual = block.sample(ch as u16, frame);
562 let error = (actual - expected).abs();
563 max_error = max_error.max(error);
564 }
565 assert!(
566 max_error < 0.02,
567 "Channel {} ({}Hz): max error {} exceeds threshold",
568 ch,
569 freq,
570 max_error
571 );
572 }
573 }
574
575 #[test]
576 fn test_channel_selection_with_resampling() {
577 const FREQUENCIES: [f64; 4] = [440.0, 554.37, 659.25, 880.0];
579 let sr_out: u32 = 22050;
580
581 let audio = read::<f32>(
583 "test_data/test_4ch.wav",
584 ReadConfig {
585 start_channel: Some(1),
586 num_channels: Some(2),
587 sample_rate: Some(sr_out),
588 ..Default::default()
589 },
590 )
591 .unwrap();
592 let block = to_block(&audio);
593
594 assert_eq!(audio.num_channels, 2, "Should have 2 channels");
595 assert_eq!(
596 audio.sample_rate, sr_out,
597 "Sample rate should be the resampled rate"
598 );
599
600 let expected_frames = 22050;
602 assert_eq!(
603 block.num_frames(),
604 expected_frames,
605 "Expected {} frames, got {}",
606 expected_frames,
607 block.num_frames()
608 );
609
610 let selected_freqs = &FREQUENCIES[1..3];
613
614 let start_frame = 100;
615 let test_frames = 1000;
616
617 for (ch, &freq) in selected_freqs.iter().enumerate() {
618 let mut max_error: f32 = 0.0;
619 for frame in start_frame..(start_frame + test_frames) {
620 let expected =
621 (2.0 * std::f64::consts::PI * freq * frame as f64 / sr_out as f64).sin() as f32;
622 let actual = block.sample(ch as u16, frame);
623 let error = (actual - expected).abs();
624 max_error = max_error.max(error);
625 }
626 assert!(
627 max_error < 0.02,
628 "Channel {} ({}Hz): max error {} exceeds threshold",
629 ch,
630 freq,
631 max_error
632 );
633 }
634 }
635}