creek_decode_wav/
lib.rs

1use std::fs::File;
2use std::path::PathBuf;
3
4use symphonia::core::audio::SampleBuffer;
5use symphonia::core::codecs::{CodecParameters, Decoder as SymphDecoder, DecoderOptions};
6use symphonia::core::errors::Error;
7use symphonia::core::formats::{FormatOptions, FormatReader, SeekMode, SeekTo};
8use symphonia::core::io::MediaSourceStream;
9use symphonia::core::meta::MetadataOptions;
10use symphonia::core::probe::Hint;
11use symphonia::core::units::Duration;
12
13use creek_core::{DataBlock, Decoder, FileInfo};
14
15mod error;
16pub use error::OpenError;
17
18pub struct SymphoniaDecoder {
19    reader: Box<dyn FormatReader>,
20    decoder: Box<dyn SymphDecoder>,
21
22    smp_buf: SampleBuffer<f32>,
23    curr_smp_buf_i: usize,
24
25    num_frames: usize,
26    num_channels: usize,
27    sample_rate: Option<u32>,
28    block_size: usize,
29
30    current_frame: usize,
31    reset_smp_buffer: bool,
32}
33
34impl Decoder for SymphoniaDecoder {
35    type T = f32;
36    type FileParams = CodecParameters;
37    type OpenError = OpenError;
38    type FatalError = Error;
39    type AdditionalOpts = ();
40
41    const DEFAULT_BLOCK_SIZE: usize = 16384;
42    const DEFAULT_NUM_CACHE_BLOCKS: usize = 0;
43    const DEFAULT_NUM_LOOK_AHEAD_BLOCKS: usize = 8;
44
45    fn new(
46        file: PathBuf,
47        start_frame: usize,
48        block_size: usize,
49        _additional_opts: Self::AdditionalOpts,
50    ) -> Result<(Self, FileInfo<Self::FileParams>), Self::OpenError> {
51        // Create a hint to help the format registry guess what format reader is appropriate.
52        let mut hint = Hint::new();
53
54        // Provide the file extension as a hint.
55        if let Some(extension) = file.extension() {
56            if let Some(extension_str) = extension.to_str() {
57                hint.with_extension(extension_str);
58            }
59        }
60
61        let source = Box::new(File::open(file)?);
62
63        // Create the media source stream using the boxed media source from above.
64        let mss = MediaSourceStream::new(source, Default::default());
65
66        // Use the default options for metadata and format readers.
67        let format_opts: FormatOptions = Default::default();
68        let metadata_opts: MetadataOptions = Default::default();
69
70        let probed =
71            symphonia::default::get_probe().format(&hint, mss, &format_opts, &metadata_opts)?;
72
73        let mut reader = probed.format;
74
75        let decoder_opts = DecoderOptions {
76            ..Default::default()
77        };
78
79        let params = {
80            // Get the default stream.
81            let stream = reader
82                .default_track()
83                .ok_or_else(|| OpenError::NoDefaultTrack)?;
84
85            stream.codec_params.clone()
86        };
87
88        let num_frames = params.n_frames.ok_or_else(|| OpenError::NoNumFrames)? as usize;
89        let num_channels = (params.channels.ok_or_else(|| OpenError::NoNumChannels)?).count();
90        let sample_rate = params.sample_rate;
91
92        // Seek the reader to the requested position.
93        if start_frame != 0 {
94            let seconds = start_frame as f64 / f64::from(sample_rate.unwrap_or(44100));
95
96            reader.seek(
97                SeekMode::Accurate,
98                SeekTo::Time {
99                    time: seconds.into(),
100                    track_id: None,
101                }
102            )?;
103        }
104
105        // Create a decoder for the stream.
106        let mut decoder = symphonia::default::get_codecs().make(&params, &decoder_opts)?;
107
108        // Decode the first packet to get the signal specification.
109        let smp_buf = loop {
110            match decoder.decode(&reader.next_packet()?) {
111                Ok(decoded) => {
112                    // Get the buffer spec.
113                    let spec = *decoded.spec();
114
115                    // Get the buffer capacity.
116                    let capacity = Duration::from(decoded.capacity() as u64);
117
118                    let mut smp_buf = SampleBuffer::<f32>::new(capacity, spec);
119
120                    smp_buf.copy_interleaved_ref(decoded);
121
122                    break smp_buf;
123                }
124                Err(Error::DecodeError(e)) => {
125                    // Decode errors are not fatal. Send a warning and try to decode the next packet.
126
127                    println!("{}", e);
128
129                    // TODO: print warning.
130
131                    continue;
132                }
133                Err(e) => {
134                    // Errors other than decode errors are fatal.
135                    return Err(e.into());
136                }
137            }
138        };
139
140        let file_info = FileInfo {
141            params,
142            num_frames,
143            num_channels: num_channels as u16,
144            sample_rate: sample_rate.map(|s| s as u32),
145        };
146
147        Ok((
148            Self {
149                reader,
150                decoder,
151
152                smp_buf,
153                curr_smp_buf_i: 0,
154
155                num_frames,
156                num_channels,
157                sample_rate,
158                block_size,
159
160                current_frame: start_frame,
161                reset_smp_buffer: false,
162            },
163            file_info,
164        ))
165    }
166
167    fn seek(&mut self, frame: usize) -> Result<(), Self::FatalError> {
168        if frame >= self.num_frames {
169            // Do nothing if out of range.
170            self.current_frame = self.num_frames;
171
172            return Ok(());
173        }
174
175        self.current_frame = frame;
176
177        let seconds = self.current_frame as f64 / f64::from(self.sample_rate.unwrap_or(44100));
178
179        match self.reader.seek(
180            SeekMode::Accurate,
181            SeekTo::Time {
182                time: seconds.into(),
183                track_id: None,
184            }
185        ) {
186            Ok(_res) => {}
187            Err(e) => {
188                return Err(e);
189            }
190        }
191
192        self.reset_smp_buffer = true;
193        self.curr_smp_buf_i = 0;
194
195        /*
196        let decoder_opts = DecoderOptions {
197            verify: false,
198            ..Default::default()
199        };
200
201        self.decoder.close();
202        self.decoder = symphonia::default::get_codecs()
203            .make(self.decoder.codec_params(), &decoder_opts)?;
204            */
205
206        Ok(())
207    }
208
209    unsafe fn decode(
210        &mut self,
211        data_block: &mut DataBlock<Self::T>,
212    ) -> Result<(), Self::FatalError> {
213        if self.current_frame >= self.num_frames {
214            // Do nothing if reached the end of the file.
215            return Ok(());
216        }
217
218        let mut reached_end_of_file = false;
219
220        let mut block_start = 0;
221        while block_start < self.block_size {
222            let num_frames_to_cpy = if self.reset_smp_buffer {
223                // Get new data first.
224                self.reset_smp_buffer = false;
225                0
226            } else if self.smp_buf.len() < self.num_channels {
227                // Get new data first.
228                0
229            } else {
230                // Find the maximum amount of frames that can be copied.
231                (self.block_size - block_start)
232                    .min((self.smp_buf.len() - self.curr_smp_buf_i) / self.num_channels)
233            };
234
235            if num_frames_to_cpy != 0 {
236                if self.num_channels == 1 {
237                    // Mono, no need to deinterleave.
238                    data_block.block[0][block_start..block_start + num_frames_to_cpy]
239                        .copy_from_slice(
240                            &self.smp_buf.samples()
241                                [self.curr_smp_buf_i..self.curr_smp_buf_i + num_frames_to_cpy],
242                        );
243                } else if self.num_channels == 2 {
244                    // Provide efficient stereo deinterleaving.
245
246                    let smp_buf = &self.smp_buf.samples()
247                        [self.curr_smp_buf_i..self.curr_smp_buf_i + (num_frames_to_cpy * 2)];
248
249                    let (block1, block2) = data_block.block.split_at_mut(1);
250                    let block1 = &mut block1[0][block_start..block_start + num_frames_to_cpy];
251                    let block2 = &mut block2[0][block_start..block_start + num_frames_to_cpy];
252
253                    for i in 0..num_frames_to_cpy {
254                        block1[i] = smp_buf[i * 2];
255                        block2[i] = smp_buf[(i * 2) + 1];
256                    }
257                } else {
258                    let smp_buf = &self.smp_buf.samples()[self.curr_smp_buf_i
259                        ..self.curr_smp_buf_i + (num_frames_to_cpy * self.num_channels)];
260
261                    for i in 0..num_frames_to_cpy {
262                        for (ch, block) in data_block.block.iter_mut().enumerate() {
263                            block[block_start + i] = smp_buf[(i * self.num_channels) + ch];
264                        }
265                    }
266                }
267
268                block_start += num_frames_to_cpy;
269
270                self.curr_smp_buf_i += num_frames_to_cpy * self.num_channels;
271                if self.curr_smp_buf_i >= self.smp_buf.len() {
272                    self.reset_smp_buffer = true;
273                }
274            } else {
275                // Decode more packets.
276
277                loop {
278                    match self.reader.next_packet() {
279                        Ok(packet) => {
280                            match self.decoder.decode(&packet) {
281                                Ok(decoded) => {
282                                    self.smp_buf.copy_interleaved_ref(decoded);
283                                    self.curr_smp_buf_i = 0;
284                                    break;
285                                }
286                                Err(Error::DecodeError(e)) => {
287                                    // Decode errors are not fatal. Print a message and try to decode the next packet as
288                                    // usual.
289
290                                    println!("{}", e);
291
292                                    // TODO: print warning.
293
294                                    continue;
295                                }
296                                Err(e) => {
297                                    // Errors other than decode errors are fatal.
298                                    return Err(e);
299                                }
300                            }
301                        }
302                        Err(e) => {
303                            if let Error::IoError(io_error) = &e {
304                                if io_error.kind() == std::io::ErrorKind::UnexpectedEof {
305                                    // End of file, stop decoding.
306                                    reached_end_of_file = true;
307                                    block_start = self.block_size;
308                                    break;
309                                } else {
310                                    return Err(e);
311                                }
312                            } else {
313                                return Err(e);
314                            }
315                        }
316                    }
317                }
318            }
319        }
320
321        if reached_end_of_file {
322            self.current_frame = self.num_frames;
323        } else {
324            self.current_frame += self.block_size;
325        }
326
327        Ok(())
328    }
329
330    fn current_frame(&self) -> usize {
331        self.current_frame
332    }
333}
334
335impl Drop for SymphoniaDecoder {
336    fn drop(&mut self) {
337        let _ = self.decoder.finalize();
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344
345    #[test]
346    fn decoder_new() {
347        let files = vec![
348            //  file | num_channels | num_frames | sample_rate
349            ("./test_files/wav_u8_mono.wav", 1, 1323000, Some(44100)),
350            ("./test_files/wav_i16_mono.wav", 1, 1323000, Some(44100)),
351            ("./test_files/wav_i24_mono.wav", 1, 1323000, Some(44100)),
352            ("./test_files/wav_i32_mono.wav", 1, 1323000, Some(44100)),
353            ("./test_files/wav_f32_mono.wav", 1, 1323000, Some(44100)),
354            ("./test_files/wav_i24_stereo.wav", 2, 1323000, Some(44100)),
355            //"./test_files/ogg_mono.ogg",
356            //"./test_files/ogg_stereo.ogg",
357            //"./test_files/mp3_constant_mono.mp3",
358            //"./test_files/mp3_constant_stereo.mp3",
359            //"./test_files/mp3_variable_mono.mp3",
360            //"./test_files/mp3_variable_stereo.mp3",
361        ];
362
363        for file in files {
364            dbg!(file.0);
365            let decoder =
366                SymphoniaDecoder::new(file.0.into(), 0, SymphoniaDecoder::DEFAULT_BLOCK_SIZE, ());
367            match decoder {
368                Ok((_, file_info)) => {
369                    assert_eq!(file_info.num_channels, file.1);
370                    assert_eq!(file_info.num_frames, file.2);
371                    //assert_eq!(file_info.sample_rate, file.3);
372                }
373                Err(e) => {
374                    panic!("{}", e);
375                }
376            }
377        }
378    }
379}