mecomp_analysis/decoder/
mecomp.rs

1//! Implementation of the mecomp decoder, which is rodio/rubato based.
2
3use std::{f32::consts::SQRT_2, fs::File, num::NonZeroUsize, time::Duration};
4
5use object_pool::Pool;
6use rubato::{FastFixedIn, Resampler, ResamplerConstructionError};
7use symphonia::{
8    core::{
9        audio::{AudioBufferRef, SampleBuffer, SignalSpec},
10        codecs::{CODEC_TYPE_NULL, DecoderOptions},
11        errors::Error,
12        formats::{FormatOptions, FormatReader},
13        io::{MediaSourceStream, MediaSourceStreamOptions},
14        meta::MetadataOptions,
15        probe::Hint,
16        units,
17    },
18    default::get_probe,
19};
20
21use crate::{ResampledAudio, SAMPLE_RATE, errors::AnalysisError, errors::AnalysisResult};
22
23use super::Decoder;
24
25const MAX_DECODE_RETRIES: usize = 3;
26const CHUNK_SIZE: usize = 4096;
27
28/// Struct used by the symphonia-based bliss decoders to decode audio files
29#[doc(hidden)]
30pub struct SymphoniaSource {
31    decoder: Box<dyn symphonia::core::codecs::Decoder>,
32    current_span_offset: usize,
33    format: Box<dyn FormatReader>,
34    total_duration: Option<Duration>,
35    buffer: SampleBuffer<f32>,
36    spec: SignalSpec,
37}
38
39impl SymphoniaSource {
40    /// Create a new `SymphoniaSource` from a `MediaSourceStream`
41    ///
42    /// # Errors
43    ///
44    /// This function will return an error if the `MediaSourceStream` does not contain any streams, or if the stream
45    /// is not supported by the decoder.
46    pub fn new(mss: MediaSourceStream) -> Result<Self, Error> {
47        Self::init(mss)?.ok_or(Error::DecodeError("No Streams"))
48    }
49
50    fn init(mss: MediaSourceStream) -> symphonia::core::errors::Result<Option<Self>> {
51        let hint = Hint::new();
52        let format_opts = FormatOptions {
53            enable_gapless: true,
54            ..Default::default()
55        };
56        let metadata_opts = MetadataOptions::default();
57        let mut probed_format = get_probe()
58            .format(&hint, mss, &format_opts, &metadata_opts)?
59            .format;
60
61        let Some(stream) = probed_format.default_track() else {
62            return Ok(None);
63        };
64
65        // Select the first supported track
66        let track = probed_format
67            .tracks()
68            .iter()
69            .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
70            .ok_or(Error::Unsupported("No track with supported codec"))?;
71
72        let track_id = track.id;
73
74        let mut decoder = symphonia::default::get_codecs()
75            .make(&track.codec_params, &DecoderOptions::default())?;
76        let total_duration = stream
77            .codec_params
78            .time_base
79            .zip(stream.codec_params.n_frames)
80            .map(|(base, spans)| base.calc_time(spans).into());
81
82        let mut decode_errors: usize = 0;
83        let decoded_audio = loop {
84            let current_span = probed_format.next_packet()?;
85
86            // If the packet does not belong to the selected track, skip over it
87            if current_span.track_id() != track_id {
88                continue;
89            }
90
91            match decoder.decode(&current_span) {
92                Ok(audio) => break audio,
93                Err(Error::DecodeError(_)) if decode_errors < MAX_DECODE_RETRIES => {
94                    decode_errors += 1;
95                }
96                Err(e) => return Err(e),
97            }
98        };
99
100        let spec = decoded_audio.spec().to_owned();
101        let buffer = Self::get_buffer(decoded_audio, spec);
102        Ok(Some(Self {
103            decoder,
104            current_span_offset: 0,
105            format: probed_format,
106            total_duration,
107            buffer,
108            spec,
109        }))
110    }
111
112    #[inline]
113    fn get_buffer(decoded: AudioBufferRef<'_>, spec: SignalSpec) -> SampleBuffer<f32> {
114        let duration = units::Duration::from(decoded.capacity() as u64);
115        let mut buffer = SampleBuffer::<f32>::new(duration, spec);
116        buffer.copy_interleaved_ref(decoded);
117        buffer
118    }
119
120    #[inline]
121    #[must_use]
122    pub const fn total_duration(&self) -> Option<Duration> {
123        self.total_duration
124    }
125
126    #[inline]
127    #[must_use]
128    pub const fn sample_rate(&self) -> u32 {
129        self.spec.rate
130    }
131
132    #[inline]
133    #[must_use]
134    pub fn channels(&self) -> usize {
135        self.spec.channels.count()
136    }
137}
138
139impl Iterator for SymphoniaSource {
140    type Item = f32;
141
142    fn size_hint(&self) -> (usize, Option<usize>) {
143        (
144            self.buffer.samples().len(),
145            self.total_duration.map(|dur| {
146                usize::try_from(
147                    (dur.as_secs() + 1)
148                        * u64::from(self.spec.rate)
149                        * self.spec.channels.count() as u64,
150                )
151                .unwrap_or(usize::MAX)
152            }),
153        )
154    }
155
156    fn next(&mut self) -> Option<Self::Item> {
157        if self.current_span_offset < self.buffer.len() {
158            let sample = self.buffer.samples().get(self.current_span_offset);
159            self.current_span_offset += 1;
160
161            return sample.copied();
162        }
163
164        let mut decode_errors = 0;
165        let decoded = loop {
166            let packet = self.format.next_packet().ok()?;
167            match self.decoder.decode(&packet) {
168                // Loop until we get a packet with audio frames. This is necessary because some
169                // formats can have packets with only metadata, particularly when rewinding, in
170                // which case the iterator would otherwise end with `None`.
171                // Note: checking `decoded.frames()` is more reliable than `packet.dur()`, which
172                // can returns non-zero durations for packets without audio frames.
173                Ok(decoded) if decoded.frames() > 0 => break decoded,
174                Ok(_) => {}
175                Err(Error::DecodeError(_)) if decode_errors < MAX_DECODE_RETRIES => {
176                    decode_errors += 1;
177                }
178                Err(_) => return None,
179            }
180        };
181
182        decoded.spec().clone_into(&mut self.spec);
183        self.buffer = Self::get_buffer(decoded, self.spec);
184        self.current_span_offset = 1;
185        self.buffer.samples().first().copied()
186    }
187}
188
189#[allow(clippy::module_name_repetitions)]
190pub struct MecompDecoder<R = FastFixedIn<f32>> {
191    resampler: Pool<Result<R, ResamplerConstructionError>>,
192}
193
194impl MecompDecoder {
195    #[inline]
196    fn generate_resampler() -> Result<FastFixedIn<f32>, ResamplerConstructionError> {
197        FastFixedIn::new(1.0, 10.0, rubato::PolynomialDegree::Cubic, CHUNK_SIZE, 1)
198    }
199
200    /// Create a new `MecompDecoder`
201    ///
202    /// # Errors
203    ///
204    /// This function will return an error if the resampler could not be created.
205    #[inline]
206    pub fn new() -> Result<Self, AnalysisError> {
207        // try to generate a resampler first, so we can return an error if it fails (if it fails, it's likely all future calls will too)
208        let first = Self::generate_resampler()?;
209
210        let pool_size = std::thread::available_parallelism().map_or(1, NonZeroUsize::get);
211        let resampler = Pool::new(pool_size, Self::generate_resampler);
212        resampler.attach(Ok(first));
213
214        Ok(Self { resampler })
215    }
216
217    /// we need to collapse the audio source into one channel
218    /// channels are interleaved, so if we have 2 channels, `[1, 2, 3, 4]` and `[5, 6, 7, 8]`,
219    /// they will be stored as `[1, 5, 2, 6, 3, 7, 4, 8]`
220    ///
221    /// For stereo sound, we can make this mono by averaging the channels and multiplying by the square root of 2,
222    /// This recovers the exact behavior of ffmpeg when converting stereo to mono, however for 2.1 and 5.1 surround sound,
223    /// ffmpeg might be doing something different, and I'm not sure what that is (don't have a 5.1 surround sound file to test with)
224    ///
225    /// TODO: Figure out how ffmpeg does it for 2.1 and 5.1 surround sound, and do it the same way
226    #[inline]
227    #[doc(hidden)]
228    pub fn into_mono_samples(
229        source: Vec<f32>,
230        num_channels: usize,
231    ) -> Result<Vec<f32>, AnalysisError> {
232        match num_channels {
233            // no channels
234            0 => Err(AnalysisError::DecodeError(Error::DecodeError(
235                "The audio source has no channels",
236            ))),
237            // mono
238            1 => Ok(source),
239            // stereo
240            2 => Ok(source
241                .chunks_exact(2)
242                .map(|chunk| (chunk[0] + chunk[1]) * SQRT_2 / 2.)
243                .collect()),
244            // 2.1 or 5.1 surround
245            _ => {
246                log::warn!(
247                    "The audio source has more than 2 channels (might be 2.1 or 5.1 surround sound), will collapse to mono by averaging the channels"
248                );
249
250                #[allow(clippy::cast_precision_loss)]
251                let num_channels_f32 = num_channels as f32;
252                let mono_samples = source
253                    .chunks_exact(num_channels)
254                    .map(|chunk| chunk.iter().sum::<f32>() / num_channels_f32)
255                    .collect();
256
257                Ok(mono_samples)
258            }
259        }
260    }
261
262    /// Resample the given mono samples to 22050 Hz
263    #[inline]
264    #[doc(hidden)]
265    pub fn resample_mono_samples(
266        &self,
267        mut samples: Vec<f32>,
268        sample_rate: u32,
269        total_duration: Duration,
270    ) -> Result<Vec<f32>, AnalysisError> {
271        if sample_rate == SAMPLE_RATE {
272            samples.shrink_to_fit();
273            return Ok(samples);
274        }
275
276        let mut resampled_frames = Vec::with_capacity(
277            usize::try_from((total_duration.as_secs() + 1) * u64::from(SAMPLE_RATE))
278                .unwrap_or(usize::MAX),
279        );
280
281        let (pool, resampler) = self.resampler.pull(Self::generate_resampler).detach();
282        let mut resampler = resampler?;
283        resampler.set_resample_ratio(f64::from(SAMPLE_RATE) / f64::from(sample_rate), false)?;
284
285        let delay = resampler.output_delay();
286
287        let new_length = samples.len() * SAMPLE_RATE as usize / sample_rate as usize;
288        let mut output_buffer = resampler.output_buffer_allocate(true);
289
290        // chunks of frames, each being CHUNKSIZE long.
291        let sample_chunks = samples.chunks_exact(CHUNK_SIZE);
292        let remainder = sample_chunks.remainder();
293
294        for chunk in sample_chunks {
295            debug_assert!(resampler.input_frames_next() == CHUNK_SIZE);
296
297            let (_, output_written) =
298                resampler.process_into_buffer(&[chunk], output_buffer.as_mut_slice(), None)?;
299            resampled_frames.extend_from_slice(&output_buffer[0][..output_written]);
300        }
301
302        // process the remainder
303        if !remainder.is_empty() {
304            let (_, output_written) = resampler.process_partial_into_buffer(
305                Some(&[remainder]),
306                output_buffer.as_mut_slice(),
307                None,
308            )?;
309            resampled_frames.extend_from_slice(&output_buffer[0][..output_written]);
310        }
311
312        // flush final samples from resampler
313        if resampled_frames.len() < new_length + delay {
314            let (_, output_written) = resampler.process_partial_into_buffer(
315                Option::<&[&[f32]]>::None,
316                output_buffer.as_mut_slice(),
317                None,
318            )?;
319            resampled_frames.extend_from_slice(&output_buffer[0][..output_written]);
320        }
321
322        resampler.reset();
323        pool.attach(Ok(resampler));
324
325        Ok(resampled_frames[delay..new_length + delay].to_vec())
326    }
327}
328
329impl Decoder for MecompDecoder {
330    /// A function that should decode and resample a song, optionally
331    /// extracting the song's metadata such as the artist, the album, etc.
332    ///
333    /// The output sample array should be resampled to f32le, one channel, with a sampling rate
334    /// of 22050 Hz. Anything other than that will yield wrong results.
335    #[allow(clippy::missing_inline_in_public_items)]
336    fn decode(&self, path: &std::path::Path) -> AnalysisResult<ResampledAudio> {
337        // open the file
338        let file = File::open(path)?;
339        // create the media source stream
340        let mss = MediaSourceStream::new(Box::new(file), MediaSourceStreamOptions::default());
341
342        let source = SymphoniaSource::new(mss)?;
343
344        // Convert the audio source into a mono channel
345        let sample_rate = source.spec.rate;
346        let Some(total_duration) = source.total_duration else {
347            return Err(AnalysisError::IndeterminantDuration);
348        };
349        let num_channels = source.channels();
350
351        let mono_sample_array =
352            Self::into_mono_samples(source.into_iter().collect(), num_channels)?;
353
354        // then we need to resample the audio source into 22050 Hz
355        let resampled_array =
356            self.resample_mono_samples(mono_sample_array, sample_rate, total_duration)?;
357
358        Ok(ResampledAudio {
359            path: path.to_owned(),
360            samples: resampled_array,
361        })
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use crate::NUMBER_FEATURES;
368
369    use super::{Decoder as DecoderTrait, MecompDecoder as Decoder};
370    use adler32::RollingAdler32;
371    use pretty_assertions::assert_eq;
372    use rstest::rstest;
373    use std::path::Path;
374
375    fn verify_decoding_output(path: &Path, expected_hash: u32) {
376        let decoder = Decoder::new().unwrap();
377        let song = decoder.decode(path).unwrap();
378        let mut hasher = RollingAdler32::new();
379        for sample in &song.samples {
380            hasher.update_buffer(&sample.to_le_bytes());
381        }
382
383        assert_eq!(expected_hash, hasher.hash());
384    }
385
386    // expected hash Obtained through
387    // ffmpeg -i data/s16_stereo_22_5kHz.flac -ar 22050 -ac 1 -c:a pcm_f32le -f hash -hash adler32 -
388    #[rstest]
389    #[ignore = "fails when asked to resample to 22050 Hz, ig ffmpeg does it differently, but I'm not sure what the difference actually is"]
390    #[case::resample_stereo(Path::new("data/s32_stereo_44_1_kHz.flac"), 0xbbcb_a1cf)]
391    #[ignore = "fails when asked to resample to 22050 Hz, ig ffmpeg does it differently, but I'm not sure what the difference actually is"]
392    #[case::resample_mono(Path::new("data/s32_mono_44_1_kHz.flac"), 0xa0f8_b8af)]
393    #[case::decode_stereo(Path::new("data/s16_stereo_22_5kHz.flac"), 0x1d7b_2d6d)]
394    #[case::decode_mono(Path::new("data/s16_mono_22_5kHz.flac"), 0x5e01_930b)]
395    #[ignore = "fails when asked to resample to 22050 Hz, ig ffmpeg does it differently, but I'm not sure what the difference actually is"]
396    #[case::resample_mp3(Path::new("data/s32_stereo_44_1_kHz.mp3"), 0x69ca_6906)]
397    #[case::decode_wav(Path::new("data/piano.wav"), 0xde83_1e82)]
398    fn test_decode(#[case] path: &Path, #[case] expected_hash: u32) {
399        verify_decoding_output(path, expected_hash);
400    }
401
402    #[test]
403    fn test_dont_panic_no_channel_layout() {
404        let path = Path::new("data/no_channel.wav");
405        Decoder::new().unwrap().decode(path).unwrap();
406    }
407
408    #[test]
409    fn test_decode_right_capacity_vec() {
410        let path = Path::new("data/s16_mono_22_5kHz.flac");
411        let song = Decoder::new().unwrap().decode(path).unwrap();
412        let sample_array = song.samples;
413        assert_eq!(
414            sample_array.len(), // + SAMPLE_RATE as usize, // The + SAMPLE_RATE is because bliss-rs would add an extra second as a buffer, we don't need to because we know the exact length of the song
415            sample_array.capacity()
416        );
417
418        let path = Path::new("data/s32_stereo_44_1_kHz.flac");
419        let song = Decoder::new().unwrap().decode(path).unwrap();
420        let sample_array = song.samples;
421        assert_eq!(
422            sample_array.len(), // + SAMPLE_RATE as usize,
423            sample_array.capacity()
424        );
425
426        // NOTE: originally used the .ogg file, but it was failing to decode with `DecodeError(IoError("end of stream"))`
427        let path = Path::new("data/capacity_fix.wav");
428        let song = Decoder::new().unwrap().decode(path).unwrap();
429        let sample_array = song.samples;
430        assert_eq!(
431            sample_array.len(), // + SAMPLE_RATE as usize,
432            sample_array.capacity()
433        );
434    }
435
436    const PATH_AND_EXPECTED_ANALYSIS: (&str, [f64; NUMBER_FEATURES]) = (
437        "data/s16_mono_22_5kHz.flac",
438        [
439            0.384_638_9,
440            -0.849_141_,
441            -0.754_810_45,
442            -0.879_074_8,
443            -0.632_582_66,
444            -0.725_895_9,
445            -0.775_737_9,
446            -0.814_672_6,
447            0.271_672_6,
448            0.257_790_57,
449            -0.342_925_13,
450            -0.628_034_23,
451            -0.280_950_96,
452            0.086_864_59,
453            0.244_460_82,
454            -0.572_325_7,
455            0.232_920_65,
456            0.199_811_46,
457            -0.585_944_06,
458            -0.067_842_96,
459            -0.060_007_63,
460            -0.584_857_17,
461            -0.078_803_78,
462        ],
463    );
464
465    #[test]
466    fn test_analyze() {
467        let (path, expected_analysis) = PATH_AND_EXPECTED_ANALYSIS;
468        let analysis = Decoder::new()
469            .unwrap()
470            .analyze_path(Path::new(path))
471            .unwrap();
472        for (x, y) in analysis.as_vec().iter().zip(expected_analysis) {
473            assert!(
474                1e-5 > (x - y).abs(),
475                "Expected {x} to be within 1e-5 of {y}, but it was not"
476            );
477        }
478    }
479
480    const RESAMPLED_PATH_AND_EXPECTED_ANALYSIS: (&str, [f64; NUMBER_FEATURES]) = (
481        "data/s32_stereo_44_1_kHz.flac",
482        [
483            0.38463664,
484            -0.85172224,
485            -0.7607465,
486            -0.8857495,
487            -0.63906085,
488            -0.73908424,
489            -0.7890965,
490            -0.8191868,
491            0.33856833,
492            0.3246863,
493            -0.34292227,
494            -0.62803173,
495            -0.2809453,
496            0.08687115,
497            0.2444489,
498            -0.5723239,
499            0.23292565,
500            0.19979525,
501            -0.58593845,
502            -0.06783122,
503            -0.060014784,
504            -0.5848569,
505            -0.07879859,
506        ],
507    );
508
509    #[test]
510    fn test_analyze_resampled() {
511        let (path, expected_analysis) = RESAMPLED_PATH_AND_EXPECTED_ANALYSIS;
512        let analysis = Decoder::new()
513            .unwrap()
514            .analyze_path(Path::new(path))
515            .unwrap();
516
517        for (x, y) in analysis.as_vec().iter().zip(expected_analysis) {
518            assert!(
519                0.1 > (x - y).abs(),
520                "Expected {x} to be within 0.1 of {y}, but it was not"
521            );
522        }
523    }
524}