1use std::fs::File;
2use std::path::Path;
3
4use audio_blocks::AudioBlockInterleavedView;
5use num::Float;
6use symphonia::core::audio::SampleBuffer;
7use symphonia::core::codecs::{CODEC_TYPE_NULL, DecoderOptions};
8use symphonia::core::errors::Error;
9use symphonia::core::formats::{FormatOptions, SeekMode, SeekTo};
10use symphonia::core::io::MediaSourceStream;
11use symphonia::core::meta::MetadataOptions;
12use symphonia::core::probe::Hint;
13use thiserror::Error;
14
15#[derive(Debug, Error)]
16pub enum AudioReadError {
17 #[error("could not read file")]
18 FileError(#[from] std::io::Error),
19 #[error("could not decode audio")]
20 EncodingError(#[from] symphonia::core::errors::Error),
21 #[error("could not find track in file")]
22 NoTrack,
23 #[error("could not find sample rate in file")]
24 NoSampleRate,
25 #[error("end frame {0} is larger than start frame {1}")]
26 EndFrameLargerThanStartFrame(usize, usize),
27 #[error("start channel {0} invalid, audio file has only {1}")]
28 InvalidStartChannel(usize, usize),
29 #[error("end channel {0} invalid, audio file has only {1}")]
30 InvalidEndChannel(usize, usize),
31 #[error("end channel {0} is larger than start channel {1}")]
32 EndChannelLargerThanStartChannel(usize, usize),
33}
34
35#[derive(Debug, Clone, Copy, Default)]
37pub enum Start {
38 #[default]
40 Beginning,
41 Time(std::time::Duration),
43 Frame(usize),
45}
46
47#[derive(Debug, Clone, Copy, Default)]
49pub enum Stop {
50 #[default]
52 End,
53 Time(std::time::Duration),
55 Frame(usize),
57}
58
59#[derive(Default)]
60pub struct AudioReadConfig {
61 pub start: Start,
63 pub stop: Stop,
65 pub first_channel: Option<usize>,
67 pub last_channel: Option<usize>,
69}
70
71#[derive(Default)]
72pub struct AudioData<F: Float + 'static> {
73 pub interleaved_samples: Vec<F>,
74 pub sample_rate: u32,
75 pub num_channels: usize,
76 pub num_frames: usize,
77}
78
79impl<F: Float> AudioData<F> {
80 pub fn audio_block(&self) -> AudioBlockInterleavedView<'_, F> {
86 AudioBlockInterleavedView::from_slice(
87 &self.interleaved_samples,
88 self.num_channels as u16,
89 self.num_frames,
90 )
91 }
92}
93
94pub fn audio_read<P: AsRef<Path>, F: Float>(
95 path: P,
96 config: AudioReadConfig,
97) -> Result<AudioData<F>, AudioReadError> {
98 let src = File::open(path.as_ref())?;
99 let mss = MediaSourceStream::new(Box::new(src), Default::default());
100
101 let mut hint = Hint::new();
102 if let Some(ext) = path.as_ref().extension()
103 && let Some(ext_str) = ext.to_str()
104 {
105 hint.with_extension(ext_str);
106 }
107
108 let meta_opts: MetadataOptions = Default::default();
109 let fmt_opts: FormatOptions = Default::default();
110
111 let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
112
113 let mut format = probed.format;
114
115 let track = format
116 .tracks()
117 .iter()
118 .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
119 .ok_or(AudioReadError::NoTrack)?;
120
121 let sample_rate = track
122 .codec_params
123 .sample_rate
124 .ok_or(AudioReadError::NoSampleRate)?;
125
126 let track_id = track.id;
127
128 let codec_params = track.codec_params.clone();
130 let time_base = track.codec_params.time_base;
131
132 let start_frame = match config.start {
134 Start::Beginning => 0,
135 Start::Time(duration) => {
136 let secs = duration.as_secs_f64();
137 (secs * sample_rate as f64) as usize
138 }
139 Start::Frame(frame) => frame,
140 };
141
142 let end_frame: Option<usize> = match config.stop {
143 Stop::End => None,
144 Stop::Time(duration) => {
145 let secs = duration.as_secs_f64();
146 Some((secs * sample_rate as f64) as usize)
147 }
148 Stop::Frame(frame) => Some(frame),
149 };
150
151 if let Some(end_frame) = end_frame
152 && start_frame > end_frame
153 {
154 return Err(AudioReadError::EndFrameLargerThanStartFrame(
155 end_frame,
156 start_frame,
157 ));
158 }
159
160 if start_frame > sample_rate as usize
162 && let Some(tb) = time_base
163 {
164 let seek_sample = (start_frame as f64 * 0.9) as u64;
166 let seek_ts = (seek_sample * tb.denom as u64) / (sample_rate as u64);
167
168 let _ = format.seek(
170 SeekMode::Accurate,
171 SeekTo::TimeStamp {
172 ts: seek_ts,
173 track_id,
174 },
175 );
176 }
177
178 let dec_opts: DecoderOptions = Default::default();
179 let mut decoder = symphonia::default::get_codecs().make(&codec_params, &dec_opts)?;
180
181 let mut sample_buf = None;
182 let mut samples = Vec::new();
183 let mut num_channels = 0usize;
184 let start_channel = config.first_channel;
185 let end_channel = config.last_channel;
186
187 let mut current_sample: Option<u64> = None;
189
190 loop {
191 let packet = match format.next_packet() {
192 Ok(packet) => packet,
193 Err(Error::ResetRequired) => {
194 decoder.reset();
195 continue;
196 }
197 Err(Error::IoError(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
198 break;
199 }
200 Err(err) => return Err(err.into()),
201 };
202
203 if packet.track_id() != track_id {
204 continue;
205 }
206
207 let decoded = decoder.decode(&packet)?;
208
209 if current_sample.is_none() {
211 let ts = packet.ts();
212 if let Some(tb) = time_base {
213 current_sample = Some((ts * sample_rate as u64) / tb.denom as u64);
215 } else {
216 current_sample = Some(0);
217 }
218 }
219
220 if sample_buf.is_none() {
221 let spec = *decoded.spec();
222 let duration = decoded.capacity() as u64;
223 sample_buf = Some(SampleBuffer::<f32>::new(duration, spec));
224
225 num_channels = spec.channels.count();
227
228 if let Some(start_ch) = start_channel
230 && start_ch >= num_channels
231 {
232 return Err(AudioReadError::InvalidStartChannel(start_ch, num_channels));
233 }
234 if let Some(end_ch) = end_channel {
235 if end_ch > num_channels {
236 return Err(AudioReadError::InvalidEndChannel(end_ch, num_channels));
237 }
238 if let Some(start_ch) = start_channel
239 && end_ch <= start_ch
240 {
241 return Err(AudioReadError::EndChannelLargerThanStartChannel(
242 end_ch, start_ch,
243 ));
244 }
245 }
246 }
247
248 if let Some(buf) = &mut sample_buf {
249 buf.copy_interleaved_ref(decoded);
250 let packet_samples = buf.samples();
251
252 let mut pos = current_sample.unwrap_or(0);
253
254 let ch_start = start_channel.unwrap_or(0);
256 let ch_end = end_channel.unwrap_or(num_channels);
257 let num_channels = ch_end - ch_start;
258
259 if ch_start != 0 || ch_end != num_channels {
261 let frames = packet_samples.len() / num_channels;
264
265 for frame_idx in 0..frames {
266 if let Some(end) = end_frame
268 && pos >= end as u64
269 {
270 let num_frames = samples.len() / num_channels;
271 return Ok(AudioData {
272 sample_rate,
273 num_channels,
274 num_frames,
275 interleaved_samples: samples,
276 });
277 }
278
279 if pos >= start_frame as u64 {
281 for ch in ch_start..ch_end {
283 let sample_idx = frame_idx * num_channels + ch;
284 samples.push(F::from(packet_samples[sample_idx]).unwrap());
285 }
286 }
287
288 pos += 1;
289 }
290 } else {
291 let frames = packet_samples.len() / num_channels;
293
294 for frame_idx in 0..frames {
295 if let Some(end) = end_frame
297 && pos >= end as u64
298 {
299 let num_frames = samples.len() / num_channels;
300 return Ok(AudioData {
301 sample_rate,
302 num_channels,
303 num_frames,
304 interleaved_samples: samples,
305 });
306 }
307
308 if pos >= start_frame as u64 {
310 for ch in 0..num_channels {
312 let sample_idx = frame_idx * num_channels + ch;
313 samples.push(F::from(packet_samples[sample_idx]).unwrap());
314 }
315 }
316
317 pos += 1;
318 }
319 }
320
321 current_sample = Some(pos);
323 }
324 }
325
326 let ch_start = start_channel.unwrap_or(0);
327 let ch_end = end_channel.unwrap_or(num_channels);
328 let num_channels = ch_end - ch_start;
329 let num_frames = samples.len() / num_channels;
330
331 Ok(AudioData {
332 sample_rate,
333 num_channels,
334 num_frames,
335 interleaved_samples: samples,
336 })
337}
338
339#[cfg(test)]
340mod tests {
341 use std::time::Duration;
342
343 use audio_blocks::AudioBlock;
344
345 use super::*;
346
347 #[test]
348 fn test_samples_selection() {
349 let data1: AudioData<f32> = audio_read("test.wav", AudioReadConfig::default()).unwrap();
350 let block1 = data1.audio_block();
351 assert_eq!(data1.sample_rate, 48000);
352 assert_eq!(block1.num_frames(), 48000);
353 assert_eq!(block1.num_channels(), 1);
354
355 let data2: AudioData<f32> = audio_read(
356 "test.wav",
357 AudioReadConfig {
358 start: Start::Frame(1100),
359 stop: Stop::Frame(1200),
360 ..Default::default()
361 },
362 )
363 .unwrap();
364 let block2 = data2.audio_block();
365 assert_eq!(data2.sample_rate, 48000);
366 assert_eq!(block2.num_frames(), 100);
367 assert_eq!(block2.num_channels(), 1);
368 assert_eq!(block1.raw_data()[1100..1200], block2.raw_data()[..]);
369 }
370
371 #[test]
372 fn test_time_selection() {
373 let data1: AudioData<f32> = audio_read("test.wav", AudioReadConfig::default()).unwrap();
374 let block1 = data1.audio_block();
375 assert_eq!(data1.sample_rate, 48000);
376 assert_eq!(block1.num_frames(), 48000);
377 assert_eq!(block1.num_channels(), 1);
378
379 let data2: AudioData<f32> = audio_read(
380 "test.wav",
381 AudioReadConfig {
382 start: Start::Time(Duration::from_secs_f32(0.5)),
383 stop: Stop::Time(Duration::from_secs_f32(0.6)),
384 ..Default::default()
385 },
386 )
387 .unwrap();
388
389 let block2 = data2.audio_block();
390 assert_eq!(data2.sample_rate, 48000);
391 assert_eq!(block2.num_frames(), 4800);
392 assert_eq!(block2.num_channels(), 1);
393 assert_eq!(block1.raw_data()[24000..28800], block2.raw_data()[..]);
394 }
395
396 #[test]
397 fn test_fail_selection() {
398 match audio_read::<_, f32>(
399 "test.wav",
400 AudioReadConfig {
401 start: Start::Frame(100),
402 stop: Stop::Frame(99),
403 ..Default::default()
404 },
405 ) {
406 Err(AudioReadError::EndFrameLargerThanStartFrame(_, _)) => (),
407 _ => panic!(),
408 }
409
410 match audio_read::<_, f32>(
411 "test.wav",
412 AudioReadConfig {
413 start: Start::Time(Duration::from_secs_f32(0.6)),
414 stop: Stop::Time(Duration::from_secs_f32(0.5)),
415 ..Default::default()
416 },
417 ) {
418 Err(AudioReadError::EndFrameLargerThanStartFrame(_, _)) => (),
419 _ => panic!(),
420 }
421
422 match audio_read::<_, f32>(
423 "test.wav",
424 AudioReadConfig {
425 first_channel: Some(1),
426 ..Default::default()
427 },
428 ) {
429 Err(AudioReadError::InvalidStartChannel(_, _)) => (),
430 _ => panic!(),
431 }
432
433 match audio_read::<_, f32>(
434 "test.wav",
435 AudioReadConfig {
436 last_channel: Some(2),
437 ..Default::default()
438 },
439 ) {
440 Err(AudioReadError::InvalidEndChannel(_, _)) => (),
441 _ => panic!(),
442 }
443 }
444}