Skip to main content

audio_file/
reader.rs

1use std::fs::File;
2use std::path::Path;
3
4use num::Float;
5use symphonia::core::audio::SampleBuffer;
6use symphonia::core::codecs::{CODEC_TYPE_NULL, DecoderOptions};
7use symphonia::core::errors::Error;
8use symphonia::core::formats::{FormatOptions, SeekMode, SeekTo};
9use symphonia::core::io::MediaSourceStream;
10use symphonia::core::meta::MetadataOptions;
11use symphonia::core::probe::Hint;
12use thiserror::Error;
13
14use crate::resample::{ResampleError, resample};
15
16/// Audio data with interleaved samples
17#[derive(Debug, Clone)]
18pub struct Audio<F> {
19    /// Interleaved audio samples
20    pub samples_interleaved: Vec<F>,
21    /// Sample rate in Hz
22    pub sample_rate: u32,
23    /// Number of channels
24    pub num_channels: u16,
25}
26
27#[derive(Debug, Error)]
28pub enum ReadError {
29    #[error("could not read file")]
30    Io(#[from] std::io::Error),
31
32    #[error("could not decode audio")]
33    Decode(#[from] symphonia::core::errors::Error),
34
35    #[error("no track found")]
36    NoTrack,
37
38    #[error("no sample rate found")]
39    NoSampleRate,
40
41    #[error("end frame ({end}) must not exceed start frame ({start})")]
42    InvalidFrameRange { start: usize, end: usize },
43
44    #[error("start channel {index} out of bounds (file has {total} channels)")]
45    InvalidChannel { index: usize, total: usize },
46
47    #[error("invalid channel count: {0}")]
48    InvalidChannelCount(usize),
49
50    #[error("resample failed")]
51    Resample(#[from] ResampleError),
52}
53
54/// Position in the audio stream (for start or stop points)
55#[derive(Default, Debug, Clone, Copy)]
56pub enum Position {
57    /// Start from beginning or read until the end (depending on context)
58    #[default]
59    Default,
60    /// Specific time offset
61    Time(std::time::Duration),
62    /// Specific frame number (sample position across all channels)
63    Frame(usize),
64}
65
66#[derive(Default)]
67pub struct ReadConfig {
68    /// Where to start reading audio (time or frame-based)
69    pub start: Position,
70    /// Where to stop reading audio (time or frame-based)
71    pub stop: Position,
72    /// Starting channel to extract (0-indexed). None means start from channel 0.
73    pub start_channel: Option<usize>,
74    /// Number of channels to extract. None means extract all remaining channels.
75    pub num_channels: Option<usize>,
76    /// If specified the audio will be resampled to the given sample rate
77    pub sample_rate: Option<u32>,
78}
79
80pub fn read<F: Float + rubato::Sample>(
81    path: impl AsRef<Path>,
82    config: ReadConfig,
83) -> Result<Audio<F>, ReadError> {
84    let src = File::open(path.as_ref())?;
85    let mss = MediaSourceStream::new(Box::new(src), Default::default());
86
87    let mut hint = Hint::new();
88    if let Some(ext) = path.as_ref().extension()
89        && let Some(ext_str) = ext.to_str()
90    {
91        hint.with_extension(ext_str);
92    }
93
94    let meta_opts: MetadataOptions = Default::default();
95    let fmt_opts: FormatOptions = Default::default();
96
97    let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
98
99    let mut format = probed.format;
100
101    let track = format
102        .tracks()
103        .iter()
104        .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
105        .ok_or(ReadError::NoTrack)?;
106
107    let sample_rate = track
108        .codec_params
109        .sample_rate
110        .ok_or(ReadError::NoSampleRate)?;
111
112    let track_id = track.id;
113
114    // Clone codec params before the mutable borrow
115    let codec_params = track.codec_params.clone();
116    let time_base = track.codec_params.time_base;
117
118    // Convert start/stop positions to frame numbers
119    let start_frame = match config.start {
120        Position::Default => 0,
121        Position::Time(duration) => {
122            let secs = duration.as_secs_f64();
123            (secs * sample_rate as f64) as usize
124        }
125        Position::Frame(frame) => frame,
126    };
127
128    let end_frame: Option<usize> = match config.stop {
129        Position::Default => None,
130        Position::Time(duration) => {
131            let secs = duration.as_secs_f64();
132            Some((secs * sample_rate as f64) as usize)
133        }
134        Position::Frame(frame) => Some(frame),
135    };
136
137    if let Some(end_frame) = end_frame
138        && start_frame > end_frame
139    {
140        return Err(ReadError::InvalidFrameRange {
141            start: start_frame,
142            end: end_frame,
143        });
144    }
145
146    // Optimization: Use seeking for large offsets to avoid decoding unnecessary data.
147    // For small offsets (< 1 second), we decode from the beginning and discard samples,
148    // which is simpler and avoids seek complexity. This threshold balances simplicity
149    // with performance - seeking has overhead and keyframe alignment issues that make
150    // it inefficient for small offsets.
151    if start_frame > sample_rate as usize
152        && let Some(tb) = time_base
153    {
154        // Seek to 90% of the target to account for keyframe positioning
155        let seek_sample = (start_frame as f64 * 0.9) as u64;
156        let seek_ts = (seek_sample * tb.denom as u64) / (sample_rate as u64);
157
158        // Try to seek, but don't fail if seeking doesn't work
159        let _ = format.seek(
160            SeekMode::Accurate,
161            SeekTo::TimeStamp {
162                ts: seek_ts,
163                track_id,
164            },
165        );
166    }
167
168    let dec_opts: DecoderOptions = Default::default();
169    let mut decoder = symphonia::default::get_codecs().make(&codec_params, &dec_opts)?;
170
171    let mut sample_buf = None;
172    let mut samples = Vec::new();
173    let mut num_channels = 0usize;
174    let start_channel = config.start_channel;
175
176    // We'll track exact position by counting samples as we decode
177    let mut current_sample: Option<u64> = None;
178
179    loop {
180        let packet = match format.next_packet() {
181            Ok(packet) => packet,
182            Err(Error::ResetRequired) => {
183                decoder.reset();
184                continue;
185            }
186            Err(Error::IoError(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
187                break;
188            }
189            Err(err) => return Err(err.into()),
190        };
191
192        if packet.track_id() != track_id {
193            continue;
194        }
195
196        let decoded = decoder.decode(&packet)?;
197
198        // Get the timestamp of this packet to know our position
199        if current_sample.is_none() {
200            let ts = packet.ts();
201            if let Some(tb) = time_base {
202                // Convert timestamp to sample position
203                current_sample = Some((ts * sample_rate as u64) / tb.denom as u64);
204            } else {
205                current_sample = Some(0);
206            }
207        }
208
209        if sample_buf.is_none() {
210            let spec = *decoded.spec();
211            let duration = decoded.capacity() as u64;
212            sample_buf = Some(SampleBuffer::<f32>::new(duration, spec));
213
214            // Get the number of channels from the spec
215            num_channels = spec.channels.count();
216
217            // Validate channel range
218            let ch_start = start_channel.unwrap_or(0);
219            let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
220
221            if ch_start >= num_channels {
222                return Err(ReadError::InvalidChannel {
223                    index: ch_start,
224                    total: num_channels,
225                });
226            }
227            if ch_count == 0 {
228                return Err(ReadError::InvalidChannelCount(0));
229            }
230            if ch_start + ch_count > num_channels {
231                return Err(ReadError::InvalidChannelCount(ch_count));
232            }
233        }
234
235        if let Some(buf) = &mut sample_buf {
236            buf.copy_interleaved_ref(decoded);
237            let packet_samples = buf.samples();
238
239            let mut pos = current_sample.unwrap_or(0);
240
241            // Determine channel range to extract
242            let ch_start = start_channel.unwrap_or(0);
243            let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
244            let ch_end = ch_start + ch_count;
245
246            // Calculate frames using the ORIGINAL channel count from the file
247            let frames = packet_samples.len() / num_channels;
248
249            // Process all frames, extracting only the requested channel range
250            for frame_idx in 0..frames {
251                // Check if we've reached the end frame
252                if let Some(end) = end_frame
253                    && pos >= end as u64
254                {
255                    return Ok(Audio {
256                        samples_interleaved: samples,
257                        sample_rate,
258                        num_channels: ch_count as u16,
259                    });
260                }
261
262                // Start collecting samples once we reach start_frame
263                if pos >= start_frame as u64 {
264                    // Extract the selected channel range from this frame
265                    // When ch_start=0 and ch_count=num_channels, this extracts all channels
266                    for ch in ch_start..ch_end {
267                        let sample_idx = frame_idx * num_channels + ch;
268                        samples.push(F::from(packet_samples[sample_idx]).unwrap());
269                    }
270                }
271
272                pos += 1;
273            }
274
275            // Update our position tracker
276            current_sample = Some(pos);
277        }
278    }
279
280    // Calculate the actual channel count in the extracted samples
281    let ch_start = start_channel.unwrap_or(0);
282    let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
283
284    let samples = if let Some(sr_out) = config.sample_rate
285        && sr_out != sample_rate
286    {
287        // Use ch_count (the selected channels) not num_channels (original file channels)
288        resample(&samples, ch_count, sample_rate, sr_out)?
289    } else {
290        samples
291    };
292
293    // Return the actual sample rate (resampled if applicable, otherwise original)
294    let actual_sample_rate = config.sample_rate.unwrap_or(sample_rate);
295
296    Ok(Audio {
297        samples_interleaved: samples,
298        sample_rate: actual_sample_rate,
299        num_channels: ch_count as u16,
300    })
301}
302
303#[cfg(feature = "audio-blocks")]
304pub fn read_block<F: num::Float + 'static + rubato::Sample>(
305    path: impl AsRef<Path>,
306    config: ReadConfig,
307) -> Result<(audio_blocks::Interleaved<F>, u32), ReadError> {
308    let audio = read(path, config)?;
309    Ok((
310        audio_blocks::Interleaved::from_slice(&audio.samples_interleaved, audio.num_channels),
311        audio.sample_rate,
312    ))
313}
314
315#[cfg(test)]
316mod tests {
317    use std::time::Duration;
318
319    use audio_blocks::{AudioBlock, InterleavedView};
320
321    use super::*;
322
323    fn to_block<F: num::Float + 'static>(audio: &Audio<F>) -> InterleavedView<'_, F> {
324        InterleavedView::from_slice(&audio.samples_interleaved, audio.num_channels)
325    }
326
327    /// Verify that the read audio data matches the expected sine wave values.
328    /// The test file was generated by utils/generate_wav.py with these parameters:
329    /// - 4 channels with frequencies: [440, 554.37, 659.25, 880] Hz
330    /// - Sample rate: 48000 Hz
331    /// - Duration: 1 second (48000 samples)
332    #[test]
333    fn test_sine_wave_data_integrity() {
334        const SAMPLE_RATE: f64 = 48000.0;
335        const N_SAMPLES: usize = 48000;
336        const FREQUENCIES: [f64; 4] = [440.0, 554.37, 659.25, 880.0];
337
338        let audio = read::<f32>("test_data/test_4ch.wav", ReadConfig::default()).unwrap();
339        let block = to_block(&audio);
340
341        assert_eq!(audio.sample_rate, 48000);
342        assert_eq!(block.num_frames(), N_SAMPLES);
343        assert_eq!(block.num_channels(), 4);
344
345        // Verify each channel contains the expected sine wave
346        for (ch, &freq) in FREQUENCIES.iter().enumerate() {
347            for frame in 0..N_SAMPLES {
348                let expected =
349                    (2.0 * std::f64::consts::PI * freq * frame as f64 / SAMPLE_RATE).sin() as f32;
350                let actual = block.sample(ch as u16, frame);
351                assert!(
352                    (actual - expected).abs() < 1e-4,
353                    "Mismatch at channel {ch}, frame {frame}: expected {expected}, got {actual}"
354                );
355            }
356        }
357
358        // Also verify reading with an offset works consistently
359        let audio = read::<f32>(
360            "test_data/test_4ch.wav",
361            ReadConfig {
362                start: Position::Frame(24000),
363                stop: Position::Frame(24100),
364                ..Default::default()
365            },
366        )
367        .unwrap();
368        let block = to_block(&audio);
369
370        for (ch, &freq) in FREQUENCIES.iter().enumerate() {
371            for frame in 0..100 {
372                let actual_frame = 24000 + frame;
373                let expected = (2.0 * std::f64::consts::PI * freq * actual_frame as f64
374                    / SAMPLE_RATE)
375                    .sin() as f32;
376                let actual = block.sample(ch as u16, frame);
377                assert!(
378                    (actual - expected).abs() < 1e-4,
379                    "Offset mismatch at channel {ch}, frame {actual_frame}: expected {expected}, got {actual}"
380                );
381            }
382        }
383    }
384
385    #[test]
386    fn test_samples_selection() {
387        let audio1 = read::<f32>("test_data/test_1ch.wav", ReadConfig::default()).unwrap();
388        let block1 = to_block(&audio1);
389        assert_eq!(audio1.sample_rate, 48000);
390        assert_eq!(block1.num_frames(), 48000);
391        assert_eq!(block1.num_channels(), 1);
392
393        let audio2 = read::<f32>(
394            "test_data/test_1ch.wav",
395            ReadConfig {
396                start: Position::Frame(1100),
397                stop: Position::Frame(1200),
398                ..Default::default()
399            },
400        )
401        .unwrap();
402        let block2 = to_block(&audio2);
403        assert_eq!(audio2.sample_rate, 48000);
404        assert_eq!(block2.num_frames(), 100);
405        assert_eq!(block2.num_channels(), 1);
406        assert_eq!(block1.raw_data()[1100..1200], block2.raw_data()[..]);
407    }
408
409    #[test]
410    fn test_time_selection() {
411        let audio1 = read::<f32>("test_data/test_1ch.wav", ReadConfig::default()).unwrap();
412        let block1 = to_block(&audio1);
413        assert_eq!(audio1.sample_rate, 48000);
414        assert_eq!(block1.num_frames(), 48000);
415        assert_eq!(block1.num_channels(), 1);
416
417        let audio2 = read::<f32>(
418            "test_data/test_1ch.wav",
419            ReadConfig {
420                start: Position::Time(Duration::from_secs_f32(0.5)),
421                stop: Position::Time(Duration::from_secs_f32(0.6)),
422                ..Default::default()
423            },
424        )
425        .unwrap();
426        let block2 = to_block(&audio2);
427
428        assert_eq!(audio2.sample_rate, 48000);
429        assert_eq!(block2.num_frames(), 4800);
430        assert_eq!(block2.num_channels(), 1);
431        assert_eq!(block1.raw_data()[24000..28800], block2.raw_data()[..]);
432    }
433
434    #[test]
435    fn test_channel_selection() {
436        let audio1 = read::<f32>("test_data/test_4ch.wav", ReadConfig::default()).unwrap();
437        let block1 = to_block(&audio1);
438        assert_eq!(audio1.sample_rate, 48000);
439        assert_eq!(block1.num_frames(), 48000);
440        assert_eq!(block1.num_channels(), 4);
441
442        let audio2 = read::<f32>(
443            "test_data/test_4ch.wav",
444            ReadConfig {
445                start_channel: Some(1),
446                num_channels: Some(2),
447                ..Default::default()
448            },
449        )
450        .unwrap();
451        let block2 = to_block(&audio2);
452
453        assert_eq!(audio2.sample_rate, 48000);
454        assert_eq!(block2.num_frames(), 48000);
455        assert_eq!(block2.num_channels(), 2);
456
457        // Verify we extracted channels 1 and 2 (skipping channel 0 and 3)
458        for frame in 0..10 {
459            assert_eq!(block2.sample(0, frame), block1.sample(1, frame));
460            assert_eq!(block2.sample(1, frame), block1.sample(2, frame));
461        }
462    }
463
464    #[test]
465    fn test_fail_selection() {
466        match read::<f32>(
467            "test_data/test_1ch.wav",
468            ReadConfig {
469                start: Position::Frame(100),
470                stop: Position::Frame(99),
471                ..Default::default()
472            },
473        ) {
474            Err(ReadError::InvalidFrameRange { start: _, end: _ }) => (),
475            _ => panic!(),
476        }
477
478        match read::<f32>(
479            "test_data/test_1ch.wav",
480            ReadConfig {
481                start: Position::Time(Duration::from_secs_f32(0.6)),
482                stop: Position::Time(Duration::from_secs_f32(0.5)),
483                ..Default::default()
484            },
485        ) {
486            Err(ReadError::InvalidFrameRange { start: _, end: _ }) => (),
487            _ => panic!(),
488        }
489
490        match read::<f32>(
491            "test_data/test_1ch.wav",
492            ReadConfig {
493                start_channel: Some(1),
494                ..Default::default()
495            },
496        ) {
497            Err(ReadError::InvalidChannel { index: _, total: _ }) => (),
498            _ => panic!(),
499        }
500
501        match read::<f32>(
502            "test_data/test_1ch.wav",
503            ReadConfig {
504                num_channels: Some(0),
505                ..Default::default()
506            },
507        ) {
508            Err(ReadError::InvalidChannelCount(0)) => (),
509            _ => panic!(),
510        }
511
512        match read::<f32>(
513            "test_data/test_1ch.wav",
514            ReadConfig {
515                num_channels: Some(2),
516                ..Default::default()
517            },
518        ) {
519            Err(ReadError::InvalidChannelCount(2)) => (),
520            _ => panic!(),
521        }
522    }
523
524    #[test]
525    fn test_resample_preserves_frequency() {
526        const FREQUENCIES: [f64; 4] = [440.0, 554.37, 659.25, 880.0];
527        let sr_out: u32 = 22050;
528
529        // Read and resample in one step
530        let audio = read::<f32>(
531            "test_data/test_4ch.wav",
532            ReadConfig {
533                sample_rate: Some(sr_out),
534                ..Default::default()
535            },
536        )
537        .unwrap();
538        let block = to_block(&audio);
539
540        assert_eq!(audio.sample_rate, sr_out); // Resampled sample rate is returned
541        assert_eq!(block.num_channels(), 4);
542
543        // Expected frames after resampling: 48000 * (22050/48000) = 22050
544        let expected_frames = 22050;
545        assert_eq!(
546            block.num_frames(),
547            expected_frames,
548            "Expected {} frames, got {}",
549            expected_frames,
550            block.num_frames()
551        );
552
553        // Verify sine wave frequencies are preserved after resampling
554        // Skip first ~100 samples to avoid any edge effects from resampling
555        let start_frame = 100;
556        let test_frames = 1000;
557
558        for (ch, &freq) in FREQUENCIES.iter().enumerate() {
559            let mut max_error: f32 = 0.0;
560            for frame in start_frame..(start_frame + test_frames) {
561                let expected =
562                    (2.0 * std::f64::consts::PI * freq * frame as f64 / sr_out as f64).sin() as f32;
563                let actual = block.sample(ch as u16, frame);
564                let error = (actual - expected).abs();
565                max_error = max_error.max(error);
566            }
567            assert!(
568                max_error < 0.02,
569                "Channel {} ({}Hz): max error {} exceeds threshold",
570                ch,
571                freq,
572                max_error
573            );
574        }
575    }
576
577    #[test]
578    fn test_channel_selection_with_resampling() {
579        // This test verifies that channel selection combined with resampling works correctly
580        const FREQUENCIES: [f64; 4] = [440.0, 554.37, 659.25, 880.0];
581        let sr_out: u32 = 22050;
582
583        // Read channels 1 and 2 (indices 1 and 2) with resampling
584        let audio = read::<f32>(
585            "test_data/test_4ch.wav",
586            ReadConfig {
587                start_channel: Some(1),
588                num_channels: Some(2),
589                sample_rate: Some(sr_out),
590                ..Default::default()
591            },
592        )
593        .unwrap();
594        let block = to_block(&audio);
595
596        assert_eq!(audio.num_channels, 2, "Should have 2 channels");
597        assert_eq!(
598            audio.sample_rate, sr_out,
599            "Sample rate should be the resampled rate"
600        );
601
602        // Expected frames after resampling: 48000 * (22050/48000) = 22050
603        let expected_frames = 22050;
604        assert_eq!(
605            block.num_frames(),
606            expected_frames,
607            "Expected {} frames, got {}",
608            expected_frames,
609            block.num_frames()
610        );
611
612        // Verify that the resampled audio contains the correct frequencies
613        // Channels 1 and 2 should have frequencies 554.37 Hz and 659.25 Hz
614        let selected_freqs = &FREQUENCIES[1..3];
615
616        let start_frame = 100;
617        let test_frames = 1000;
618
619        for (ch, &freq) in selected_freqs.iter().enumerate() {
620            let mut max_error: f32 = 0.0;
621            for frame in start_frame..(start_frame + test_frames) {
622                let expected =
623                    (2.0 * std::f64::consts::PI * freq * frame as f64 / sr_out as f64).sin() as f32;
624                let actual = block.sample(ch as u16, frame);
625                let error = (actual - expected).abs();
626                max_error = max_error.max(error);
627            }
628            assert!(
629                max_error < 0.02,
630                "Channel {} ({}Hz): max error {} exceeds threshold",
631                ch,
632                freq,
633                max_error
634            );
635        }
636    }
637}