Skip to main content

audio_io/
reader.rs

1use std::fs::File;
2use std::path::Path;
3
4use num::Float;
5use symphonia::core::audio::SampleBuffer;
6use symphonia::core::codecs::{CODEC_TYPE_NULL, DecoderOptions};
7use symphonia::core::errors::Error;
8use symphonia::core::formats::{FormatOptions, SeekMode, SeekTo};
9use symphonia::core::io::MediaSourceStream;
10use symphonia::core::meta::MetadataOptions;
11use symphonia::core::probe::Hint;
12use thiserror::Error;
13
14use crate::resample::resample;
15
16/// Audio data with interleaved samples
17#[derive(Debug, Clone)]
18pub struct Audio<F> {
19    /// Interleaved audio samples
20    pub samples_interleaved: Vec<F>,
21    /// Sample rate in Hz
22    pub sample_rate: u32,
23    /// Number of channels
24    pub num_channels: u16,
25}
26
27#[derive(Debug, Error)]
28pub enum AudioReadError {
29    #[error("could not read file")]
30    FileError(#[from] std::io::Error),
31    #[error("could not decode audio")]
32    EncodingError(#[from] symphonia::core::errors::Error),
33    #[error("could not find track in file")]
34    NoTrack,
35    #[error("could not find sample rate in file")]
36    NoSampleRate,
37    #[error("end frame {0} is larger than start frame {1}")]
38    EndFrameLargerThanStartFrame(usize, usize),
39    #[error("start channel {0} invalid, audio file has only {1} channels")]
40    InvalidStartChannel(usize, usize),
41    #[error("invalid number of channels to extract: {0}")]
42    InvalidNumChannels(usize),
43}
44
45/// Position in the audio stream (for start or stop points)
46#[derive(Default, Debug, Clone, Copy)]
47pub enum Position {
48    /// Start from beginning or read until the end (depending on context)
49    #[default]
50    Default,
51    /// Specific time offset
52    Time(std::time::Duration),
53    /// Specific frame number (sample position across all channels)
54    Frame(usize),
55}
56
57#[derive(Default)]
58pub struct AudioReadConfig {
59    /// Where to start reading audio (time or frame-based)
60    pub start: Position,
61    /// Where to stop reading audio (time or frame-based)
62    pub stop: Position,
63    /// Starting channel to extract (0-indexed). None means start from channel 0.
64    pub start_channel: Option<usize>,
65    /// Number of channels to extract. None means extract all remaining channels.
66    pub num_channels: Option<usize>,
67    /// If specified the audio will be resampled to the given sample rate
68    pub sample_rate: Option<u32>,
69}
70
71pub fn audio_read<F: Float + rubato::Sample>(
72    path: impl AsRef<Path>,
73    config: AudioReadConfig,
74) -> Result<Audio<F>, AudioReadError> {
75    let src = File::open(path.as_ref())?;
76    let mss = MediaSourceStream::new(Box::new(src), Default::default());
77
78    let mut hint = Hint::new();
79    if let Some(ext) = path.as_ref().extension()
80        && let Some(ext_str) = ext.to_str()
81    {
82        hint.with_extension(ext_str);
83    }
84
85    let meta_opts: MetadataOptions = Default::default();
86    let fmt_opts: FormatOptions = Default::default();
87
88    let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
89
90    let mut format = probed.format;
91
92    let track = format
93        .tracks()
94        .iter()
95        .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
96        .ok_or(AudioReadError::NoTrack)?;
97
98    let sample_rate = track
99        .codec_params
100        .sample_rate
101        .ok_or(AudioReadError::NoSampleRate)?;
102
103    let track_id = track.id;
104
105    // Clone codec params before the mutable borrow
106    let codec_params = track.codec_params.clone();
107    let time_base = track.codec_params.time_base;
108
109    // Convert start/stop positions to frame numbers
110    let start_frame = match config.start {
111        Position::Default => 0,
112        Position::Time(duration) => {
113            let secs = duration.as_secs_f64();
114            (secs * sample_rate as f64) as usize
115        }
116        Position::Frame(frame) => frame,
117    };
118
119    let end_frame: Option<usize> = match config.stop {
120        Position::Default => None,
121        Position::Time(duration) => {
122            let secs = duration.as_secs_f64();
123            Some((secs * sample_rate as f64) as usize)
124        }
125        Position::Frame(frame) => Some(frame),
126    };
127
128    if let Some(end_frame) = end_frame
129        && start_frame > end_frame
130    {
131        return Err(AudioReadError::EndFrameLargerThanStartFrame(
132            end_frame,
133            start_frame,
134        ));
135    }
136
137    // Optimization: Use seeking for large offsets to avoid decoding unnecessary data.
138    // For small offsets (< 1 second), we decode from the beginning and discard samples,
139    // which is simpler and avoids seek complexity. This threshold balances simplicity
140    // with performance - seeking has overhead and keyframe alignment issues that make
141    // it inefficient for small offsets.
142    if start_frame > sample_rate as usize
143        && let Some(tb) = time_base
144    {
145        // Seek to 90% of the target to account for keyframe positioning
146        let seek_sample = (start_frame as f64 * 0.9) as u64;
147        let seek_ts = (seek_sample * tb.denom as u64) / (sample_rate as u64);
148
149        // Try to seek, but don't fail if seeking doesn't work
150        let _ = format.seek(
151            SeekMode::Accurate,
152            SeekTo::TimeStamp {
153                ts: seek_ts,
154                track_id,
155            },
156        );
157    }
158
159    let dec_opts: DecoderOptions = Default::default();
160    let mut decoder = symphonia::default::get_codecs().make(&codec_params, &dec_opts)?;
161
162    let mut sample_buf = None;
163    let mut samples = Vec::new();
164    let mut num_channels = 0usize;
165    let start_channel = config.start_channel;
166
167    // We'll track exact position by counting samples as we decode
168    let mut current_sample: Option<u64> = None;
169
170    loop {
171        let packet = match format.next_packet() {
172            Ok(packet) => packet,
173            Err(Error::ResetRequired) => {
174                decoder.reset();
175                continue;
176            }
177            Err(Error::IoError(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
178                break;
179            }
180            Err(err) => return Err(err.into()),
181        };
182
183        if packet.track_id() != track_id {
184            continue;
185        }
186
187        let decoded = decoder.decode(&packet)?;
188
189        // Get the timestamp of this packet to know our position
190        if current_sample.is_none() {
191            let ts = packet.ts();
192            if let Some(tb) = time_base {
193                // Convert timestamp to sample position
194                current_sample = Some((ts * sample_rate as u64) / tb.denom as u64);
195            } else {
196                current_sample = Some(0);
197            }
198        }
199
200        if sample_buf.is_none() {
201            let spec = *decoded.spec();
202            let duration = decoded.capacity() as u64;
203            sample_buf = Some(SampleBuffer::<f32>::new(duration, spec));
204
205            // Get the number of channels from the spec
206            num_channels = spec.channels.count();
207
208            // Validate channel range
209            let ch_start = start_channel.unwrap_or(0);
210            let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
211
212            if ch_start >= num_channels {
213                return Err(AudioReadError::InvalidStartChannel(ch_start, num_channels));
214            }
215            if ch_count == 0 {
216                return Err(AudioReadError::InvalidNumChannels(0));
217            }
218            if ch_start + ch_count > num_channels {
219                return Err(AudioReadError::InvalidNumChannels(ch_count));
220            }
221        }
222
223        if let Some(buf) = &mut sample_buf {
224            buf.copy_interleaved_ref(decoded);
225            let packet_samples = buf.samples();
226
227            let mut pos = current_sample.unwrap_or(0);
228
229            // Determine channel range to extract
230            let ch_start = start_channel.unwrap_or(0);
231            let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
232            let ch_end = ch_start + ch_count;
233
234            // Calculate frames using the ORIGINAL channel count from the file
235            let frames = packet_samples.len() / num_channels;
236
237            // Process all frames, extracting only the requested channel range
238            for frame_idx in 0..frames {
239                // Check if we've reached the end frame
240                if let Some(end) = end_frame
241                    && pos >= end as u64
242                {
243                    return Ok(Audio {
244                        samples_interleaved: samples,
245                        sample_rate,
246                        num_channels: ch_count as u16,
247                    });
248                }
249
250                // Start collecting samples once we reach start_frame
251                if pos >= start_frame as u64 {
252                    // Extract the selected channel range from this frame
253                    // When ch_start=0 and ch_count=num_channels, this extracts all channels
254                    for ch in ch_start..ch_end {
255                        let sample_idx = frame_idx * num_channels + ch;
256                        samples.push(F::from(packet_samples[sample_idx]).unwrap());
257                    }
258                }
259
260                pos += 1;
261            }
262
263            // Update our position tracker
264            current_sample = Some(pos);
265        }
266    }
267
268    // Calculate the actual channel count in the extracted samples
269    let ch_start = start_channel.unwrap_or(0);
270    let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
271
272    let samples = if let Some(sr_out) = config.sample_rate {
273        // Use ch_count (the selected channels) not num_channels (original file channels)
274        resample(&samples, ch_count, sample_rate, sr_out).map_err(|_| AudioReadError::NoTrack)?
275    } else {
276        samples
277    };
278
279    // Return the actual sample rate (resampled if applicable, otherwise original)
280    let actual_sample_rate = config.sample_rate.unwrap_or(sample_rate);
281
282    Ok(Audio {
283        samples_interleaved: samples,
284        sample_rate: actual_sample_rate,
285        num_channels: ch_count as u16,
286    })
287}
288
289#[cfg(feature = "audio-blocks")]
290pub fn audio_read_block<F: num::Float + 'static + rubato::Sample>(
291    path: impl AsRef<Path>,
292    config: AudioReadConfig,
293) -> Result<(audio_blocks::AudioBlockInterleaved<F>, u32), AudioReadError> {
294    let audio = audio_read(path, config)?;
295    Ok((
296        audio_blocks::AudioBlockInterleaved::from_slice(
297            &audio.samples_interleaved,
298            audio.num_channels,
299        ),
300        audio.sample_rate,
301    ))
302}
303
304#[cfg(test)]
305mod tests {
306    use std::time::Duration;
307
308    use audio_blocks::{AudioBlock, AudioBlockInterleavedView};
309
310    use super::*;
311
312    fn to_block<F: num::Float + 'static>(audio: &Audio<F>) -> AudioBlockInterleavedView<'_, F> {
313        AudioBlockInterleavedView::from_slice(&audio.samples_interleaved, audio.num_channels)
314    }
315
316    /// Verify that the read audio data matches the expected sine wave values.
317    /// The test file was generated by utils/generate_wav.py with these parameters:
318    /// - 4 channels with frequencies: [440, 554.37, 659.25, 880] Hz
319    /// - Sample rate: 48000 Hz
320    /// - Duration: 1 second (48000 samples)
321    #[test]
322    fn test_sine_wave_data_integrity() {
323        const SAMPLE_RATE: f64 = 48000.0;
324        const N_SAMPLES: usize = 48000;
325        const FREQUENCIES: [f64; 4] = [440.0, 554.37, 659.25, 880.0];
326
327        let audio =
328            audio_read::<f32>("test_data/test_4ch.wav", AudioReadConfig::default()).unwrap();
329        let block = to_block(&audio);
330
331        assert_eq!(audio.sample_rate, 48000);
332        assert_eq!(block.num_frames(), N_SAMPLES);
333        assert_eq!(block.num_channels(), 4);
334
335        // Verify each channel contains the expected sine wave
336        for (ch, &freq) in FREQUENCIES.iter().enumerate() {
337            for frame in 0..N_SAMPLES {
338                let expected =
339                    (2.0 * std::f64::consts::PI * freq * frame as f64 / SAMPLE_RATE).sin() as f32;
340                let actual = block.sample(ch as u16, frame);
341                assert!(
342                    (actual - expected).abs() < 1e-4,
343                    "Mismatch at channel {ch}, frame {frame}: expected {expected}, got {actual}"
344                );
345            }
346        }
347
348        // Also verify reading with an offset works consistently
349        let audio = audio_read::<f32>(
350            "test_data/test_4ch.wav",
351            AudioReadConfig {
352                start: Position::Frame(24000),
353                stop: Position::Frame(24100),
354                ..Default::default()
355            },
356        )
357        .unwrap();
358        let block = to_block(&audio);
359
360        for (ch, &freq) in FREQUENCIES.iter().enumerate() {
361            for frame in 0..100 {
362                let actual_frame = 24000 + frame;
363                let expected = (2.0 * std::f64::consts::PI * freq * actual_frame as f64
364                    / SAMPLE_RATE)
365                    .sin() as f32;
366                let actual = block.sample(ch as u16, frame);
367                assert!(
368                    (actual - expected).abs() < 1e-4,
369                    "Offset mismatch at channel {ch}, frame {actual_frame}: expected {expected}, got {actual}"
370                );
371            }
372        }
373    }
374
375    #[test]
376    fn test_samples_selection() {
377        let audio1 =
378            audio_read::<f32>("test_data/test_1ch.wav", AudioReadConfig::default()).unwrap();
379        let block1 = to_block(&audio1);
380        assert_eq!(audio1.sample_rate, 48000);
381        assert_eq!(block1.num_frames(), 48000);
382        assert_eq!(block1.num_channels(), 1);
383
384        let audio2 = audio_read::<f32>(
385            "test_data/test_1ch.wav",
386            AudioReadConfig {
387                start: Position::Frame(1100),
388                stop: Position::Frame(1200),
389                ..Default::default()
390            },
391        )
392        .unwrap();
393        let block2 = to_block(&audio2);
394        assert_eq!(audio2.sample_rate, 48000);
395        assert_eq!(block2.num_frames(), 100);
396        assert_eq!(block2.num_channels(), 1);
397        assert_eq!(block1.raw_data()[1100..1200], block2.raw_data()[..]);
398    }
399
400    #[test]
401    fn test_time_selection() {
402        let audio1 =
403            audio_read::<f32>("test_data/test_1ch.wav", AudioReadConfig::default()).unwrap();
404        let block1 = to_block(&audio1);
405        assert_eq!(audio1.sample_rate, 48000);
406        assert_eq!(block1.num_frames(), 48000);
407        assert_eq!(block1.num_channels(), 1);
408
409        let audio2 = audio_read::<f32>(
410            "test_data/test_1ch.wav",
411            AudioReadConfig {
412                start: Position::Time(Duration::from_secs_f32(0.5)),
413                stop: Position::Time(Duration::from_secs_f32(0.6)),
414                ..Default::default()
415            },
416        )
417        .unwrap();
418        let block2 = to_block(&audio2);
419
420        assert_eq!(audio2.sample_rate, 48000);
421        assert_eq!(block2.num_frames(), 4800);
422        assert_eq!(block2.num_channels(), 1);
423        assert_eq!(block1.raw_data()[24000..28800], block2.raw_data()[..]);
424    }
425
426    #[test]
427    fn test_channel_selection() {
428        let audio1 =
429            audio_read::<f32>("test_data/test_4ch.wav", AudioReadConfig::default()).unwrap();
430        let block1 = to_block(&audio1);
431        assert_eq!(audio1.sample_rate, 48000);
432        assert_eq!(block1.num_frames(), 48000);
433        assert_eq!(block1.num_channels(), 4);
434
435        let audio2 = audio_read::<f32>(
436            "test_data/test_4ch.wav",
437            AudioReadConfig {
438                start_channel: Some(1),
439                num_channels: Some(2),
440                ..Default::default()
441            },
442        )
443        .unwrap();
444        let block2 = to_block(&audio2);
445
446        assert_eq!(audio2.sample_rate, 48000);
447        assert_eq!(block2.num_frames(), 48000);
448        assert_eq!(block2.num_channels(), 2);
449
450        // Verify we extracted channels 1 and 2 (skipping channel 0 and 3)
451        for frame in 0..10 {
452            assert_eq!(block2.sample(0, frame), block1.sample(1, frame));
453            assert_eq!(block2.sample(1, frame), block1.sample(2, frame));
454        }
455    }
456
457    #[test]
458    fn test_fail_selection() {
459        match audio_read::<f32>(
460            "test_data/test_1ch.wav",
461            AudioReadConfig {
462                start: Position::Frame(100),
463                stop: Position::Frame(99),
464                ..Default::default()
465            },
466        ) {
467            Err(AudioReadError::EndFrameLargerThanStartFrame(_, _)) => (),
468            _ => panic!(),
469        }
470
471        match audio_read::<f32>(
472            "test_data/test_1ch.wav",
473            AudioReadConfig {
474                start: Position::Time(Duration::from_secs_f32(0.6)),
475                stop: Position::Time(Duration::from_secs_f32(0.5)),
476                ..Default::default()
477            },
478        ) {
479            Err(AudioReadError::EndFrameLargerThanStartFrame(_, _)) => (),
480            _ => panic!(),
481        }
482
483        match audio_read::<f32>(
484            "test_data/test_1ch.wav",
485            AudioReadConfig {
486                start_channel: Some(1),
487                ..Default::default()
488            },
489        ) {
490            Err(AudioReadError::InvalidStartChannel(_, _)) => (),
491            _ => panic!(),
492        }
493
494        match audio_read::<f32>(
495            "test_data/test_1ch.wav",
496            AudioReadConfig {
497                num_channels: Some(0),
498                ..Default::default()
499            },
500        ) {
501            Err(AudioReadError::InvalidNumChannels(0)) => (),
502            _ => panic!(),
503        }
504
505        match audio_read::<f32>(
506            "test_data/test_1ch.wav",
507            AudioReadConfig {
508                num_channels: Some(2),
509                ..Default::default()
510            },
511        ) {
512            Err(AudioReadError::InvalidNumChannels(2)) => (),
513            _ => panic!(),
514        }
515    }
516
517    #[test]
518    fn test_resample_preserves_frequency() {
519        const FREQUENCIES: [f64; 4] = [440.0, 554.37, 659.25, 880.0];
520        let sr_out: u32 = 22050;
521
522        // Read and resample in one step
523        let audio = audio_read::<f32>(
524            "test_data/test_4ch.wav",
525            AudioReadConfig {
526                sample_rate: Some(sr_out),
527                ..Default::default()
528            },
529        )
530        .unwrap();
531        let block = to_block(&audio);
532
533        assert_eq!(audio.sample_rate, sr_out); // Resampled sample rate is returned
534        assert_eq!(block.num_channels(), 4);
535
536        // Expected frames after resampling: 48000 * (22050/48000) = 22050
537        let expected_frames = 22050;
538        assert_eq!(
539            block.num_frames(),
540            expected_frames,
541            "Expected {} frames, got {}",
542            expected_frames,
543            block.num_frames()
544        );
545
546        // Verify sine wave frequencies are preserved after resampling
547        // Skip first ~100 samples to avoid any edge effects from resampling
548        let start_frame = 100;
549        let test_frames = 1000;
550
551        for (ch, &freq) in FREQUENCIES.iter().enumerate() {
552            let mut max_error: f32 = 0.0;
553            for frame in start_frame..(start_frame + test_frames) {
554                let expected =
555                    (2.0 * std::f64::consts::PI * freq * frame as f64 / sr_out as f64).sin() as f32;
556                let actual = block.sample(ch as u16, frame);
557                let error = (actual - expected).abs();
558                max_error = max_error.max(error);
559            }
560            assert!(
561                max_error < 0.02,
562                "Channel {} ({}Hz): max error {} exceeds threshold",
563                ch,
564                freq,
565                max_error
566            );
567        }
568    }
569
570    #[test]
571    fn test_channel_selection_with_resampling() {
572        // This test verifies that channel selection combined with resampling works correctly
573        const FREQUENCIES: [f64; 4] = [440.0, 554.37, 659.25, 880.0];
574        let sr_out: u32 = 22050;
575
576        // Read channels 1 and 2 (indices 1 and 2) with resampling
577        let audio = audio_read::<f32>(
578            "test_data/test_4ch.wav",
579            AudioReadConfig {
580                start_channel: Some(1),
581                num_channels: Some(2),
582                sample_rate: Some(sr_out),
583                ..Default::default()
584            },
585        )
586        .unwrap();
587        let block = to_block(&audio);
588
589        assert_eq!(audio.num_channels, 2, "Should have 2 channels");
590        assert_eq!(
591            audio.sample_rate, sr_out,
592            "Sample rate should be the resampled rate"
593        );
594
595        // Expected frames after resampling: 48000 * (22050/48000) = 22050
596        let expected_frames = 22050;
597        assert_eq!(
598            block.num_frames(),
599            expected_frames,
600            "Expected {} frames, got {}",
601            expected_frames,
602            block.num_frames()
603        );
604
605        // Verify that the resampled audio contains the correct frequencies
606        // Channels 1 and 2 should have frequencies 554.37 Hz and 659.25 Hz
607        let selected_freqs = &FREQUENCIES[1..3];
608
609        let start_frame = 100;
610        let test_frames = 1000;
611
612        for (ch, &freq) in selected_freqs.iter().enumerate() {
613            let mut max_error: f32 = 0.0;
614            for frame in start_frame..(start_frame + test_frames) {
615                let expected =
616                    (2.0 * std::f64::consts::PI * freq * frame as f64 / sr_out as f64).sin() as f32;
617                let actual = block.sample(ch as u16, frame);
618                let error = (actual - expected).abs();
619                max_error = max_error.max(error);
620            }
621            assert!(
622                max_error < 0.02,
623                "Channel {} ({}Hz): max error {} exceeds threshold",
624                ch,
625                freq,
626                max_error
627            );
628        }
629    }
630}