audio_io/
reader.rs

1use std::fs::File;
2use std::path::Path;
3
4use audio_blocks::AudioBlockInterleavedView;
5use num::Float;
6use symphonia::core::audio::SampleBuffer;
7use symphonia::core::codecs::{CODEC_TYPE_NULL, DecoderOptions};
8use symphonia::core::errors::Error;
9use symphonia::core::formats::{FormatOptions, SeekMode, SeekTo};
10use symphonia::core::io::MediaSourceStream;
11use symphonia::core::meta::MetadataOptions;
12use symphonia::core::probe::Hint;
13use thiserror::Error;
14
15#[derive(Debug, Error)]
16pub enum AudioReadError {
17    #[error("could not read file")]
18    FileError(#[from] std::io::Error),
19    #[error("could not decode audio")]
20    EncodingError(#[from] symphonia::core::errors::Error),
21    #[error("could not find track in file")]
22    NoTrack,
23    #[error("could not find sample rate in file")]
24    NoSampleRate,
25    #[error("end frame {0} is larger than start frame {1}")]
26    EndFrameLargerThanStartFrame(usize, usize),
27    #[error("start channel {0} invalid, audio file has only {1} channels")]
28    InvalidStartChannel(usize, usize),
29    #[error("invalid number of channels to extract: {0}")]
30    InvalidNumChannels(usize),
31}
32
33/// Position in the audio stream (for start or stop points)
34#[derive(Default, Debug, Clone, Copy)]
35pub enum Position {
36    /// Start from beginning or read until the end (depending on context)
37    #[default]
38    Default,
39    /// Specific time offset
40    Time(std::time::Duration),
41    /// Specific frame number (sample position across all channels)
42    Frame(usize),
43}
44
45#[derive(Default)]
46pub struct AudioReadConfig {
47    /// Where to start reading audio (time or frame-based)
48    pub start: Position,
49    /// Where to stop reading audio (time or frame-based)
50    pub stop: Position,
51    /// Starting channel to extract (0-indexed). None means start from channel 0.
52    pub start_channel: Option<usize>,
53    /// Number of channels to extract. None means extract all remaining channels.
54    pub num_channels: Option<usize>,
55}
56
57#[derive(Default)]
58pub struct AudioData<F: Float + 'static> {
59    pub interleaved_samples: Vec<F>,
60    pub sample_rate: u32,
61    pub num_channels: usize,
62    pub num_frames: usize,
63}
64
65impl<F: Float> AudioData<F> {
66    // Convert into audio block, which makes it easy to access
67    // channels and frames or convert into any other layout.
68    // See [audio-blocks](https://crates.io/crates/audio-blocks) for more info.
69    //
70    // Does not allocate or copy memory!
71    pub fn audio_block(&self) -> AudioBlockInterleavedView<'_, F> {
72        AudioBlockInterleavedView::from_slice(
73            &self.interleaved_samples,
74            self.num_channels as u16,
75            self.num_frames,
76        )
77    }
78}
79
80pub fn audio_read<P: AsRef<Path>, F: Float>(
81    path: P,
82    config: AudioReadConfig,
83) -> Result<AudioData<F>, AudioReadError> {
84    let src = File::open(path.as_ref())?;
85    let mss = MediaSourceStream::new(Box::new(src), Default::default());
86
87    let mut hint = Hint::new();
88    if let Some(ext) = path.as_ref().extension()
89        && let Some(ext_str) = ext.to_str()
90    {
91        hint.with_extension(ext_str);
92    }
93
94    let meta_opts: MetadataOptions = Default::default();
95    let fmt_opts: FormatOptions = Default::default();
96
97    let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
98
99    let mut format = probed.format;
100
101    let track = format
102        .tracks()
103        .iter()
104        .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
105        .ok_or(AudioReadError::NoTrack)?;
106
107    let sample_rate = track
108        .codec_params
109        .sample_rate
110        .ok_or(AudioReadError::NoSampleRate)?;
111
112    let track_id = track.id;
113
114    // Clone codec params before the mutable borrow
115    let codec_params = track.codec_params.clone();
116    let time_base = track.codec_params.time_base;
117
118    // Convert start/stop positions to frame numbers
119    let start_frame = match config.start {
120        Position::Default => 0,
121        Position::Time(duration) => {
122            let secs = duration.as_secs_f64();
123            (secs * sample_rate as f64) as usize
124        }
125        Position::Frame(frame) => frame,
126    };
127
128    let end_frame: Option<usize> = match config.stop {
129        Position::Default => None,
130        Position::Time(duration) => {
131            let secs = duration.as_secs_f64();
132            Some((secs * sample_rate as f64) as usize)
133        }
134        Position::Frame(frame) => Some(frame),
135    };
136
137    if let Some(end_frame) = end_frame
138        && start_frame > end_frame
139    {
140        return Err(AudioReadError::EndFrameLargerThanStartFrame(
141            end_frame,
142            start_frame,
143        ));
144    }
145
146    // If start_frame is large (more than 1 second), use seeking to avoid decoding everything
147    if start_frame > sample_rate as usize
148        && let Some(tb) = time_base
149    {
150        // Seek to 90% of the target to account for keyframe positioning
151        let seek_sample = (start_frame as f64 * 0.9) as u64;
152        let seek_ts = (seek_sample * tb.denom as u64) / (sample_rate as u64);
153
154        // Try to seek, but don't fail if seeking doesn't work
155        let _ = format.seek(
156            SeekMode::Accurate,
157            SeekTo::TimeStamp {
158                ts: seek_ts,
159                track_id,
160            },
161        );
162    }
163
164    let dec_opts: DecoderOptions = Default::default();
165    let mut decoder = symphonia::default::get_codecs().make(&codec_params, &dec_opts)?;
166
167    let mut sample_buf = None;
168    let mut samples = Vec::new();
169    let mut num_channels = 0usize;
170    let start_channel = config.start_channel;
171
172    // We'll track exact position by counting samples as we decode
173    let mut current_sample: Option<u64> = None;
174
175    loop {
176        let packet = match format.next_packet() {
177            Ok(packet) => packet,
178            Err(Error::ResetRequired) => {
179                decoder.reset();
180                continue;
181            }
182            Err(Error::IoError(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
183                break;
184            }
185            Err(err) => return Err(err.into()),
186        };
187
188        if packet.track_id() != track_id {
189            continue;
190        }
191
192        let decoded = decoder.decode(&packet)?;
193
194        // Get the timestamp of this packet to know our position
195        if current_sample.is_none() {
196            let ts = packet.ts();
197            if let Some(tb) = time_base {
198                // Convert timestamp to sample position
199                current_sample = Some((ts * sample_rate as u64) / tb.denom as u64);
200            } else {
201                current_sample = Some(0);
202            }
203        }
204
205        if sample_buf.is_none() {
206            let spec = *decoded.spec();
207            let duration = decoded.capacity() as u64;
208            sample_buf = Some(SampleBuffer::<f32>::new(duration, spec));
209
210            // Get the number of channels from the spec
211            num_channels = spec.channels.count();
212
213            // Validate channel range
214            let ch_start = start_channel.unwrap_or(0);
215            let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
216
217            if ch_start >= num_channels {
218                return Err(AudioReadError::InvalidStartChannel(ch_start, num_channels));
219            }
220            if ch_count == 0 {
221                return Err(AudioReadError::InvalidNumChannels(0));
222            }
223            if ch_start + ch_count > num_channels {
224                return Err(AudioReadError::InvalidNumChannels(ch_count));
225            }
226        }
227
228        if let Some(buf) = &mut sample_buf {
229            buf.copy_interleaved_ref(decoded);
230            let packet_samples = buf.samples();
231
232            let mut pos = current_sample.unwrap_or(0);
233
234            // Determine channel range to extract
235            let ch_start = start_channel.unwrap_or(0);
236            let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
237            let ch_end = ch_start + ch_count;
238
239            // Calculate frames using the ORIGINAL channel count from the file
240            let frames = packet_samples.len() / num_channels;
241
242            // Process all frames, extracting only the requested channel range
243            for frame_idx in 0..frames {
244                // Check if we've reached the end frame
245                if let Some(end) = end_frame
246                    && pos >= end as u64
247                {
248                    let num_frames = samples.len() / ch_count;
249                    return Ok(AudioData {
250                        sample_rate,
251                        num_channels: ch_count,
252                        num_frames,
253                        interleaved_samples: samples,
254                    });
255                }
256
257                // Start collecting samples once we reach start_frame
258                if pos >= start_frame as u64 {
259                    // Extract the selected channel range from this frame
260                    // When ch_start=0 and ch_count=num_channels, this extracts all channels
261                    for ch in ch_start..ch_end {
262                        let sample_idx = frame_idx * num_channels + ch;
263                        samples.push(F::from(packet_samples[sample_idx]).unwrap());
264                    }
265                }
266
267                pos += 1;
268            }
269
270            // Update our position tracker
271            current_sample = Some(pos);
272        }
273    }
274
275    let ch_start = start_channel.unwrap_or(0);
276    let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
277    let num_frames = samples.len() / ch_count;
278
279    Ok(AudioData {
280        sample_rate,
281        num_channels: ch_count,
282        num_frames,
283        interleaved_samples: samples,
284    })
285}
286
287#[cfg(test)]
288mod tests {
289    use std::time::Duration;
290
291    use audio_blocks::AudioBlock;
292
293    use super::*;
294
295    #[test]
296    fn test_samples_selection() {
297        let data1: AudioData<f32> =
298            audio_read("test_data/test_1ch.wav", AudioReadConfig::default()).unwrap();
299        let block1 = data1.audio_block();
300        assert_eq!(data1.sample_rate, 48000);
301        assert_eq!(block1.num_frames(), 48000);
302        assert_eq!(block1.num_channels(), 1);
303
304        let data2: AudioData<f32> = audio_read(
305            "test_data/test_1ch.wav",
306            AudioReadConfig {
307                start: Position::Frame(1100),
308                stop: Position::Frame(1200),
309                ..Default::default()
310            },
311        )
312        .unwrap();
313        let block2 = data2.audio_block();
314        assert_eq!(data2.sample_rate, 48000);
315        assert_eq!(block2.num_frames(), 100);
316        assert_eq!(block2.num_channels(), 1);
317        assert_eq!(block1.raw_data()[1100..1200], block2.raw_data()[..]);
318    }
319
320    #[test]
321    fn test_time_selection() {
322        let data1: AudioData<f32> =
323            audio_read("test_data/test_1ch.wav", AudioReadConfig::default()).unwrap();
324        let block1 = data1.audio_block();
325        assert_eq!(data1.sample_rate, 48000);
326        assert_eq!(block1.num_frames(), 48000);
327        assert_eq!(block1.num_channels(), 1);
328
329        let data2: AudioData<f32> = audio_read(
330            "test_data/test_1ch.wav",
331            AudioReadConfig {
332                start: Position::Time(Duration::from_secs_f32(0.5)),
333                stop: Position::Time(Duration::from_secs_f32(0.6)),
334                ..Default::default()
335            },
336        )
337        .unwrap();
338
339        let block2 = data2.audio_block();
340        assert_eq!(data2.sample_rate, 48000);
341        assert_eq!(block2.num_frames(), 4800);
342        assert_eq!(block2.num_channels(), 1);
343        assert_eq!(block1.raw_data()[24000..28800], block2.raw_data()[..]);
344    }
345
346    #[test]
347    fn test_channel_selection() {
348        let data1: AudioData<f32> =
349            audio_read("test_data/test_4ch.wav", AudioReadConfig::default()).unwrap();
350        let block1 = data1.audio_block();
351        assert_eq!(data1.sample_rate, 48000);
352        assert_eq!(block1.num_frames(), 48000);
353        assert_eq!(block1.num_channels(), 4);
354
355        let data2: AudioData<f32> = audio_read(
356            "test_data/test_4ch.wav",
357            AudioReadConfig {
358                start_channel: Some(1),
359                num_channels: Some(2),
360                ..Default::default()
361            },
362        )
363        .unwrap();
364
365        let block2 = data2.audio_block();
366        assert_eq!(data2.sample_rate, 48000);
367        assert_eq!(block2.num_frames(), 48000);
368        assert_eq!(block2.num_channels(), 2);
369
370        // Verify we extracted channels 1 and 2 (skipping channel 0 and 3)
371        for frame in 0..10 {
372            assert_eq!(block2.sample(0, frame), block1.sample(1, frame));
373            assert_eq!(block2.sample(1, frame), block1.sample(2, frame));
374        }
375    }
376
377    #[test]
378    fn test_fail_selection() {
379        match audio_read::<_, f32>(
380            "test_data/test_1ch.wav",
381            AudioReadConfig {
382                start: Position::Frame(100),
383                stop: Position::Frame(99),
384                ..Default::default()
385            },
386        ) {
387            Err(AudioReadError::EndFrameLargerThanStartFrame(_, _)) => (),
388            _ => panic!(),
389        }
390
391        match audio_read::<_, f32>(
392            "test_data/test_1ch.wav",
393            AudioReadConfig {
394                start: Position::Time(Duration::from_secs_f32(0.6)),
395                stop: Position::Time(Duration::from_secs_f32(0.5)),
396                ..Default::default()
397            },
398        ) {
399            Err(AudioReadError::EndFrameLargerThanStartFrame(_, _)) => (),
400            _ => panic!(),
401        }
402
403        match audio_read::<_, f32>(
404            "test_data/test_1ch.wav",
405            AudioReadConfig {
406                start_channel: Some(1),
407                ..Default::default()
408            },
409        ) {
410            Err(AudioReadError::InvalidStartChannel(_, _)) => (),
411            _ => panic!(),
412        }
413
414        match audio_read::<_, f32>(
415            "test_data/test_1ch.wav",
416            AudioReadConfig {
417                num_channels: Some(0),
418                ..Default::default()
419            },
420        ) {
421            Err(AudioReadError::InvalidNumChannels(0)) => (),
422            _ => panic!(),
423        }
424
425        match audio_read::<_, f32>(
426            "test_data/test_1ch.wav",
427            AudioReadConfig {
428                num_channels: Some(2),
429                ..Default::default()
430            },
431        ) {
432            Err(AudioReadError::InvalidNumChannels(2)) => (),
433            _ => panic!(),
434        }
435    }
436}