Skip to main content

mecomp_analysis/decoder/
mecomp.rs

1//! Implementation of the mecomp decoder, which is rodio/rubato based.
2
3use std::{f32::consts::SQRT_2, fs::File, num::NonZeroUsize, time::Duration};
4
5use object_pool::Pool;
6use rubato::{FastFixedIn, Resampler, ResamplerConstructionError};
7use symphonia::{
8    core::{
9        audio::{AudioBufferRef, SampleBuffer, SignalSpec},
10        codecs::{CODEC_TYPE_NULL, DecoderOptions},
11        errors::Error,
12        formats::{FormatOptions, FormatReader},
13        io::{MediaSourceStream, MediaSourceStreamOptions},
14        meta::MetadataOptions,
15        probe::Hint,
16        units,
17    },
18    default::get_probe,
19};
20
21use crate::{ResampledAudio, SAMPLE_RATE, errors::AnalysisError, errors::AnalysisResult};
22
23use super::Decoder;
24
25const MAX_DECODE_RETRIES: usize = 3;
26const CHUNK_SIZE: usize = 4096;
27
28/// Struct used by the symphonia-based bliss decoders to decode audio files
29#[doc(hidden)]
30pub struct SymphoniaSource {
31    decoder: Box<dyn symphonia::core::codecs::Decoder>,
32    current_span_offset: usize,
33    format: Box<dyn FormatReader>,
34    total_duration: Option<Duration>,
35    buffer: SampleBuffer<f32>,
36    spec: SignalSpec,
37}
38
39impl SymphoniaSource {
40    /// Create a new `SymphoniaSource` from a `MediaSourceStream`
41    ///
42    /// # Errors
43    ///
44    /// This function will return an error if the `MediaSourceStream` does not contain any streams, or if the stream
45    /// is not supported by the decoder.
46    pub fn new(mss: MediaSourceStream) -> Result<Self, Error> {
47        Self::init(mss)?.ok_or(Error::DecodeError("No Streams"))
48    }
49
50    fn init(mss: MediaSourceStream) -> symphonia::core::errors::Result<Option<Self>> {
51        let hint = Hint::new();
52        let format_opts = FormatOptions {
53            enable_gapless: true,
54            ..Default::default()
55        };
56        let metadata_opts = MetadataOptions::default();
57        let mut probed_format = get_probe()
58            .format(&hint, mss, &format_opts, &metadata_opts)?
59            .format;
60
61        let Some(stream) = probed_format.default_track() else {
62            return Ok(None);
63        };
64
65        // Select the first supported track
66        let track = probed_format
67            .tracks()
68            .iter()
69            .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
70            .ok_or(Error::Unsupported("No track with supported codec"))?;
71
72        let track_id = track.id;
73
74        let mut decoder = symphonia::default::get_codecs()
75            .make(&track.codec_params, &DecoderOptions::default())?;
76        let total_duration = stream
77            .codec_params
78            .time_base
79            .zip(stream.codec_params.n_frames)
80            .map(|(base, spans)| base.calc_time(spans).into());
81
82        let mut decode_errors: usize = 0;
83        let decoded_audio = loop {
84            let current_span = probed_format.next_packet()?;
85
86            // If the packet does not belong to the selected track, skip over it
87            if current_span.track_id() != track_id {
88                continue;
89            }
90
91            match decoder.decode(&current_span) {
92                Ok(audio) => break audio,
93                Err(Error::DecodeError(_)) if decode_errors < MAX_DECODE_RETRIES => {
94                    decode_errors += 1;
95                }
96                Err(e) => return Err(e),
97            }
98        };
99
100        let spec = decoded_audio.spec().to_owned();
101        let buffer = Self::get_buffer(decoded_audio, spec);
102        Ok(Some(Self {
103            decoder,
104            current_span_offset: 0,
105            format: probed_format,
106            total_duration,
107            buffer,
108            spec,
109        }))
110    }
111
112    #[inline]
113    fn get_buffer(decoded: AudioBufferRef<'_>, spec: SignalSpec) -> SampleBuffer<f32> {
114        let duration = units::Duration::from(decoded.capacity() as u64);
115        let mut buffer = SampleBuffer::<f32>::new(duration, spec);
116        buffer.copy_interleaved_ref(decoded);
117        buffer
118    }
119
120    #[inline]
121    #[must_use]
122    pub const fn total_duration(&self) -> Option<Duration> {
123        self.total_duration
124    }
125
126    #[inline]
127    #[must_use]
128    pub const fn sample_rate(&self) -> u32 {
129        self.spec.rate
130    }
131
132    #[inline]
133    #[must_use]
134    pub fn channels(&self) -> usize {
135        self.spec.channels.count()
136    }
137}
138
139impl Iterator for SymphoniaSource {
140    type Item = f32;
141
142    fn size_hint(&self) -> (usize, Option<usize>) {
143        (
144            self.buffer.samples().len(),
145            self.total_duration.map(|dur| {
146                usize::try_from(
147                    (dur.as_secs() + 1)
148                        * u64::from(self.spec.rate)
149                        * self.spec.channels.count() as u64,
150                )
151                .unwrap_or(usize::MAX)
152            }),
153        )
154    }
155
156    fn next(&mut self) -> Option<Self::Item> {
157        if self.current_span_offset < self.buffer.len() {
158            let sample = self.buffer.samples().get(self.current_span_offset);
159            self.current_span_offset += 1;
160
161            return sample.copied();
162        }
163
164        let mut decode_errors = 0;
165        let decoded = loop {
166            let packet = self.format.next_packet().ok()?;
167            match self.decoder.decode(&packet) {
168                // Loop until we get a packet with audio frames. This is necessary because some
169                // formats can have packets with only metadata, particularly when rewinding, in
170                // which case the iterator would otherwise end with `None`.
171                // Note: checking `decoded.frames()` is more reliable than `packet.dur()`, which
172                // can returns non-zero durations for packets without audio frames.
173                Ok(decoded) if decoded.frames() > 0 => break decoded,
174                Ok(_) => {}
175                Err(Error::DecodeError(_)) if decode_errors < MAX_DECODE_RETRIES => {
176                    decode_errors += 1;
177                }
178                Err(_) => return None,
179            }
180        };
181
182        decoded.spec().clone_into(&mut self.spec);
183        self.buffer = Self::get_buffer(decoded, self.spec);
184        self.current_span_offset = 1;
185        self.buffer.samples().first().copied()
186    }
187}
188
189#[allow(clippy::module_name_repetitions)]
190pub struct MecompDecoder<R = FastFixedIn<f32>> {
191    resampler: Pool<Result<R, ResamplerConstructionError>>,
192}
193
194impl MecompDecoder {
195    #[inline]
196    fn generate_resampler() -> Result<FastFixedIn<f32>, ResamplerConstructionError> {
197        FastFixedIn::new(1.0, 10.0, rubato::PolynomialDegree::Cubic, CHUNK_SIZE, 1)
198    }
199
200    /// Create a new `MecompDecoder`
201    ///
202    /// # Errors
203    ///
204    /// This function will return an error if the resampler could not be created.
205    #[inline]
206    pub fn new() -> Result<Self, AnalysisError> {
207        // try to generate a resampler first, so we can return an error if it fails (if it fails, it's likely all future calls will too)
208        let first = Self::generate_resampler()?;
209
210        let pool_size = std::thread::available_parallelism().map_or(1, NonZeroUsize::get);
211        let resampler = Pool::new(pool_size, Self::generate_resampler);
212        resampler.attach(Ok(first));
213
214        Ok(Self { resampler })
215    }
216
217    /// we need to collapse the audio source into one channel
218    /// channels are interleaved, so if we have 2 channels, `[1, 2, 3, 4]` and `[5, 6, 7, 8]`,
219    /// they will be stored as `[1, 5, 2, 6, 3, 7, 4, 8]`
220    ///
221    /// For stereo sound, we can make this mono by averaging the channels and multiplying by the square root of 2,
222    /// This recovers the exact behavior of ffmpeg when converting stereo to mono, however for 2.1 and 5.1 surround sound,
223    /// ffmpeg might be doing something different, and I'm not sure what that is (don't have a 5.1 surround sound file to test with)
224    ///
225    /// TODO: Figure out how ffmpeg does it for 2.1 and 5.1 surround sound, and do it the same way
226    #[inline]
227    #[doc(hidden)]
228    pub fn into_mono_samples(
229        source: Vec<f32>,
230        num_channels: usize,
231    ) -> Result<Vec<f32>, AnalysisError> {
232        match num_channels {
233            // no channels
234            0 => Err(AnalysisError::DecodeError(Error::DecodeError(
235                "The audio source has no channels",
236            ))),
237            // mono
238            1 => Ok(source),
239            // stereo
240            2 => {
241                let len = source.len() / 2;
242                let mut result = vec![0f32; len];
243                let scale = SQRT_2 * 0.5;
244
245                // process 8 stereo pairs (16 floats) at a time for better SIMD utilization
246                let (src_chunks, src_remainder) = source.as_chunks::<16>();
247                let (dest_chunks, dest_remainder) = result.as_chunks_mut::<8>();
248
249                for (src, dest) in src_chunks.iter().zip(dest_chunks) {
250                    // compiler should auto-vectorize this
251                    for i in 0..8 {
252                        dest[i] = (src[2 * i] + src[2 * i + 1]) * scale;
253                    }
254                }
255
256                // process the remainder
257                for (i, chunk) in src_remainder.chunks_exact(2).enumerate() {
258                    dest_remainder[i] = (chunk[0] + chunk[1]) * scale;
259                }
260
261                Ok(result)
262            }
263            // 2.1 or 5.1 surround
264            _ => {
265                log::warn!(
266                    "The audio source has more than 2 channels (might be 2.1 or 5.1 surround sound), will collapse to mono by averaging the channels"
267                );
268
269                #[allow(clippy::cast_precision_loss)]
270                let num_channels_f32 = num_channels as f32;
271                let mono_samples = source
272                    .chunks_exact(num_channels)
273                    .map(|chunk| chunk.iter().sum::<f32>() / num_channels_f32)
274                    .collect();
275
276                Ok(mono_samples)
277            }
278        }
279    }
280
281    /// Resample the given mono samples to 22050 Hz
282    #[inline]
283    #[doc(hidden)]
284    pub fn resample_mono_samples(
285        &self,
286        mut samples: Vec<f32>,
287        sample_rate: u32,
288    ) -> Result<Vec<f32>, AnalysisError> {
289        if sample_rate == SAMPLE_RATE {
290            samples.shrink_to_fit();
291            return Ok(samples);
292        }
293
294        let resample_ratio = f64::from(SAMPLE_RATE) / f64::from(sample_rate);
295        #[allow(
296            clippy::cast_possible_truncation,
297            clippy::cast_sign_loss,
298            clippy::cast_precision_loss
299        )]
300        let mut resampled_frames = Vec::with_capacity(
301            (samples.len() as f64 * resample_ratio) as usize + SAMPLE_RATE as usize, // add an extra second as a buffer
302        );
303
304        let (pool, resampler) = self.resampler.pull(Self::generate_resampler).detach();
305        let mut resampler = resampler?;
306        resampler.set_resample_ratio(resample_ratio, false)?;
307
308        let delay = resampler.output_delay();
309
310        let new_length = samples.len() * SAMPLE_RATE as usize / sample_rate as usize;
311        let mut output_buffer = resampler.output_buffer_allocate(true);
312
313        // chunks of frames, each being CHUNKSIZE long.
314        let sample_chunks = samples.chunks_exact(CHUNK_SIZE);
315        let remainder = sample_chunks.remainder();
316
317        for chunk in sample_chunks {
318            debug_assert!(resampler.input_frames_next() == CHUNK_SIZE);
319
320            let (_, output_written) =
321                resampler.process_into_buffer(&[chunk], output_buffer.as_mut_slice(), None)?;
322            resampled_frames.extend_from_slice(&output_buffer[0][..output_written]);
323        }
324
325        // process the remainder
326        if !remainder.is_empty() {
327            let (_, output_written) = resampler.process_partial_into_buffer(
328                Some(&[remainder]),
329                output_buffer.as_mut_slice(),
330                None,
331            )?;
332            resampled_frames.extend_from_slice(&output_buffer[0][..output_written]);
333        }
334
335        // flush final samples from resampler
336        if resampled_frames.len() < new_length + delay {
337            let (_, output_written) = resampler.process_partial_into_buffer(
338                Option::<&[&[f32]]>::None,
339                output_buffer.as_mut_slice(),
340                None,
341            )?;
342            resampled_frames.extend_from_slice(&output_buffer[0][..output_written]);
343        }
344
345        resampler.reset();
346        pool.attach(Ok(resampler));
347
348        Ok(resampled_frames[delay..new_length + delay].to_vec())
349    }
350}
351
352impl Decoder for MecompDecoder {
353    /// A function that should decode and resample a song, optionally
354    /// extracting the song's metadata such as the artist, the album, etc.
355    ///
356    /// The output sample array should be resampled to f32le, one channel, with a sampling rate
357    /// of 22050 Hz. Anything other than that will yield wrong results.
358    #[allow(clippy::missing_inline_in_public_items)]
359    fn decode(&self, path: &std::path::Path) -> AnalysisResult<ResampledAudio> {
360        // open the file
361        let file = File::open(path)?;
362        // create the media source stream
363        let mss = MediaSourceStream::new(Box::new(file), MediaSourceStreamOptions::default());
364
365        let source = SymphoniaSource::new(mss)?;
366
367        // Convert the audio source into a mono channel
368        let sample_rate = source.spec.rate;
369        let num_channels = source.channels();
370
371        let mono_sample_array =
372            Self::into_mono_samples(source.into_iter().collect(), num_channels)?;
373
374        // then we need to resample the audio source into 22050 Hz
375        let resampled_array = self.resample_mono_samples(mono_sample_array, sample_rate)?;
376
377        Ok(ResampledAudio {
378            path: path.to_owned(),
379            samples: resampled_array,
380        })
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use crate::{NUMBER_FEATURES, embeddings::ModelConfig};
387
388    use super::{Decoder as DecoderTrait, MecompDecoder as Decoder};
389    use adler32::RollingAdler32;
390    use pretty_assertions::assert_eq;
391    use rstest::rstest;
392    use std::{collections::HashMap, path::Path, sync::mpsc};
393
394    fn verify_decoding_output(path: &Path, expected_hash: u32) {
395        let decoder = Decoder::new().unwrap();
396        let song = decoder.decode(path).unwrap();
397        let mut hasher = RollingAdler32::new();
398        for sample in &song.samples {
399            hasher.update_buffer(&sample.to_le_bytes());
400        }
401
402        assert_eq!(expected_hash, hasher.hash());
403    }
404
405    // expected hash Obtained through
406    // ffmpeg -i data/s16_stereo_22_5kHz.flac -ar 22050 -ac 1 -c:a pcm_f32le -f hash -hash adler32 -
407    #[rstest]
408    #[ignore = "fails when asked to resample to 22050 Hz, ig ffmpeg does it differently, but I'm not sure what the difference actually is"]
409    #[case::resample_stereo(Path::new("data/s32_stereo_44_1_kHz.flac"), 0xbbcb_a1cf)]
410    #[ignore = "fails when asked to resample to 22050 Hz, ig ffmpeg does it differently, but I'm not sure what the difference actually is"]
411    #[case::resample_mono(Path::new("data/s32_mono_44_1_kHz.flac"), 0xa0f8_b8af)]
412    #[case::decode_stereo(Path::new("data/s16_stereo_22_5kHz.flac"), 0x1d7b_2d6d)]
413    #[case::decode_mono(Path::new("data/s16_mono_22_5kHz.flac"), 0x5e01_930b)]
414    #[ignore = "fails when asked to resample to 22050 Hz, ig ffmpeg does it differently, but I'm not sure what the difference actually is"]
415    #[case::resample_mp3(Path::new("data/s32_stereo_44_1_kHz.mp3"), 0x69ca_6906)]
416    #[case::decode_wav(Path::new("data/piano.wav"), 0xde83_1e82)]
417    fn test_decode(#[case] path: &Path, #[case] expected_hash: u32) {
418        verify_decoding_output(path, expected_hash);
419    }
420
421    #[test]
422    fn test_dont_panic_no_channel_layout() {
423        let path = Path::new("data/no_channel.wav");
424        Decoder::new().unwrap().decode(path).unwrap();
425    }
426
427    #[test]
428    fn test_decode_right_capacity_vec() {
429        let path = Path::new("data/s16_mono_22_5kHz.flac");
430        let song = Decoder::new().unwrap().decode(path).unwrap();
431        let sample_array = song.samples;
432        assert_eq!(
433            sample_array.len(), // + SAMPLE_RATE as usize, // The + SAMPLE_RATE is because bliss-rs would add an extra second as a buffer, we don't need to because we know the exact length of the song
434            sample_array.capacity()
435        );
436
437        let path = Path::new("data/s32_stereo_44_1_kHz.flac");
438        let song = Decoder::new().unwrap().decode(path).unwrap();
439        let sample_array = song.samples;
440        assert_eq!(
441            sample_array.len(), // + SAMPLE_RATE as usize,
442            sample_array.capacity()
443        );
444
445        // NOTE: originally used the .ogg file, but it was failing to decode with `DecodeError(IoError("end of stream"))`
446        let path = Path::new("data/capacity_fix.wav");
447        let song = Decoder::new().unwrap().decode(path).unwrap();
448        let sample_array = song.samples;
449        assert_eq!(
450            sample_array.len(), // + SAMPLE_RATE as usize,
451            sample_array.capacity()
452        );
453    }
454
455    const PATH_AND_EXPECTED_ANALYSIS: (&str, [f32; NUMBER_FEATURES]) = (
456        "data/s16_mono_22_5kHz.flac",
457        [
458            0.384_638_9,
459            -0.849_141_,
460            -0.754_810_45,
461            -0.879_074_8,
462            -0.632_582_66,
463            -0.725_895_9,
464            -0.775_737_9,
465            -0.814_672_6,
466            0.271_672_6,
467            0.257_790_57,
468            -0.342_925_13,
469            -0.628_034_23,
470            -0.280_950_96,
471            0.086_864_59,
472            0.244_460_82,
473            -0.572_325_7,
474            0.232_920_65,
475            0.199_811_46,
476            -0.585_944_06,
477            -0.067_842_96,
478            -0.060_007_63,
479            -0.584_857_17,
480            -0.078_803_78,
481        ],
482    );
483
484    #[test]
485    fn test_analyze() {
486        let (path, expected_analysis) = PATH_AND_EXPECTED_ANALYSIS;
487        let analysis = Decoder::new()
488            .unwrap()
489            .analyze_path(Path::new(path))
490            .unwrap();
491        for (x, y) in analysis.as_vec().iter().zip(expected_analysis) {
492            assert!(
493                1e-5 > (x - y).abs(),
494                "Expected {x} to be within 1e-5 of {y}, but it was not"
495            );
496        }
497    }
498
499    const RESAMPLED_PATH_AND_EXPECTED_ANALYSIS: (&str, [f32; NUMBER_FEATURES]) = (
500        "data/s32_stereo_44_1_kHz.flac",
501        [
502            0.38463664,
503            -0.85172224,
504            -0.7607465,
505            -0.8857495,
506            -0.63906085,
507            -0.73908424,
508            -0.7890965,
509            -0.8191868,
510            0.33856833,
511            0.3246863,
512            -0.34292227,
513            -0.62803173,
514            -0.2809453,
515            0.08687115,
516            0.2444489,
517            -0.5723239,
518            0.23292565,
519            0.19979525,
520            -0.58593845,
521            -0.06783122,
522            -0.060014784,
523            -0.5848569,
524            -0.07879859,
525        ],
526    );
527
528    #[test]
529    fn test_analyze_resampled() {
530        let (path, expected_analysis) = RESAMPLED_PATH_AND_EXPECTED_ANALYSIS;
531        let analysis = Decoder::new()
532            .unwrap()
533            .analyze_path(Path::new(path))
534            .unwrap();
535
536        for (x, y) in analysis.as_vec().iter().zip(expected_analysis) {
537            assert!(
538                0.1 > (x - y).abs(),
539                "Expected {x} to be within 0.1 of {y}, but it was not"
540            );
541        }
542    }
543
544    #[test]
545    fn test_analyze_paths() {
546        // get the paths to every music file in the "data" directory
547        let paths = Path::new(env!("CARGO_MANIFEST_DIR"))
548            .join("data")
549            .read_dir()
550            .unwrap()
551            .map(|entry| entry.unwrap().path())
552            .filter(|p| {
553                p.is_file()
554                    && (p.extension().unwrap() == "wav"
555                        || p.extension().unwrap() == "flac"
556                        || p.extension().unwrap() == "mp3")
557            })
558            .collect::<Vec<_>>();
559
560        // track which paths we analyzed
561        let mut analyzed_paths = HashMap::new();
562        for path in &paths {
563            analyzed_paths.insert(path.clone(), false);
564        }
565
566        // ensure we *can* analyze these without errors
567        let decoder = Decoder::new().unwrap();
568        let mut count = 0;
569        let expected = paths.len();
570        let (tx, rx) = mpsc::channel();
571        let handle = std::thread::spawn(move || decoder.analyze_paths(&paths, tx));
572        for (path, analysis) in rx {
573            count += 1;
574            assert!(analysis.is_ok(), "Failed to analyze {path:?}: {analysis:?}",);
575            assert_eq!(
576                analyzed_paths.insert(path.clone(), true),
577                Some(false),
578                "Analyzed the same path twice: {path:?}"
579            );
580        }
581
582        assert_eq!(count, expected);
583        assert!(
584            analyzed_paths.values().all(|&v| v),
585            "Not all paths were analyzed: {analyzed_paths:?}"
586        );
587
588        handle.join().unwrap().unwrap();
589    }
590
591    #[test]
592    fn test_process_paths() {
593        // get the paths to every music file in the "data" directory
594        let paths = Path::new(env!("CARGO_MANIFEST_DIR"))
595            .join("data")
596            .read_dir()
597            .unwrap()
598            .map(|entry| entry.unwrap().path())
599            .filter(|p| {
600                p.is_file()
601                    && (p.extension().unwrap() == "wav"
602                        || p.extension().unwrap() == "flac"
603                        || p.extension().unwrap() == "mp3")
604            })
605            .collect::<Vec<_>>();
606
607        // track which paths we analyzed
608        let mut analyzed_paths = HashMap::new();
609        for path in &paths {
610            analyzed_paths.insert(path.clone(), false);
611        }
612
613        // ensure we *can* analyze these without errors
614        let decoder = Decoder::new().unwrap();
615        let model_config = ModelConfig::default();
616        let (tx, rx) = std::sync::mpsc::sync_channel(4);
617
618        // spawn a thread to process the songs
619        let paths_clone = paths.clone();
620        std::thread::spawn(move || decoder.process_songs(&paths_clone, tx, model_config));
621
622        let mut count = 0;
623        for (path, analysis, embedding) in rx {
624            count += 1;
625            assert!(analysis.is_ok(), "Failed to analyze {path:?}: {analysis:?}");
626            assert!(embedding.is_ok(), "Failed to embed {path:?}: {embedding:?}");
627            assert_eq!(
628                analyzed_paths.insert(path.clone(), true),
629                Some(false),
630                "Analyzed the same path twice: {path:?}"
631            );
632        }
633        assert_eq!(count, paths.len());
634        assert!(
635            analyzed_paths.values().all(|&v| v),
636            "Not all paths were analyzed: {analyzed_paths:?}"
637        );
638    }
639}