Skip to main content

audio_io/
reader.rs

1use std::fs::File;
2use std::path::Path;
3
4use num::Float;
5use symphonia::core::audio::SampleBuffer;
6use symphonia::core::codecs::{CODEC_TYPE_NULL, DecoderOptions};
7use symphonia::core::errors::Error;
8use symphonia::core::formats::{FormatOptions, SeekMode, SeekTo};
9use symphonia::core::io::MediaSourceStream;
10use symphonia::core::meta::MetadataOptions;
11use symphonia::core::probe::Hint;
12use thiserror::Error;
13
14use crate::resample::resample;
15
16/// Audio data with interleaved samples
17#[derive(Debug, Clone)]
18pub struct Audio<F> {
19    /// Interleaved audio samples
20    pub samples_interleaved: Vec<F>,
21    /// Sample rate in Hz
22    pub sample_rate: u32,
23    /// Number of channels
24    pub num_channels: u16,
25}
26
27#[derive(Debug, Error)]
28pub enum AudioReadError {
29    #[error("could not read file")]
30    FileError(#[from] std::io::Error),
31    #[error("could not decode audio")]
32    EncodingError(#[from] symphonia::core::errors::Error),
33    #[error("could not find track in file")]
34    NoTrack,
35    #[error("could not find sample rate in file")]
36    NoSampleRate,
37    #[error("end frame {0} is larger than start frame {1}")]
38    EndFrameLargerThanStartFrame(usize, usize),
39    #[error("start channel {0} invalid, audio file has only {1} channels")]
40    InvalidStartChannel(usize, usize),
41    #[error("invalid number of channels to extract: {0}")]
42    InvalidNumChannels(usize),
43}
44
45/// Position in the audio stream (for start or stop points)
46#[derive(Default, Debug, Clone, Copy)]
47pub enum Position {
48    /// Start from beginning or read until the end (depending on context)
49    #[default]
50    Default,
51    /// Specific time offset
52    Time(std::time::Duration),
53    /// Specific frame number (sample position across all channels)
54    Frame(usize),
55}
56
57#[derive(Default)]
58pub struct AudioReadConfig {
59    /// Where to start reading audio (time or frame-based)
60    pub start: Position,
61    /// Where to stop reading audio (time or frame-based)
62    pub stop: Position,
63    /// Starting channel to extract (0-indexed). None means start from channel 0.
64    pub start_channel: Option<usize>,
65    /// Number of channels to extract. None means extract all remaining channels.
66    pub num_channels: Option<usize>,
67    /// If specified the audio will be resampled to the given sample rate
68    pub sample_rate: Option<u32>,
69}
70
71pub fn audio_read<F: Float + rubato::Sample>(
72    path: impl AsRef<Path>,
73    config: AudioReadConfig,
74) -> Result<Audio<F>, AudioReadError> {
75    let src = File::open(path.as_ref())?;
76    let mss = MediaSourceStream::new(Box::new(src), Default::default());
77
78    let mut hint = Hint::new();
79    if let Some(ext) = path.as_ref().extension()
80        && let Some(ext_str) = ext.to_str()
81    {
82        hint.with_extension(ext_str);
83    }
84
85    let meta_opts: MetadataOptions = Default::default();
86    let fmt_opts: FormatOptions = Default::default();
87
88    let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
89
90    let mut format = probed.format;
91
92    let track = format
93        .tracks()
94        .iter()
95        .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
96        .ok_or(AudioReadError::NoTrack)?;
97
98    let sample_rate = track
99        .codec_params
100        .sample_rate
101        .ok_or(AudioReadError::NoSampleRate)?;
102
103    let track_id = track.id;
104
105    // Clone codec params before the mutable borrow
106    let codec_params = track.codec_params.clone();
107    let time_base = track.codec_params.time_base;
108
109    // Convert start/stop positions to frame numbers
110    let start_frame = match config.start {
111        Position::Default => 0,
112        Position::Time(duration) => {
113            let secs = duration.as_secs_f64();
114            (secs * sample_rate as f64) as usize
115        }
116        Position::Frame(frame) => frame,
117    };
118
119    let end_frame: Option<usize> = match config.stop {
120        Position::Default => None,
121        Position::Time(duration) => {
122            let secs = duration.as_secs_f64();
123            Some((secs * sample_rate as f64) as usize)
124        }
125        Position::Frame(frame) => Some(frame),
126    };
127
128    if let Some(end_frame) = end_frame
129        && start_frame > end_frame
130    {
131        return Err(AudioReadError::EndFrameLargerThanStartFrame(
132            end_frame,
133            start_frame,
134        ));
135    }
136
137    // If start_frame is large (more than 1 second), use seeking to avoid decoding everything
138    if start_frame > sample_rate as usize
139        && let Some(tb) = time_base
140    {
141        // Seek to 90% of the target to account for keyframe positioning
142        let seek_sample = (start_frame as f64 * 0.9) as u64;
143        let seek_ts = (seek_sample * tb.denom as u64) / (sample_rate as u64);
144
145        // Try to seek, but don't fail if seeking doesn't work
146        let _ = format.seek(
147            SeekMode::Accurate,
148            SeekTo::TimeStamp {
149                ts: seek_ts,
150                track_id,
151            },
152        );
153    }
154
155    let dec_opts: DecoderOptions = Default::default();
156    let mut decoder = symphonia::default::get_codecs().make(&codec_params, &dec_opts)?;
157
158    let mut sample_buf = None;
159    let mut samples = Vec::new();
160    let mut num_channels = 0usize;
161    let start_channel = config.start_channel;
162
163    // We'll track exact position by counting samples as we decode
164    let mut current_sample: Option<u64> = None;
165
166    loop {
167        let packet = match format.next_packet() {
168            Ok(packet) => packet,
169            Err(Error::ResetRequired) => {
170                decoder.reset();
171                continue;
172            }
173            Err(Error::IoError(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
174                break;
175            }
176            Err(err) => return Err(err.into()),
177        };
178
179        if packet.track_id() != track_id {
180            continue;
181        }
182
183        let decoded = decoder.decode(&packet)?;
184
185        // Get the timestamp of this packet to know our position
186        if current_sample.is_none() {
187            let ts = packet.ts();
188            if let Some(tb) = time_base {
189                // Convert timestamp to sample position
190                current_sample = Some((ts * sample_rate as u64) / tb.denom as u64);
191            } else {
192                current_sample = Some(0);
193            }
194        }
195
196        if sample_buf.is_none() {
197            let spec = *decoded.spec();
198            let duration = decoded.capacity() as u64;
199            sample_buf = Some(SampleBuffer::<f32>::new(duration, spec));
200
201            // Get the number of channels from the spec
202            num_channels = spec.channels.count();
203
204            // Validate channel range
205            let ch_start = start_channel.unwrap_or(0);
206            let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
207
208            if ch_start >= num_channels {
209                return Err(AudioReadError::InvalidStartChannel(ch_start, num_channels));
210            }
211            if ch_count == 0 {
212                return Err(AudioReadError::InvalidNumChannels(0));
213            }
214            if ch_start + ch_count > num_channels {
215                return Err(AudioReadError::InvalidNumChannels(ch_count));
216            }
217        }
218
219        if let Some(buf) = &mut sample_buf {
220            buf.copy_interleaved_ref(decoded);
221            let packet_samples = buf.samples();
222
223            let mut pos = current_sample.unwrap_or(0);
224
225            // Determine channel range to extract
226            let ch_start = start_channel.unwrap_or(0);
227            let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
228            let ch_end = ch_start + ch_count;
229
230            // Calculate frames using the ORIGINAL channel count from the file
231            let frames = packet_samples.len() / num_channels;
232
233            // Process all frames, extracting only the requested channel range
234            for frame_idx in 0..frames {
235                // Check if we've reached the end frame
236                if let Some(end) = end_frame
237                    && pos >= end as u64
238                {
239                    return Ok(Audio {
240                        samples_interleaved: samples,
241                        sample_rate,
242                        num_channels: ch_count as u16,
243                    });
244                }
245
246                // Start collecting samples once we reach start_frame
247                if pos >= start_frame as u64 {
248                    // Extract the selected channel range from this frame
249                    // When ch_start=0 and ch_count=num_channels, this extracts all channels
250                    for ch in ch_start..ch_end {
251                        let sample_idx = frame_idx * num_channels + ch;
252                        samples.push(F::from(packet_samples[sample_idx]).unwrap());
253                    }
254                }
255
256                pos += 1;
257            }
258
259            // Update our position tracker
260            current_sample = Some(pos);
261        }
262    }
263
264    let samples = if let Some(sr_out) = config.sample_rate {
265        resample(&samples, num_channels, sample_rate, sr_out).unwrap()
266    } else {
267        samples
268    };
269
270    let ch_start = start_channel.unwrap_or(0);
271    let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
272
273    Ok(Audio {
274        samples_interleaved: samples,
275        sample_rate,
276        num_channels: ch_count as u16,
277    })
278}
279
280#[cfg(feature = "audio-blocks")]
281pub fn audio_read_block<F: num::Float + 'static + rubato::Sample>(
282    path: impl AsRef<Path>,
283    config: AudioReadConfig,
284) -> Result<(audio_blocks::AudioBlockInterleaved<F>, u32), AudioReadError> {
285    let audio = audio_read(path, config)?;
286    Ok((
287        audio_blocks::AudioBlockInterleaved::from_slice(
288            &audio.samples_interleaved,
289            audio.num_channels,
290        ),
291        audio.sample_rate,
292    ))
293}
294
295#[cfg(test)]
296mod tests {
297    use std::time::Duration;
298
299    use audio_blocks::{AudioBlock, AudioBlockInterleavedView};
300
301    use super::*;
302
303    fn to_block<F: num::Float + 'static>(audio: &Audio<F>) -> AudioBlockInterleavedView<'_, F> {
304        AudioBlockInterleavedView::from_slice(&audio.samples_interleaved, audio.num_channels)
305    }
306
307    /// Verify that the read audio data matches the expected sine wave values.
308    /// The test file was generated by utils/generate_wav.py with these parameters:
309    /// - 4 channels with frequencies: [440, 554.37, 659.25, 880] Hz
310    /// - Sample rate: 48000 Hz
311    /// - Duration: 1 second (48000 samples)
312    #[test]
313    fn test_sine_wave_data_integrity() {
314        const SAMPLE_RATE: f64 = 48000.0;
315        const N_SAMPLES: usize = 48000;
316        const FREQUENCIES: [f64; 4] = [440.0, 554.37, 659.25, 880.0];
317
318        let audio =
319            audio_read::<f32>("test_data/test_4ch.wav", AudioReadConfig::default()).unwrap();
320        let block = to_block(&audio);
321
322        assert_eq!(audio.sample_rate, 48000);
323        assert_eq!(block.num_frames(), N_SAMPLES);
324        assert_eq!(block.num_channels(), 4);
325
326        // Verify each channel contains the expected sine wave
327        for (ch, &freq) in FREQUENCIES.iter().enumerate() {
328            for frame in 0..N_SAMPLES {
329                let expected =
330                    (2.0 * std::f64::consts::PI * freq * frame as f64 / SAMPLE_RATE).sin() as f32;
331                let actual = block.sample(ch as u16, frame);
332                assert!(
333                    (actual - expected).abs() < 1e-4,
334                    "Mismatch at channel {ch}, frame {frame}: expected {expected}, got {actual}"
335                );
336            }
337        }
338
339        // Also verify reading with an offset works consistently
340        let audio = audio_read::<f32>(
341            "test_data/test_4ch.wav",
342            AudioReadConfig {
343                start: Position::Frame(24000),
344                stop: Position::Frame(24100),
345                ..Default::default()
346            },
347        )
348        .unwrap();
349        let block = to_block(&audio);
350
351        for (ch, &freq) in FREQUENCIES.iter().enumerate() {
352            for frame in 0..100 {
353                let actual_frame = 24000 + frame;
354                let expected = (2.0 * std::f64::consts::PI * freq * actual_frame as f64
355                    / SAMPLE_RATE)
356                    .sin() as f32;
357                let actual = block.sample(ch as u16, frame);
358                assert!(
359                    (actual - expected).abs() < 1e-4,
360                    "Offset mismatch at channel {ch}, frame {actual_frame}: expected {expected}, got {actual}"
361                );
362            }
363        }
364    }
365
366    #[test]
367    fn test_samples_selection() {
368        let audio1 =
369            audio_read::<f32>("test_data/test_1ch.wav", AudioReadConfig::default()).unwrap();
370        let block1 = to_block(&audio1);
371        assert_eq!(audio1.sample_rate, 48000);
372        assert_eq!(block1.num_frames(), 48000);
373        assert_eq!(block1.num_channels(), 1);
374
375        let audio2 = audio_read::<f32>(
376            "test_data/test_1ch.wav",
377            AudioReadConfig {
378                start: Position::Frame(1100),
379                stop: Position::Frame(1200),
380                ..Default::default()
381            },
382        )
383        .unwrap();
384        let block2 = to_block(&audio2);
385        assert_eq!(audio2.sample_rate, 48000);
386        assert_eq!(block2.num_frames(), 100);
387        assert_eq!(block2.num_channels(), 1);
388        assert_eq!(block1.raw_data()[1100..1200], block2.raw_data()[..]);
389    }
390
391    #[test]
392    fn test_time_selection() {
393        let audio1 =
394            audio_read::<f32>("test_data/test_1ch.wav", AudioReadConfig::default()).unwrap();
395        let block1 = to_block(&audio1);
396        assert_eq!(audio1.sample_rate, 48000);
397        assert_eq!(block1.num_frames(), 48000);
398        assert_eq!(block1.num_channels(), 1);
399
400        let audio2 = audio_read::<f32>(
401            "test_data/test_1ch.wav",
402            AudioReadConfig {
403                start: Position::Time(Duration::from_secs_f32(0.5)),
404                stop: Position::Time(Duration::from_secs_f32(0.6)),
405                ..Default::default()
406            },
407        )
408        .unwrap();
409        let block2 = to_block(&audio2);
410
411        assert_eq!(audio2.sample_rate, 48000);
412        assert_eq!(block2.num_frames(), 4800);
413        assert_eq!(block2.num_channels(), 1);
414        assert_eq!(block1.raw_data()[24000..28800], block2.raw_data()[..]);
415    }
416
417    #[test]
418    fn test_channel_selection() {
419        let audio1 =
420            audio_read::<f32>("test_data/test_4ch.wav", AudioReadConfig::default()).unwrap();
421        let block1 = to_block(&audio1);
422        assert_eq!(audio1.sample_rate, 48000);
423        assert_eq!(block1.num_frames(), 48000);
424        assert_eq!(block1.num_channels(), 4);
425
426        let audio2 = audio_read::<f32>(
427            "test_data/test_4ch.wav",
428            AudioReadConfig {
429                start_channel: Some(1),
430                num_channels: Some(2),
431                ..Default::default()
432            },
433        )
434        .unwrap();
435        let block2 = to_block(&audio2);
436
437        assert_eq!(audio2.sample_rate, 48000);
438        assert_eq!(block2.num_frames(), 48000);
439        assert_eq!(block2.num_channels(), 2);
440
441        // Verify we extracted channels 1 and 2 (skipping channel 0 and 3)
442        for frame in 0..10 {
443            assert_eq!(block2.sample(0, frame), block1.sample(1, frame));
444            assert_eq!(block2.sample(1, frame), block1.sample(2, frame));
445        }
446    }
447
448    #[test]
449    fn test_fail_selection() {
450        match audio_read::<f32>(
451            "test_data/test_1ch.wav",
452            AudioReadConfig {
453                start: Position::Frame(100),
454                stop: Position::Frame(99),
455                ..Default::default()
456            },
457        ) {
458            Err(AudioReadError::EndFrameLargerThanStartFrame(_, _)) => (),
459            _ => panic!(),
460        }
461
462        match audio_read::<f32>(
463            "test_data/test_1ch.wav",
464            AudioReadConfig {
465                start: Position::Time(Duration::from_secs_f32(0.6)),
466                stop: Position::Time(Duration::from_secs_f32(0.5)),
467                ..Default::default()
468            },
469        ) {
470            Err(AudioReadError::EndFrameLargerThanStartFrame(_, _)) => (),
471            _ => panic!(),
472        }
473
474        match audio_read::<f32>(
475            "test_data/test_1ch.wav",
476            AudioReadConfig {
477                start_channel: Some(1),
478                ..Default::default()
479            },
480        ) {
481            Err(AudioReadError::InvalidStartChannel(_, _)) => (),
482            _ => panic!(),
483        }
484
485        match audio_read::<f32>(
486            "test_data/test_1ch.wav",
487            AudioReadConfig {
488                num_channels: Some(0),
489                ..Default::default()
490            },
491        ) {
492            Err(AudioReadError::InvalidNumChannels(0)) => (),
493            _ => panic!(),
494        }
495
496        match audio_read::<f32>(
497            "test_data/test_1ch.wav",
498            AudioReadConfig {
499                num_channels: Some(2),
500                ..Default::default()
501            },
502        ) {
503            Err(AudioReadError::InvalidNumChannels(2)) => (),
504            _ => panic!(),
505        }
506    }
507
508    #[test]
509    fn test_resample_preserves_frequency() {
510        const FREQUENCIES: [f64; 4] = [440.0, 554.37, 659.25, 880.0];
511        let sr_out: u32 = 22050;
512
513        // Read and resample in one step
514        let audio = audio_read::<f32>(
515            "test_data/test_4ch.wav",
516            AudioReadConfig {
517                sample_rate: Some(sr_out),
518                ..Default::default()
519            },
520        )
521        .unwrap();
522        let block = to_block(&audio);
523
524        assert_eq!(audio.sample_rate, 48000); // Original sample rate is preserved in metadata
525        assert_eq!(block.num_channels(), 4);
526
527        // Expected frames after resampling: 48000 * (22050/48000) = 22050
528        let expected_frames = 22050;
529        assert_eq!(
530            block.num_frames(),
531            expected_frames,
532            "Expected {} frames, got {}",
533            expected_frames,
534            block.num_frames()
535        );
536
537        // Verify sine wave frequencies are preserved after resampling
538        // Skip first ~100 samples to avoid any edge effects from resampling
539        let start_frame = 100;
540        let test_frames = 1000;
541
542        for (ch, &freq) in FREQUENCIES.iter().enumerate() {
543            let mut max_error: f32 = 0.0;
544            for frame in start_frame..(start_frame + test_frames) {
545                let expected =
546                    (2.0 * std::f64::consts::PI * freq * frame as f64 / sr_out as f64).sin() as f32;
547                let actual = block.sample(ch as u16, frame);
548                let error = (actual - expected).abs();
549                max_error = max_error.max(error);
550            }
551            assert!(
552                max_error < 0.02,
553                "Channel {} ({}Hz): max error {} exceeds threshold",
554                ch,
555                freq,
556                max_error
557            );
558        }
559    }
560}