mecomp_analysis/decoder/
mecomp.rs

1//! Implementation of the mecomp decoder, which is rodio/rubato based.
2
3use std::{f32::consts::SQRT_2, fs::File, num::NonZeroUsize, time::Duration};
4
5use object_pool::Pool;
6use rubato::{FastFixedIn, Resampler, ResamplerConstructionError};
7use symphonia::{
8    core::{
9        audio::{AudioBufferRef, SampleBuffer, SignalSpec},
10        codecs::{DecoderOptions, CODEC_TYPE_NULL},
11        errors::Error,
12        formats::{FormatOptions, FormatReader},
13        io::{MediaSourceStream, MediaSourceStreamOptions},
14        meta::MetadataOptions,
15        probe::Hint,
16        units,
17    },
18    default::get_probe,
19};
20
21use crate::{errors::AnalysisError, errors::AnalysisResult, ResampledAudio, SAMPLE_RATE};
22
23use super::Decoder;
24
25const MAX_DECODE_RETRIES: usize = 3;
26const CHUNK_SIZE: usize = 4096;
27
28/// Struct used by the symphonia-based bliss decoders to decode audio files
29#[doc(hidden)]
30pub struct SymphoniaSource {
31    decoder: Box<dyn symphonia::core::codecs::Decoder>,
32    current_span_offset: usize,
33    format: Box<dyn FormatReader>,
34    total_duration: Option<Duration>,
35    buffer: SampleBuffer<f32>,
36    spec: SignalSpec,
37}
38
39impl SymphoniaSource {
40    /// Create a new `SymphoniaSource` from a `MediaSourceStream`
41    ///
42    /// # Errors
43    ///
44    /// This function will return an error if the `MediaSourceStream` does not contain any streams, or if the stream
45    /// is not supported by the decoder.
46    pub fn new(mss: MediaSourceStream) -> Result<Self, Error> {
47        Self::init(mss)?.ok_or(Error::DecodeError("No Streams"))
48    }
49
50    fn init(mss: MediaSourceStream) -> symphonia::core::errors::Result<Option<Self>> {
51        let hint = Hint::new();
52        let format_opts = FormatOptions::default();
53        let metadata_opts = MetadataOptions::default();
54        let mut probed_format = get_probe()
55            .format(&hint, mss, &format_opts, &metadata_opts)?
56            .format;
57
58        let Some(stream) = probed_format.default_track() else {
59            return Ok(None);
60        };
61
62        // Select the first supported track
63        let track = probed_format
64            .tracks()
65            .iter()
66            .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
67            .ok_or(Error::Unsupported("No track with supported codec"))?;
68
69        let track_id = track.id;
70
71        let mut decoder = symphonia::default::get_codecs()
72            .make(&track.codec_params, &DecoderOptions::default())?;
73        let total_duration = stream
74            .codec_params
75            .time_base
76            .zip(stream.codec_params.n_frames)
77            .map(|(base, spans)| base.calc_time(spans).into());
78
79        let mut decode_errors: usize = 0;
80        let decoded_audio = loop {
81            let current_span = probed_format.next_packet()?;
82
83            // If the packet does not belong to the selected track, skip over it
84            if current_span.track_id() != track_id {
85                continue;
86            }
87
88            match decoder.decode(&current_span) {
89                Ok(audio) => break audio,
90                Err(Error::DecodeError(_)) if decode_errors < MAX_DECODE_RETRIES => {
91                    decode_errors += 1;
92                }
93                Err(e) => return Err(e),
94            }
95        };
96
97        let spec = decoded_audio.spec().to_owned();
98        let buffer = Self::get_buffer(decoded_audio, spec);
99        Ok(Some(Self {
100            decoder,
101            current_span_offset: 0,
102            format: probed_format,
103            total_duration,
104            buffer,
105            spec,
106        }))
107    }
108
109    #[inline]
110    fn get_buffer(decoded: AudioBufferRef, spec: SignalSpec) -> SampleBuffer<f32> {
111        let duration = units::Duration::from(decoded.capacity() as u64);
112        let mut buffer = SampleBuffer::<f32>::new(duration, spec);
113        buffer.copy_interleaved_ref(decoded);
114        buffer
115    }
116
117    #[inline]
118    #[must_use]
119    pub const fn total_duration(&self) -> Option<Duration> {
120        self.total_duration
121    }
122
123    #[inline]
124    #[must_use]
125    pub const fn sample_rate(&self) -> u32 {
126        self.spec.rate
127    }
128
129    #[inline]
130    #[must_use]
131    pub fn channels(&self) -> usize {
132        self.spec.channels.count()
133    }
134}
135
136impl Iterator for SymphoniaSource {
137    type Item = f32;
138
139    fn size_hint(&self) -> (usize, Option<usize>) {
140        (
141            self.buffer.samples().len(),
142            self.total_duration.map(|dur| {
143                (usize::try_from(dur.as_secs()).unwrap_or(usize::MAX) + 1)
144                    * self.spec.rate as usize
145                    * self.spec.channels.count()
146            }),
147        )
148    }
149
150    fn next(&mut self) -> Option<Self::Item> {
151        if self.current_span_offset < self.buffer.len() {
152            let sample = self.buffer.samples().get(self.current_span_offset);
153            self.current_span_offset += 1;
154
155            return sample.copied();
156        }
157
158        let mut decode_errors = 0;
159        let decoded = loop {
160            let packet = self.format.next_packet().ok()?;
161            match self.decoder.decode(&packet) {
162                // Loop until we get a packet with audio frames. This is necessary because some
163                // formats can have packets with only metadata, particularly when rewinding, in
164                // which case the iterator would otherwise end with `None`.
165                // Note: checking `decoded.frames()` is more reliable than `packet.dur()`, which
166                // can returns non-zero durations for packets without audio frames.
167                Ok(decoded) if decoded.frames() > 0 => break decoded,
168                Ok(_) => {}
169                Err(Error::DecodeError(_)) if decode_errors < MAX_DECODE_RETRIES => {
170                    decode_errors += 1;
171                }
172                Err(_) => return None,
173            }
174        };
175
176        decoded.spec().clone_into(&mut self.spec);
177        self.buffer = Self::get_buffer(decoded, self.spec);
178        self.current_span_offset = 1;
179        self.buffer.samples().first().copied()
180    }
181}
182
183#[allow(clippy::module_name_repetitions)]
184pub struct MecompDecoder<R = FastFixedIn<f32>> {
185    resampler: Pool<Result<R, ResamplerConstructionError>>,
186}
187
188impl MecompDecoder {
189    #[inline]
190    fn generate_resampler() -> Result<FastFixedIn<f32>, ResamplerConstructionError> {
191        FastFixedIn::new(1.0, 10.0, rubato::PolynomialDegree::Cubic, CHUNK_SIZE, 1)
192    }
193
194    /// Create a new `MecompDecoder`
195    ///
196    /// # Errors
197    ///
198    /// This function will return an error if the resampler could not be created.
199    #[inline]
200    pub fn new() -> Result<Self, AnalysisError> {
201        // try to generate a resampler first, so we can return an error if it fails (if it fails, it's likely all future calls will too)
202        let first = Self::generate_resampler()?;
203
204        let pool_size = std::thread::available_parallelism().map_or(1, NonZeroUsize::get);
205        let resampler = Pool::new(pool_size, Self::generate_resampler);
206        resampler.attach(Ok(first));
207
208        Ok(Self { resampler })
209    }
210
211    /// we need to collapse the audio source into one channel
212    /// channels are interleaved, so if we have 2 channels, `[1, 2, 3, 4]` and `[5, 6, 7, 8]`,
213    /// they will be stored as `[1, 5, 2, 6, 3, 7, 4, 8]`
214    ///
215    /// For stereo sound, we can make this mono by averaging the channels and multiplying by the square root of 2,
216    /// This recovers the exact behavior of ffmpeg when converting stereo to mono, however for 2.1 and 5.1 surround sound,
217    /// ffmpeg might be doing something different, and I'm not sure what that is (don't have a 5.1 surround sound file to test with)
218    ///
219    /// TODO: Figure out how ffmpeg does it for 2.1 and 5.1 surround sound, and do it the same way
220    #[inline]
221    #[doc(hidden)]
222    pub fn into_mono_samples(
223        source: Vec<f32>,
224        num_channels: usize,
225    ) -> Result<Vec<f32>, AnalysisError> {
226        match num_channels {
227            // no channels
228            0 => Err(AnalysisError::DecodeError(Error::DecodeError(
229                "The audio source has no channels",
230            ))),
231            // mono
232            1 => Ok(source),
233            // stereo
234            2 => Ok(source
235                .chunks_exact(2)
236                .map(|chunk| (chunk[0] + chunk[1]) * SQRT_2 / 2.)
237                .collect()),
238            // 2.1 or 5.1 surround
239            _ => {
240                log::warn!("The audio source has more than 2 channels (might be 2.1 or 5.1 surround sound), will collapse to mono by averaging the channels");
241
242                #[allow(clippy::cast_precision_loss)]
243                let num_channels_f32 = num_channels as f32;
244                let mono_samples = source
245                    .chunks_exact(num_channels)
246                    .map(|chunk| chunk.iter().sum::<f32>() / num_channels_f32)
247                    .collect();
248
249                Ok(mono_samples)
250            }
251        }
252    }
253
254    /// Resample the given mono samples to 22050 Hz
255    #[inline]
256    #[doc(hidden)]
257    pub fn resample_mono_samples(
258        &self,
259        mut samples: Vec<f32>,
260        sample_rate: u32,
261        total_duration: Duration,
262    ) -> Result<Vec<f32>, AnalysisError> {
263        if sample_rate == SAMPLE_RATE {
264            samples.shrink_to_fit();
265            return Ok(samples);
266        }
267
268        let mut resampled_frames = Vec::with_capacity(
269            (usize::try_from(total_duration.as_secs()).unwrap_or(usize::MAX) + 1)
270                * SAMPLE_RATE as usize,
271        );
272
273        let (pool, resampler) = self.resampler.pull(Self::generate_resampler).detach();
274        let mut resampler = resampler?;
275        resampler.set_resample_ratio(f64::from(SAMPLE_RATE) / f64::from(sample_rate), false)?;
276
277        let delay = resampler.output_delay();
278
279        let new_length = samples.len() * SAMPLE_RATE as usize / sample_rate as usize;
280        let mut output_buffer = resampler.output_buffer_allocate(true);
281
282        // chunks of frames, each being CHUNKSIZE long.
283        let sample_chunks = samples.chunks_exact(CHUNK_SIZE);
284        let remainder = sample_chunks.remainder();
285
286        for chunk in sample_chunks {
287            debug_assert!(resampler.input_frames_next() == CHUNK_SIZE);
288
289            let (_, output_written) =
290                resampler.process_into_buffer(&[chunk], output_buffer.as_mut_slice(), None)?;
291            resampled_frames.extend_from_slice(&output_buffer[0][..output_written]);
292        }
293
294        // process the remainder
295        if !remainder.is_empty() {
296            let (_, output_written) = resampler.process_partial_into_buffer(
297                Some(&[remainder]),
298                output_buffer.as_mut_slice(),
299                None,
300            )?;
301            resampled_frames.extend_from_slice(&output_buffer[0][..output_written]);
302        }
303
304        // flush final samples from resampler
305        if resampled_frames.len() < new_length + delay {
306            let (_, output_written) = resampler.process_partial_into_buffer(
307                Option::<&[&[f32]]>::None,
308                output_buffer.as_mut_slice(),
309                None,
310            )?;
311            resampled_frames.extend_from_slice(&output_buffer[0][..output_written]);
312        }
313
314        resampler.reset();
315        pool.attach(Ok(resampler));
316
317        Ok(resampled_frames[delay..new_length + delay].to_vec())
318    }
319}
320
321impl Decoder for MecompDecoder {
322    /// A function that should decode and resample a song, optionally
323    /// extracting the song's metadata such as the artist, the album, etc.
324    ///
325    /// The output sample array should be resampled to f32le, one channel, with a sampling rate
326    /// of 22050 Hz. Anything other than that will yield wrong results.
327    #[allow(clippy::missing_inline_in_public_items)]
328    fn decode(&self, path: &std::path::Path) -> AnalysisResult<ResampledAudio> {
329        // open the file
330        let file = File::open(path)?;
331        // create the media source stream
332        let mss = MediaSourceStream::new(Box::new(file), MediaSourceStreamOptions::default());
333
334        let source = SymphoniaSource::new(mss)?;
335
336        // Convert the audio source into a mono channel
337        let sample_rate = source.spec.rate;
338        let Some(total_duration) = source.total_duration else {
339            return Err(AnalysisError::IndeterminantDuration);
340        };
341        let num_channels = source.channels();
342
343        let mono_sample_array =
344            Self::into_mono_samples(source.into_iter().collect(), num_channels)?;
345
346        // then we need to resample the audio source into 22050 Hz
347        let resampled_array =
348            self.resample_mono_samples(mono_sample_array, sample_rate, total_duration)?;
349
350        Ok(ResampledAudio {
351            path: path.to_owned(),
352            samples: resampled_array,
353        })
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use crate::NUMBER_FEATURES;
360
361    use super::{Decoder as DecoderTrait, MecompDecoder as Decoder};
362    use adler32::RollingAdler32;
363    use pretty_assertions::assert_eq;
364    use rstest::rstest;
365    use std::path::Path;
366
367    fn verify_decoding_output(path: &Path, expected_hash: u32) {
368        let decoder = Decoder::new().unwrap();
369        let song = decoder.decode(path).unwrap();
370        let mut hasher = RollingAdler32::new();
371        for sample in &song.samples {
372            hasher.update_buffer(&sample.to_le_bytes());
373        }
374
375        assert_eq!(expected_hash, hasher.hash());
376    }
377
378    // expected hash Obtained through
379    // ffmpeg -i data/s16_stereo_22_5kHz.flac -ar 22050 -ac 1 -c:a pcm_f32le -f hash -hash adler32 -
380    #[rstest]
381    #[ignore = "fails when asked to resample to 22050 Hz, ig ffmpeg does it differently, but I'm not sure what the difference actually is"]
382    #[case::resample_stereo(Path::new("data/s32_stereo_44_1_kHz.flac"), 0xbbcb_a1cf)]
383    #[ignore = "fails when asked to resample to 22050 Hz, ig ffmpeg does it differently, but I'm not sure what the difference actually is"]
384    #[case::resample_mono(Path::new("data/s32_mono_44_1_kHz.flac"), 0xa0f8b8af)]
385    #[case::decode_stereo(Path::new("data/s16_stereo_22_5kHz.flac"), 0x1d7b_2d6d)]
386    #[case::decode_mono(Path::new("data/s16_mono_22_5kHz.flac"), 0x5e01_930b)]
387    #[ignore = "fails when asked to resample to 22050 Hz, ig ffmpeg does it differently, but I'm not sure what the difference actually is"]
388    #[case::resample_mp3(Path::new("data/s32_stereo_44_1_kHz.mp3"), 0x69ca_6906)]
389    #[case::decode_wav(Path::new("data/piano.wav"), 0xde83_1e82)]
390    fn test_decode(#[case] path: &Path, #[case] expected_hash: u32) {
391        verify_decoding_output(path, expected_hash);
392    }
393
394    #[test]
395    fn test_dont_panic_no_channel_layout() {
396        let path = Path::new("data/no_channel.wav");
397        Decoder::new().unwrap().decode(path).unwrap();
398    }
399
400    #[test]
401    fn test_decode_right_capacity_vec() {
402        let path = Path::new("data/s16_mono_22_5kHz.flac");
403        let song = Decoder::new().unwrap().decode(path).unwrap();
404        let sample_array = song.samples;
405        assert_eq!(
406            sample_array.len(), // + SAMPLE_RATE as usize, // The + SAMPLE_RATE is because bliss-rs would add an extra second as a buffer, we don't need to because we know the exact length of the song
407            sample_array.capacity()
408        );
409
410        let path = Path::new("data/s32_stereo_44_1_kHz.flac");
411        let song = Decoder::new().unwrap().decode(path).unwrap();
412        let sample_array = song.samples;
413        assert_eq!(
414            sample_array.len(), // + SAMPLE_RATE as usize,
415            sample_array.capacity()
416        );
417
418        // NOTE: originally used the .ogg file, but it was failing to decode with `DecodeError(IoError("end of stream"))`
419        let path = Path::new("data/capacity_fix.wav");
420        let song = Decoder::new().unwrap().decode(path).unwrap();
421        let sample_array = song.samples;
422        assert_eq!(
423            sample_array.len(), // + SAMPLE_RATE as usize,
424            sample_array.capacity()
425        );
426    }
427
428    const PATH_AND_EXPECTED_ANALYSIS: (&str, [f64; NUMBER_FEATURES]) = (
429        "data/s16_mono_22_5kHz.flac",
430        [
431            0.3846389,
432            -0.849141,
433            -0.75481045,
434            -0.8790748,
435            -0.63258266,
436            -0.7258959,
437            -0.775738,
438            -0.8146726,
439            0.2716726,
440            0.25779057,
441            -0.35661936,
442            -0.63578653,
443            -0.29593682,
444            0.06421304,
445            0.21852458,
446            -0.581239,
447            -0.9466835,
448            -0.9481153,
449            -0.9820945,
450            -0.95968974,
451        ],
452    );
453
454    #[test]
455    fn test_analyze() {
456        let (path, expected_analysis) = PATH_AND_EXPECTED_ANALYSIS;
457        let analysis = Decoder::new()
458            .unwrap()
459            .analyze_path(Path::new(path))
460            .unwrap();
461        for (x, y) in analysis.as_vec().iter().zip(expected_analysis) {
462            assert!(0.01 > (x - y).abs());
463        }
464    }
465}