1use std::fs::File;
2use std::path::Path;
3
4use audio_blocks::AudioBlockInterleavedView;
5use num::Float;
6use symphonia::core::audio::SampleBuffer;
7use symphonia::core::codecs::{CODEC_TYPE_NULL, DecoderOptions};
8use symphonia::core::errors::Error;
9use symphonia::core::formats::{FormatOptions, SeekMode, SeekTo};
10use symphonia::core::io::MediaSourceStream;
11use symphonia::core::meta::MetadataOptions;
12use symphonia::core::probe::Hint;
13use thiserror::Error;
14
15#[derive(Debug, Error)]
16pub enum AudioReadError {
17 #[error("could not read file")]
18 FileError(#[from] std::io::Error),
19 #[error("could not decode audio")]
20 EncodingError(#[from] symphonia::core::errors::Error),
21 #[error("could not find track in file")]
22 NoTrack,
23 #[error("could not find sample rate in file")]
24 NoSampleRate,
25 #[error("end frame {0} is larger than start frame {1}")]
26 EndFrameLargerThanStartFrame(usize, usize),
27 #[error("start channel {0} invalid, audio file has only {1} channels")]
28 InvalidStartChannel(usize, usize),
29 #[error("invalid number of channels to extract: {0}")]
30 InvalidNumChannels(usize),
31}
32
33#[derive(Default, Debug, Clone, Copy)]
35pub enum Position {
36 #[default]
38 Default,
39 Time(std::time::Duration),
41 Frame(usize),
43}
44
45#[derive(Default)]
46pub struct AudioReadConfig {
47 pub start: Position,
49 pub stop: Position,
51 pub start_channel: Option<usize>,
53 pub num_channels: Option<usize>,
55}
56
57#[derive(Default)]
58pub struct AudioData<F: Float + 'static> {
59 pub interleaved_samples: Vec<F>,
60 pub sample_rate: u32,
61 pub num_channels: usize,
62 pub num_frames: usize,
63}
64
65impl<F: Float> AudioData<F> {
66 pub fn audio_block(&self) -> AudioBlockInterleavedView<'_, F> {
72 AudioBlockInterleavedView::from_slice(
73 &self.interleaved_samples,
74 self.num_channels as u16,
75 self.num_frames,
76 )
77 }
78}
79
80pub fn audio_read<P: AsRef<Path>, F: Float>(
81 path: P,
82 config: AudioReadConfig,
83) -> Result<AudioData<F>, AudioReadError> {
84 let src = File::open(path.as_ref())?;
85 let mss = MediaSourceStream::new(Box::new(src), Default::default());
86
87 let mut hint = Hint::new();
88 if let Some(ext) = path.as_ref().extension()
89 && let Some(ext_str) = ext.to_str()
90 {
91 hint.with_extension(ext_str);
92 }
93
94 let meta_opts: MetadataOptions = Default::default();
95 let fmt_opts: FormatOptions = Default::default();
96
97 let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
98
99 let mut format = probed.format;
100
101 let track = format
102 .tracks()
103 .iter()
104 .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
105 .ok_or(AudioReadError::NoTrack)?;
106
107 let sample_rate = track
108 .codec_params
109 .sample_rate
110 .ok_or(AudioReadError::NoSampleRate)?;
111
112 let track_id = track.id;
113
114 let codec_params = track.codec_params.clone();
116 let time_base = track.codec_params.time_base;
117
118 let start_frame = match config.start {
120 Position::Default => 0,
121 Position::Time(duration) => {
122 let secs = duration.as_secs_f64();
123 (secs * sample_rate as f64) as usize
124 }
125 Position::Frame(frame) => frame,
126 };
127
128 let end_frame: Option<usize> = match config.stop {
129 Position::Default => None,
130 Position::Time(duration) => {
131 let secs = duration.as_secs_f64();
132 Some((secs * sample_rate as f64) as usize)
133 }
134 Position::Frame(frame) => Some(frame),
135 };
136
137 if let Some(end_frame) = end_frame
138 && start_frame > end_frame
139 {
140 return Err(AudioReadError::EndFrameLargerThanStartFrame(
141 end_frame,
142 start_frame,
143 ));
144 }
145
146 if start_frame > sample_rate as usize
148 && let Some(tb) = time_base
149 {
150 let seek_sample = (start_frame as f64 * 0.9) as u64;
152 let seek_ts = (seek_sample * tb.denom as u64) / (sample_rate as u64);
153
154 let _ = format.seek(
156 SeekMode::Accurate,
157 SeekTo::TimeStamp {
158 ts: seek_ts,
159 track_id,
160 },
161 );
162 }
163
164 let dec_opts: DecoderOptions = Default::default();
165 let mut decoder = symphonia::default::get_codecs().make(&codec_params, &dec_opts)?;
166
167 let mut sample_buf = None;
168 let mut samples = Vec::new();
169 let mut num_channels = 0usize;
170 let start_channel = config.start_channel;
171
172 let mut current_sample: Option<u64> = None;
174
175 loop {
176 let packet = match format.next_packet() {
177 Ok(packet) => packet,
178 Err(Error::ResetRequired) => {
179 decoder.reset();
180 continue;
181 }
182 Err(Error::IoError(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
183 break;
184 }
185 Err(err) => return Err(err.into()),
186 };
187
188 if packet.track_id() != track_id {
189 continue;
190 }
191
192 let decoded = decoder.decode(&packet)?;
193
194 if current_sample.is_none() {
196 let ts = packet.ts();
197 if let Some(tb) = time_base {
198 current_sample = Some((ts * sample_rate as u64) / tb.denom as u64);
200 } else {
201 current_sample = Some(0);
202 }
203 }
204
205 if sample_buf.is_none() {
206 let spec = *decoded.spec();
207 let duration = decoded.capacity() as u64;
208 sample_buf = Some(SampleBuffer::<f32>::new(duration, spec));
209
210 num_channels = spec.channels.count();
212
213 let ch_start = start_channel.unwrap_or(0);
215 let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
216
217 if ch_start >= num_channels {
218 return Err(AudioReadError::InvalidStartChannel(ch_start, num_channels));
219 }
220 if ch_count == 0 {
221 return Err(AudioReadError::InvalidNumChannels(0));
222 }
223 if ch_start + ch_count > num_channels {
224 return Err(AudioReadError::InvalidNumChannels(ch_count));
225 }
226 }
227
228 if let Some(buf) = &mut sample_buf {
229 buf.copy_interleaved_ref(decoded);
230 let packet_samples = buf.samples();
231
232 let mut pos = current_sample.unwrap_or(0);
233
234 let ch_start = start_channel.unwrap_or(0);
236 let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
237 let ch_end = ch_start + ch_count;
238
239 let frames = packet_samples.len() / num_channels;
241
242 for frame_idx in 0..frames {
244 if let Some(end) = end_frame
246 && pos >= end as u64
247 {
248 let num_frames = samples.len() / ch_count;
249 return Ok(AudioData {
250 sample_rate,
251 num_channels: ch_count,
252 num_frames,
253 interleaved_samples: samples,
254 });
255 }
256
257 if pos >= start_frame as u64 {
259 for ch in ch_start..ch_end {
262 let sample_idx = frame_idx * num_channels + ch;
263 samples.push(F::from(packet_samples[sample_idx]).unwrap());
264 }
265 }
266
267 pos += 1;
268 }
269
270 current_sample = Some(pos);
272 }
273 }
274
275 let ch_start = start_channel.unwrap_or(0);
276 let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
277 let num_frames = samples.len() / ch_count;
278
279 Ok(AudioData {
280 sample_rate,
281 num_channels: ch_count,
282 num_frames,
283 interleaved_samples: samples,
284 })
285}
286
287#[cfg(test)]
288mod tests {
289 use std::time::Duration;
290
291 use audio_blocks::AudioBlock;
292
293 use super::*;
294
295 #[test]
296 fn test_samples_selection() {
297 let data1: AudioData<f32> =
298 audio_read("test_data/test_1ch.wav", AudioReadConfig::default()).unwrap();
299 let block1 = data1.audio_block();
300 assert_eq!(data1.sample_rate, 48000);
301 assert_eq!(block1.num_frames(), 48000);
302 assert_eq!(block1.num_channels(), 1);
303
304 let data2: AudioData<f32> = audio_read(
305 "test_data/test_1ch.wav",
306 AudioReadConfig {
307 start: Position::Frame(1100),
308 stop: Position::Frame(1200),
309 ..Default::default()
310 },
311 )
312 .unwrap();
313 let block2 = data2.audio_block();
314 assert_eq!(data2.sample_rate, 48000);
315 assert_eq!(block2.num_frames(), 100);
316 assert_eq!(block2.num_channels(), 1);
317 assert_eq!(block1.raw_data()[1100..1200], block2.raw_data()[..]);
318 }
319
320 #[test]
321 fn test_time_selection() {
322 let data1: AudioData<f32> =
323 audio_read("test_data/test_1ch.wav", AudioReadConfig::default()).unwrap();
324 let block1 = data1.audio_block();
325 assert_eq!(data1.sample_rate, 48000);
326 assert_eq!(block1.num_frames(), 48000);
327 assert_eq!(block1.num_channels(), 1);
328
329 let data2: AudioData<f32> = audio_read(
330 "test_data/test_1ch.wav",
331 AudioReadConfig {
332 start: Position::Time(Duration::from_secs_f32(0.5)),
333 stop: Position::Time(Duration::from_secs_f32(0.6)),
334 ..Default::default()
335 },
336 )
337 .unwrap();
338
339 let block2 = data2.audio_block();
340 assert_eq!(data2.sample_rate, 48000);
341 assert_eq!(block2.num_frames(), 4800);
342 assert_eq!(block2.num_channels(), 1);
343 assert_eq!(block1.raw_data()[24000..28800], block2.raw_data()[..]);
344 }
345
346 #[test]
347 fn test_channel_selection() {
348 let data1: AudioData<f32> =
349 audio_read("test_data/test_4ch.wav", AudioReadConfig::default()).unwrap();
350 let block1 = data1.audio_block();
351 assert_eq!(data1.sample_rate, 48000);
352 assert_eq!(block1.num_frames(), 48000);
353 assert_eq!(block1.num_channels(), 4);
354
355 let data2: AudioData<f32> = audio_read(
356 "test_data/test_4ch.wav",
357 AudioReadConfig {
358 start_channel: Some(1),
359 num_channels: Some(2),
360 ..Default::default()
361 },
362 )
363 .unwrap();
364
365 let block2 = data2.audio_block();
366 assert_eq!(data2.sample_rate, 48000);
367 assert_eq!(block2.num_frames(), 48000);
368 assert_eq!(block2.num_channels(), 2);
369
370 for frame in 0..10 {
372 assert_eq!(block2.sample(0, frame), block1.sample(1, frame));
373 assert_eq!(block2.sample(1, frame), block1.sample(2, frame));
374 }
375 }
376
377 #[test]
378 fn test_fail_selection() {
379 match audio_read::<_, f32>(
380 "test_data/test_1ch.wav",
381 AudioReadConfig {
382 start: Position::Frame(100),
383 stop: Position::Frame(99),
384 ..Default::default()
385 },
386 ) {
387 Err(AudioReadError::EndFrameLargerThanStartFrame(_, _)) => (),
388 _ => panic!(),
389 }
390
391 match audio_read::<_, f32>(
392 "test_data/test_1ch.wav",
393 AudioReadConfig {
394 start: Position::Time(Duration::from_secs_f32(0.6)),
395 stop: Position::Time(Duration::from_secs_f32(0.5)),
396 ..Default::default()
397 },
398 ) {
399 Err(AudioReadError::EndFrameLargerThanStartFrame(_, _)) => (),
400 _ => panic!(),
401 }
402
403 match audio_read::<_, f32>(
404 "test_data/test_1ch.wav",
405 AudioReadConfig {
406 start_channel: Some(1),
407 ..Default::default()
408 },
409 ) {
410 Err(AudioReadError::InvalidStartChannel(_, _)) => (),
411 _ => panic!(),
412 }
413
414 match audio_read::<_, f32>(
415 "test_data/test_1ch.wav",
416 AudioReadConfig {
417 num_channels: Some(0),
418 ..Default::default()
419 },
420 ) {
421 Err(AudioReadError::InvalidNumChannels(0)) => (),
422 _ => panic!(),
423 }
424
425 match audio_read::<_, f32>(
426 "test_data/test_1ch.wav",
427 AudioReadConfig {
428 num_channels: Some(2),
429 ..Default::default()
430 },
431 ) {
432 Err(AudioReadError::InvalidNumChannels(2)) => (),
433 _ => panic!(),
434 }
435 }
436}