1use std::fs::File;
2use std::path::Path;
3
4use num::Float;
5use symphonia::core::audio::SampleBuffer;
6use symphonia::core::codecs::{CODEC_TYPE_NULL, DecoderOptions};
7use symphonia::core::errors::Error;
8use symphonia::core::formats::{FormatOptions, SeekMode, SeekTo};
9use symphonia::core::io::MediaSourceStream;
10use symphonia::core::meta::MetadataOptions;
11use symphonia::core::probe::Hint;
12use thiserror::Error;
13
14use crate::resample::{ResampleError, resample};
15
16#[derive(Debug, Clone)]
18pub struct Audio<F> {
19 pub samples_interleaved: Vec<F>,
21 pub sample_rate: u32,
23 pub num_channels: u16,
25}
26
27#[derive(Debug, Error)]
28pub enum ReadError {
29 #[error("could not read file")]
30 Io(#[from] std::io::Error),
31
32 #[error("could not decode audio")]
33 Decode(#[from] symphonia::core::errors::Error),
34
35 #[error("no track found")]
36 NoTrack,
37
38 #[error("no sample rate found")]
39 NoSampleRate,
40
41 #[error("end frame ({end}) must not exceed start frame ({start})")]
42 InvalidFrameRange { start: usize, end: usize },
43
44 #[error("start channel {index} out of bounds (file has {total} channels)")]
45 InvalidChannel { index: usize, total: usize },
46
47 #[error("invalid channel count: {0}")]
48 InvalidChannelCount(usize),
49
50 #[error("resample failed")]
51 Resample(#[from] ResampleError),
52}
53
54#[derive(Default, Debug, Clone, Copy)]
56pub enum Position {
57 #[default]
59 Default,
60 Time(std::time::Duration),
62 Frame(usize),
64}
65
66#[derive(Default)]
67pub struct ReadConfig {
68 pub start: Position,
70 pub stop: Position,
72 pub start_channel: Option<usize>,
74 pub num_channels: Option<usize>,
76 pub sample_rate: Option<u32>,
78}
79
80pub fn read<F: Float + rubato::Sample>(
81 path: impl AsRef<Path>,
82 config: ReadConfig,
83) -> Result<Audio<F>, ReadError> {
84 let src = File::open(path.as_ref())?;
85 let mss = MediaSourceStream::new(Box::new(src), Default::default());
86
87 let mut hint = Hint::new();
88 if let Some(ext) = path.as_ref().extension()
89 && let Some(ext_str) = ext.to_str()
90 {
91 hint.with_extension(ext_str);
92 }
93
94 let meta_opts: MetadataOptions = Default::default();
95 let fmt_opts: FormatOptions = Default::default();
96
97 let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
98
99 let mut format = probed.format;
100
101 let track = format
102 .tracks()
103 .iter()
104 .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
105 .ok_or(ReadError::NoTrack)?;
106
107 let sample_rate = track
108 .codec_params
109 .sample_rate
110 .ok_or(ReadError::NoSampleRate)?;
111
112 let track_id = track.id;
113
114 let codec_params = track.codec_params.clone();
116 let time_base = track.codec_params.time_base;
117
118 let start_frame = match config.start {
120 Position::Default => 0,
121 Position::Time(duration) => {
122 let secs = duration.as_secs_f64();
123 (secs * sample_rate as f64) as usize
124 }
125 Position::Frame(frame) => frame,
126 };
127
128 let end_frame: Option<usize> = match config.stop {
129 Position::Default => None,
130 Position::Time(duration) => {
131 let secs = duration.as_secs_f64();
132 Some((secs * sample_rate as f64) as usize)
133 }
134 Position::Frame(frame) => Some(frame),
135 };
136
137 if let Some(end_frame) = end_frame
138 && start_frame > end_frame
139 {
140 return Err(ReadError::InvalidFrameRange {
141 start: start_frame,
142 end: end_frame,
143 });
144 }
145
146 if start_frame > sample_rate as usize
152 && let Some(tb) = time_base
153 {
154 let seek_sample = (start_frame as f64 * 0.9) as u64;
156 let seek_ts = (seek_sample * tb.denom as u64) / (sample_rate as u64);
157
158 let _ = format.seek(
160 SeekMode::Accurate,
161 SeekTo::TimeStamp {
162 ts: seek_ts,
163 track_id,
164 },
165 );
166 }
167
168 let dec_opts: DecoderOptions = Default::default();
169 let mut decoder = symphonia::default::get_codecs().make(&codec_params, &dec_opts)?;
170
171 let mut sample_buf = None;
172 let mut samples = Vec::new();
173 let mut num_channels = 0usize;
174 let start_channel = config.start_channel;
175
176 let mut current_sample: Option<u64> = None;
178
179 loop {
180 let packet = match format.next_packet() {
181 Ok(packet) => packet,
182 Err(Error::ResetRequired) => {
183 decoder.reset();
184 continue;
185 }
186 Err(Error::IoError(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
187 break;
188 }
189 Err(err) => return Err(err.into()),
190 };
191
192 if packet.track_id() != track_id {
193 continue;
194 }
195
196 let decoded = decoder.decode(&packet)?;
197
198 if current_sample.is_none() {
200 let ts = packet.ts();
201 if let Some(tb) = time_base {
202 current_sample = Some((ts * sample_rate as u64) / tb.denom as u64);
204 } else {
205 current_sample = Some(0);
206 }
207 }
208
209 if sample_buf.is_none() {
210 let spec = *decoded.spec();
211 let duration = decoded.capacity() as u64;
212 sample_buf = Some(SampleBuffer::<f32>::new(duration, spec));
213
214 num_channels = spec.channels.count();
216
217 let ch_start = start_channel.unwrap_or(0);
219 let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
220
221 if ch_start >= num_channels {
222 return Err(ReadError::InvalidChannel {
223 index: ch_start,
224 total: num_channels,
225 });
226 }
227 if ch_count == 0 {
228 return Err(ReadError::InvalidChannelCount(0));
229 }
230 if ch_start + ch_count > num_channels {
231 return Err(ReadError::InvalidChannelCount(ch_count));
232 }
233 }
234
235 if let Some(buf) = &mut sample_buf {
236 buf.copy_interleaved_ref(decoded);
237 let packet_samples = buf.samples();
238
239 let mut pos = current_sample.unwrap_or(0);
240
241 let ch_start = start_channel.unwrap_or(0);
243 let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
244 let ch_end = ch_start + ch_count;
245
246 let frames = packet_samples.len() / num_channels;
248
249 for frame_idx in 0..frames {
251 if let Some(end) = end_frame
253 && pos >= end as u64
254 {
255 return Ok(Audio {
256 samples_interleaved: samples,
257 sample_rate,
258 num_channels: ch_count as u16,
259 });
260 }
261
262 if pos >= start_frame as u64 {
264 for ch in ch_start..ch_end {
267 let sample_idx = frame_idx * num_channels + ch;
268 samples.push(F::from(packet_samples[sample_idx]).unwrap());
269 }
270 }
271
272 pos += 1;
273 }
274
275 current_sample = Some(pos);
277 }
278 }
279
280 let ch_start = start_channel.unwrap_or(0);
282 let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
283
284 let samples = if let Some(sr_out) = config.sample_rate
285 && sr_out != sample_rate
286 {
287 resample(&samples, ch_count, sample_rate, sr_out)?
289 } else {
290 samples
291 };
292
293 let actual_sample_rate = config.sample_rate.unwrap_or(sample_rate);
295
296 Ok(Audio {
297 samples_interleaved: samples,
298 sample_rate: actual_sample_rate,
299 num_channels: ch_count as u16,
300 })
301}
302
303#[cfg(feature = "audio-blocks")]
304pub fn read_block<F: num::Float + 'static + rubato::Sample>(
305 path: impl AsRef<Path>,
306 config: ReadConfig,
307) -> Result<(audio_blocks::Interleaved<F>, u32), ReadError> {
308 let audio = read(path, config)?;
309 Ok((
310 audio_blocks::Interleaved::from_slice(&audio.samples_interleaved, audio.num_channels),
311 audio.sample_rate,
312 ))
313}
314
315#[cfg(test)]
316mod tests {
317 use std::time::Duration;
318
319 use audio_blocks::{AudioBlock, InterleavedView};
320
321 use super::*;
322
323 fn to_block<F: num::Float + 'static>(audio: &Audio<F>) -> InterleavedView<'_, F> {
324 InterleavedView::from_slice(&audio.samples_interleaved, audio.num_channels)
325 }
326
327 #[test]
333 fn test_sine_wave_data_integrity() {
334 const SAMPLE_RATE: f64 = 48000.0;
335 const N_SAMPLES: usize = 48000;
336 const FREQUENCIES: [f64; 4] = [440.0, 554.37, 659.25, 880.0];
337
338 let audio = read::<f32>("test_data/test_4ch.wav", ReadConfig::default()).unwrap();
339 let block = to_block(&audio);
340
341 assert_eq!(audio.sample_rate, 48000);
342 assert_eq!(block.num_frames(), N_SAMPLES);
343 assert_eq!(block.num_channels(), 4);
344
345 for (ch, &freq) in FREQUENCIES.iter().enumerate() {
347 for frame in 0..N_SAMPLES {
348 let expected =
349 (2.0 * std::f64::consts::PI * freq * frame as f64 / SAMPLE_RATE).sin() as f32;
350 let actual = block.sample(ch as u16, frame);
351 assert!(
352 (actual - expected).abs() < 1e-4,
353 "Mismatch at channel {ch}, frame {frame}: expected {expected}, got {actual}"
354 );
355 }
356 }
357
358 let audio = read::<f32>(
360 "test_data/test_4ch.wav",
361 ReadConfig {
362 start: Position::Frame(24000),
363 stop: Position::Frame(24100),
364 ..Default::default()
365 },
366 )
367 .unwrap();
368 let block = to_block(&audio);
369
370 for (ch, &freq) in FREQUENCIES.iter().enumerate() {
371 for frame in 0..100 {
372 let actual_frame = 24000 + frame;
373 let expected = (2.0 * std::f64::consts::PI * freq * actual_frame as f64
374 / SAMPLE_RATE)
375 .sin() as f32;
376 let actual = block.sample(ch as u16, frame);
377 assert!(
378 (actual - expected).abs() < 1e-4,
379 "Offset mismatch at channel {ch}, frame {actual_frame}: expected {expected}, got {actual}"
380 );
381 }
382 }
383 }
384
385 #[test]
386 fn test_samples_selection() {
387 let audio1 = read::<f32>("test_data/test_1ch.wav", ReadConfig::default()).unwrap();
388 let block1 = to_block(&audio1);
389 assert_eq!(audio1.sample_rate, 48000);
390 assert_eq!(block1.num_frames(), 48000);
391 assert_eq!(block1.num_channels(), 1);
392
393 let audio2 = read::<f32>(
394 "test_data/test_1ch.wav",
395 ReadConfig {
396 start: Position::Frame(1100),
397 stop: Position::Frame(1200),
398 ..Default::default()
399 },
400 )
401 .unwrap();
402 let block2 = to_block(&audio2);
403 assert_eq!(audio2.sample_rate, 48000);
404 assert_eq!(block2.num_frames(), 100);
405 assert_eq!(block2.num_channels(), 1);
406 assert_eq!(block1.raw_data()[1100..1200], block2.raw_data()[..]);
407 }
408
409 #[test]
410 fn test_time_selection() {
411 let audio1 = read::<f32>("test_data/test_1ch.wav", ReadConfig::default()).unwrap();
412 let block1 = to_block(&audio1);
413 assert_eq!(audio1.sample_rate, 48000);
414 assert_eq!(block1.num_frames(), 48000);
415 assert_eq!(block1.num_channels(), 1);
416
417 let audio2 = read::<f32>(
418 "test_data/test_1ch.wav",
419 ReadConfig {
420 start: Position::Time(Duration::from_secs_f32(0.5)),
421 stop: Position::Time(Duration::from_secs_f32(0.6)),
422 ..Default::default()
423 },
424 )
425 .unwrap();
426 let block2 = to_block(&audio2);
427
428 assert_eq!(audio2.sample_rate, 48000);
429 assert_eq!(block2.num_frames(), 4800);
430 assert_eq!(block2.num_channels(), 1);
431 assert_eq!(block1.raw_data()[24000..28800], block2.raw_data()[..]);
432 }
433
434 #[test]
435 fn test_channel_selection() {
436 let audio1 = read::<f32>("test_data/test_4ch.wav", ReadConfig::default()).unwrap();
437 let block1 = to_block(&audio1);
438 assert_eq!(audio1.sample_rate, 48000);
439 assert_eq!(block1.num_frames(), 48000);
440 assert_eq!(block1.num_channels(), 4);
441
442 let audio2 = read::<f32>(
443 "test_data/test_4ch.wav",
444 ReadConfig {
445 start_channel: Some(1),
446 num_channels: Some(2),
447 ..Default::default()
448 },
449 )
450 .unwrap();
451 let block2 = to_block(&audio2);
452
453 assert_eq!(audio2.sample_rate, 48000);
454 assert_eq!(block2.num_frames(), 48000);
455 assert_eq!(block2.num_channels(), 2);
456
457 for frame in 0..10 {
459 assert_eq!(block2.sample(0, frame), block1.sample(1, frame));
460 assert_eq!(block2.sample(1, frame), block1.sample(2, frame));
461 }
462 }
463
464 #[test]
465 fn test_fail_selection() {
466 match read::<f32>(
467 "test_data/test_1ch.wav",
468 ReadConfig {
469 start: Position::Frame(100),
470 stop: Position::Frame(99),
471 ..Default::default()
472 },
473 ) {
474 Err(ReadError::InvalidFrameRange { start: _, end: _ }) => (),
475 _ => panic!(),
476 }
477
478 match read::<f32>(
479 "test_data/test_1ch.wav",
480 ReadConfig {
481 start: Position::Time(Duration::from_secs_f32(0.6)),
482 stop: Position::Time(Duration::from_secs_f32(0.5)),
483 ..Default::default()
484 },
485 ) {
486 Err(ReadError::InvalidFrameRange { start: _, end: _ }) => (),
487 _ => panic!(),
488 }
489
490 match read::<f32>(
491 "test_data/test_1ch.wav",
492 ReadConfig {
493 start_channel: Some(1),
494 ..Default::default()
495 },
496 ) {
497 Err(ReadError::InvalidChannel { index: _, total: _ }) => (),
498 _ => panic!(),
499 }
500
501 match read::<f32>(
502 "test_data/test_1ch.wav",
503 ReadConfig {
504 num_channels: Some(0),
505 ..Default::default()
506 },
507 ) {
508 Err(ReadError::InvalidChannelCount(0)) => (),
509 _ => panic!(),
510 }
511
512 match read::<f32>(
513 "test_data/test_1ch.wav",
514 ReadConfig {
515 num_channels: Some(2),
516 ..Default::default()
517 },
518 ) {
519 Err(ReadError::InvalidChannelCount(2)) => (),
520 _ => panic!(),
521 }
522 }
523
524 #[test]
525 fn test_resample_preserves_frequency() {
526 const FREQUENCIES: [f64; 4] = [440.0, 554.37, 659.25, 880.0];
527 let sr_out: u32 = 22050;
528
529 let audio = read::<f32>(
531 "test_data/test_4ch.wav",
532 ReadConfig {
533 sample_rate: Some(sr_out),
534 ..Default::default()
535 },
536 )
537 .unwrap();
538 let block = to_block(&audio);
539
540 assert_eq!(audio.sample_rate, sr_out); assert_eq!(block.num_channels(), 4);
542
543 let expected_frames = 22050;
545 assert_eq!(
546 block.num_frames(),
547 expected_frames,
548 "Expected {} frames, got {}",
549 expected_frames,
550 block.num_frames()
551 );
552
553 let start_frame = 100;
556 let test_frames = 1000;
557
558 for (ch, &freq) in FREQUENCIES.iter().enumerate() {
559 let mut max_error: f32 = 0.0;
560 for frame in start_frame..(start_frame + test_frames) {
561 let expected =
562 (2.0 * std::f64::consts::PI * freq * frame as f64 / sr_out as f64).sin() as f32;
563 let actual = block.sample(ch as u16, frame);
564 let error = (actual - expected).abs();
565 max_error = max_error.max(error);
566 }
567 assert!(
568 max_error < 0.02,
569 "Channel {} ({}Hz): max error {} exceeds threshold",
570 ch,
571 freq,
572 max_error
573 );
574 }
575 }
576
577 #[test]
578 fn test_channel_selection_with_resampling() {
579 const FREQUENCIES: [f64; 4] = [440.0, 554.37, 659.25, 880.0];
581 let sr_out: u32 = 22050;
582
583 let audio = read::<f32>(
585 "test_data/test_4ch.wav",
586 ReadConfig {
587 start_channel: Some(1),
588 num_channels: Some(2),
589 sample_rate: Some(sr_out),
590 ..Default::default()
591 },
592 )
593 .unwrap();
594 let block = to_block(&audio);
595
596 assert_eq!(audio.num_channels, 2, "Should have 2 channels");
597 assert_eq!(
598 audio.sample_rate, sr_out,
599 "Sample rate should be the resampled rate"
600 );
601
602 let expected_frames = 22050;
604 assert_eq!(
605 block.num_frames(),
606 expected_frames,
607 "Expected {} frames, got {}",
608 expected_frames,
609 block.num_frames()
610 );
611
612 let selected_freqs = &FREQUENCIES[1..3];
615
616 let start_frame = 100;
617 let test_frames = 1000;
618
619 for (ch, &freq) in selected_freqs.iter().enumerate() {
620 let mut max_error: f32 = 0.0;
621 for frame in start_frame..(start_frame + test_frames) {
622 let expected =
623 (2.0 * std::f64::consts::PI * freq * frame as f64 / sr_out as f64).sin() as f32;
624 let actual = block.sample(ch as u16, frame);
625 let error = (actual - expected).abs();
626 max_error = max_error.max(error);
627 }
628 assert!(
629 max_error < 0.02,
630 "Channel {} ({}Hz): max error {} exceeds threshold",
631 ch,
632 freq,
633 max_error
634 );
635 }
636 }
637}