audio_io/
reader.rs

1use std::fs::File;
2use std::path::Path;
3
4use audio_blocks::AudioBlockInterleavedView;
5use num::Float;
6use symphonia::core::audio::SampleBuffer;
7use symphonia::core::codecs::{CODEC_TYPE_NULL, DecoderOptions};
8use symphonia::core::errors::Error;
9use symphonia::core::formats::{FormatOptions, SeekMode, SeekTo};
10use symphonia::core::io::MediaSourceStream;
11use symphonia::core::meta::MetadataOptions;
12use symphonia::core::probe::Hint;
13use thiserror::Error;
14
15#[derive(Debug, Error)]
16pub enum AudioReadError {
17    #[error("could not read file")]
18    FileError(#[from] std::io::Error),
19    #[error("could not decode audio")]
20    EncodingError(#[from] symphonia::core::errors::Error),
21    #[error("could not find track in file")]
22    NoTrack,
23    #[error("could not find sample rate in file")]
24    NoSampleRate,
25    #[error("end frame {0} is larger than start frame {1}")]
26    EndFrameLargerThanStartFrame(usize, usize),
27    #[error("start channel {0} invalid, audio file has only {1}")]
28    InvalidStartChannel(usize, usize),
29    #[error("end channel {0} invalid, audio file has only {1}")]
30    InvalidEndChannel(usize, usize),
31    #[error("end channel {0} is larger than start channel {1}")]
32    EndChannelLargerThanStartChannel(usize, usize),
33}
34
35/// Starting position in the audio stream
36#[derive(Debug, Clone, Copy, Default)]
37pub enum Start {
38    /// Start from the beginning of the audio
39    #[default]
40    Beginning,
41    /// Start at a specific time offset
42    Time(std::time::Duration),
43    /// Start at a specific frame number (sample position across all channels)
44    Frame(usize),
45}
46
47/// Ending position in the audio stream
48#[derive(Debug, Clone, Copy, Default)]
49pub enum Stop {
50    /// Read until the end of the audio
51    #[default]
52    End,
53    /// Stop at a specific time offset
54    Time(std::time::Duration),
55    /// Stop at a specific frame number (sample position across all channels)
56    Frame(usize),
57}
58
59#[derive(Default)]
60pub struct AudioReadConfig {
61    /// Where to start reading audio (time or frame-based)
62    pub start: Start,
63    /// Where to stop reading audio (time or frame-based)
64    pub stop: Stop,
65    /// First channel to extract (0-indexed). None means start from channel 0.
66    pub first_channel: Option<usize>,
67    /// Last channel to extract (exclusive). None means extract to the last channel.
68    pub last_channel: Option<usize>,
69}
70
71#[derive(Default)]
72pub struct AudioData<F: Float + 'static> {
73    pub interleaved_samples: Vec<F>,
74    pub sample_rate: u32,
75    pub num_channels: usize,
76    pub num_frames: usize,
77}
78
79impl<F: Float> AudioData<F> {
80    // Convert into audio block, which makes it easy to access
81    // channels and frames or convert into any other layout.
82    // See [audio-blocks](https://crates.io/crates/audio-blocks) for more info.
83    //
84    // Does not allocate or copy memory!
85    pub fn audio_block(&self) -> AudioBlockInterleavedView<'_, F> {
86        AudioBlockInterleavedView::from_slice(
87            &self.interleaved_samples,
88            self.num_channels as u16,
89            self.num_frames,
90        )
91    }
92}
93
94pub fn audio_read<P: AsRef<Path>, F: Float>(
95    path: P,
96    config: AudioReadConfig,
97) -> Result<AudioData<F>, AudioReadError> {
98    let src = File::open(path.as_ref())?;
99    let mss = MediaSourceStream::new(Box::new(src), Default::default());
100
101    let mut hint = Hint::new();
102    if let Some(ext) = path.as_ref().extension()
103        && let Some(ext_str) = ext.to_str()
104    {
105        hint.with_extension(ext_str);
106    }
107
108    let meta_opts: MetadataOptions = Default::default();
109    let fmt_opts: FormatOptions = Default::default();
110
111    let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
112
113    let mut format = probed.format;
114
115    let track = format
116        .tracks()
117        .iter()
118        .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
119        .ok_or(AudioReadError::NoTrack)?;
120
121    let sample_rate = track
122        .codec_params
123        .sample_rate
124        .ok_or(AudioReadError::NoSampleRate)?;
125
126    let track_id = track.id;
127
128    // Clone codec params before the mutable borrow
129    let codec_params = track.codec_params.clone();
130    let time_base = track.codec_params.time_base;
131
132    // Convert Start/Stop to frame numbers
133    let start_frame = match config.start {
134        Start::Beginning => 0,
135        Start::Time(duration) => {
136            let secs = duration.as_secs_f64();
137            (secs * sample_rate as f64) as usize
138        }
139        Start::Frame(frame) => frame,
140    };
141
142    let end_frame: Option<usize> = match config.stop {
143        Stop::End => None,
144        Stop::Time(duration) => {
145            let secs = duration.as_secs_f64();
146            Some((secs * sample_rate as f64) as usize)
147        }
148        Stop::Frame(frame) => Some(frame),
149    };
150
151    if let Some(end_frame) = end_frame
152        && start_frame > end_frame
153    {
154        return Err(AudioReadError::EndFrameLargerThanStartFrame(
155            end_frame,
156            start_frame,
157        ));
158    }
159
160    // If start_frame is large (more than 1 second), use seeking to avoid decoding everything
161    if start_frame > sample_rate as usize
162        && let Some(tb) = time_base
163    {
164        // Seek to 90% of the target to account for keyframe positioning
165        let seek_sample = (start_frame as f64 * 0.9) as u64;
166        let seek_ts = (seek_sample * tb.denom as u64) / (sample_rate as u64);
167
168        // Try to seek, but don't fail if seeking doesn't work
169        let _ = format.seek(
170            SeekMode::Accurate,
171            SeekTo::TimeStamp {
172                ts: seek_ts,
173                track_id,
174            },
175        );
176    }
177
178    let dec_opts: DecoderOptions = Default::default();
179    let mut decoder = symphonia::default::get_codecs().make(&codec_params, &dec_opts)?;
180
181    let mut sample_buf = None;
182    let mut samples = Vec::new();
183    let mut num_channels = 0usize;
184    let start_channel = config.first_channel;
185    let end_channel = config.last_channel;
186
187    // We'll track exact position by counting samples as we decode
188    let mut current_sample: Option<u64> = None;
189
190    loop {
191        let packet = match format.next_packet() {
192            Ok(packet) => packet,
193            Err(Error::ResetRequired) => {
194                decoder.reset();
195                continue;
196            }
197            Err(Error::IoError(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
198                break;
199            }
200            Err(err) => return Err(err.into()),
201        };
202
203        if packet.track_id() != track_id {
204            continue;
205        }
206
207        let decoded = decoder.decode(&packet)?;
208
209        // Get the timestamp of this packet to know our position
210        if current_sample.is_none() {
211            let ts = packet.ts();
212            if let Some(tb) = time_base {
213                // Convert timestamp to sample position
214                current_sample = Some((ts * sample_rate as u64) / tb.denom as u64);
215            } else {
216                current_sample = Some(0);
217            }
218        }
219
220        if sample_buf.is_none() {
221            let spec = *decoded.spec();
222            let duration = decoded.capacity() as u64;
223            sample_buf = Some(SampleBuffer::<f32>::new(duration, spec));
224
225            // Get the number of channels from the spec
226            num_channels = spec.channels.count();
227
228            // Validate channel range
229            if let Some(start_ch) = start_channel
230                && start_ch >= num_channels
231            {
232                return Err(AudioReadError::InvalidStartChannel(start_ch, num_channels));
233            }
234            if let Some(end_ch) = end_channel {
235                if end_ch > num_channels {
236                    return Err(AudioReadError::InvalidEndChannel(end_ch, num_channels));
237                }
238                if let Some(start_ch) = start_channel
239                    && end_ch <= start_ch
240                {
241                    return Err(AudioReadError::EndChannelLargerThanStartChannel(
242                        end_ch, start_ch,
243                    ));
244                }
245            }
246        }
247
248        if let Some(buf) = &mut sample_buf {
249            buf.copy_interleaved_ref(decoded);
250            let packet_samples = buf.samples();
251
252            let mut pos = current_sample.unwrap_or(0);
253
254            // Determine channel range to extract
255            let ch_start = start_channel.unwrap_or(0);
256            let ch_end = end_channel.unwrap_or(num_channels);
257            let num_channels = ch_end - ch_start;
258
259            // Process samples based on whether we're filtering channels
260            if ch_start != 0 || ch_end != num_channels {
261                // Channel filtering: samples are interleaved [L, R, L, R, ...] for stereo
262                // We need to extract only the requested channel range
263                let frames = packet_samples.len() / num_channels;
264
265                for frame_idx in 0..frames {
266                    // Check if we've reached the end frame
267                    if let Some(end) = end_frame
268                        && pos >= end as u64
269                    {
270                        let num_frames = samples.len() / num_channels;
271                        return Ok(AudioData {
272                            sample_rate,
273                            num_channels,
274                            num_frames,
275                            interleaved_samples: samples,
276                        });
277                    }
278
279                    // Start collecting samples once we reach start_frame
280                    if pos >= start_frame as u64 {
281                        // Extract only the selected channel range from this frame
282                        for ch in ch_start..ch_end {
283                            let sample_idx = frame_idx * num_channels + ch;
284                            samples.push(F::from(packet_samples[sample_idx]).unwrap());
285                        }
286                    }
287
288                    pos += 1;
289                }
290            } else {
291                // No channel filtering: collect all samples
292                let frames = packet_samples.len() / num_channels;
293
294                for frame_idx in 0..frames {
295                    // Check if we've reached the end frame
296                    if let Some(end) = end_frame
297                        && pos >= end as u64
298                    {
299                        let num_frames = samples.len() / num_channels;
300                        return Ok(AudioData {
301                            sample_rate,
302                            num_channels,
303                            num_frames,
304                            interleaved_samples: samples,
305                        });
306                    }
307
308                    // Start collecting samples once we reach start_frame
309                    if pos >= start_frame as u64 {
310                        // Collect all channels from this frame
311                        for ch in 0..num_channels {
312                            let sample_idx = frame_idx * num_channels + ch;
313                            samples.push(F::from(packet_samples[sample_idx]).unwrap());
314                        }
315                    }
316
317                    pos += 1;
318                }
319            }
320
321            // Update our position tracker
322            current_sample = Some(pos);
323        }
324    }
325
326    let ch_start = start_channel.unwrap_or(0);
327    let ch_end = end_channel.unwrap_or(num_channels);
328    let num_channels = ch_end - ch_start;
329    let num_frames = samples.len() / num_channels;
330
331    Ok(AudioData {
332        sample_rate,
333        num_channels,
334        num_frames,
335        interleaved_samples: samples,
336    })
337}
338
339#[cfg(test)]
340mod tests {
341    use std::time::Duration;
342
343    use audio_blocks::AudioBlock;
344
345    use super::*;
346
347    #[test]
348    fn test_samples_selection() {
349        let data1: AudioData<f32> = audio_read("test.wav", AudioReadConfig::default()).unwrap();
350        let block1 = data1.audio_block();
351        assert_eq!(data1.sample_rate, 48000);
352        assert_eq!(block1.num_frames(), 48000);
353        assert_eq!(block1.num_channels(), 1);
354
355        let data2: AudioData<f32> = audio_read(
356            "test.wav",
357            AudioReadConfig {
358                start: Start::Frame(1100),
359                stop: Stop::Frame(1200),
360                ..Default::default()
361            },
362        )
363        .unwrap();
364        let block2 = data2.audio_block();
365        assert_eq!(data2.sample_rate, 48000);
366        assert_eq!(block2.num_frames(), 100);
367        assert_eq!(block2.num_channels(), 1);
368        assert_eq!(block1.raw_data()[1100..1200], block2.raw_data()[..]);
369    }
370
371    #[test]
372    fn test_time_selection() {
373        let data1: AudioData<f32> = audio_read("test.wav", AudioReadConfig::default()).unwrap();
374        let block1 = data1.audio_block();
375        assert_eq!(data1.sample_rate, 48000);
376        assert_eq!(block1.num_frames(), 48000);
377        assert_eq!(block1.num_channels(), 1);
378
379        let data2: AudioData<f32> = audio_read(
380            "test.wav",
381            AudioReadConfig {
382                start: Start::Time(Duration::from_secs_f32(0.5)),
383                stop: Stop::Time(Duration::from_secs_f32(0.6)),
384                ..Default::default()
385            },
386        )
387        .unwrap();
388
389        let block2 = data2.audio_block();
390        assert_eq!(data2.sample_rate, 48000);
391        assert_eq!(block2.num_frames(), 4800);
392        assert_eq!(block2.num_channels(), 1);
393        assert_eq!(block1.raw_data()[24000..28800], block2.raw_data()[..]);
394    }
395
396    #[test]
397    fn test_fail_selection() {
398        match audio_read::<_, f32>(
399            "test.wav",
400            AudioReadConfig {
401                start: Start::Frame(100),
402                stop: Stop::Frame(99),
403                ..Default::default()
404            },
405        ) {
406            Err(AudioReadError::EndFrameLargerThanStartFrame(_, _)) => (),
407            _ => panic!(),
408        }
409
410        match audio_read::<_, f32>(
411            "test.wav",
412            AudioReadConfig {
413                start: Start::Time(Duration::from_secs_f32(0.6)),
414                stop: Stop::Time(Duration::from_secs_f32(0.5)),
415                ..Default::default()
416            },
417        ) {
418            Err(AudioReadError::EndFrameLargerThanStartFrame(_, _)) => (),
419            _ => panic!(),
420        }
421
422        match audio_read::<_, f32>(
423            "test.wav",
424            AudioReadConfig {
425                first_channel: Some(1),
426                ..Default::default()
427            },
428        ) {
429            Err(AudioReadError::InvalidStartChannel(_, _)) => (),
430            _ => panic!(),
431        }
432
433        match audio_read::<_, f32>(
434            "test.wav",
435            AudioReadConfig {
436                last_channel: Some(2),
437                ..Default::default()
438            },
439        ) {
440            Err(AudioReadError::InvalidEndChannel(_, _)) => (),
441            _ => panic!(),
442        }
443    }
444}