mecomp_analysis/decoder/
mecomp.rs

1//! Implementation of the mecomp decoder, which is rodio/rubato based.
2
3use std::{f32::consts::SQRT_2, fs::File, num::NonZeroUsize, time::Duration};
4
5use object_pool::Pool;
6use rubato::{FastFixedIn, Resampler, ResamplerConstructionError};
7use symphonia::{
8    core::{
9        audio::{AudioBufferRef, SampleBuffer, SignalSpec},
10        codecs::{CODEC_TYPE_NULL, DecoderOptions},
11        errors::Error,
12        formats::{FormatOptions, FormatReader},
13        io::{MediaSourceStream, MediaSourceStreamOptions},
14        meta::MetadataOptions,
15        probe::Hint,
16        units,
17    },
18    default::get_probe,
19};
20
21use crate::{ResampledAudio, SAMPLE_RATE, errors::AnalysisError, errors::AnalysisResult};
22
23use super::Decoder;
24
25const MAX_DECODE_RETRIES: usize = 3;
26const CHUNK_SIZE: usize = 4096;
27
28/// Struct used by the symphonia-based bliss decoders to decode audio files
29#[doc(hidden)]
30pub struct SymphoniaSource {
31    decoder: Box<dyn symphonia::core::codecs::Decoder>,
32    current_span_offset: usize,
33    format: Box<dyn FormatReader>,
34    total_duration: Option<Duration>,
35    buffer: SampleBuffer<f32>,
36    spec: SignalSpec,
37}
38
39impl SymphoniaSource {
40    /// Create a new `SymphoniaSource` from a `MediaSourceStream`
41    ///
42    /// # Errors
43    ///
44    /// This function will return an error if the `MediaSourceStream` does not contain any streams, or if the stream
45    /// is not supported by the decoder.
46    pub fn new(mss: MediaSourceStream) -> Result<Self, Error> {
47        Self::init(mss)?.ok_or(Error::DecodeError("No Streams"))
48    }
49
50    fn init(mss: MediaSourceStream) -> symphonia::core::errors::Result<Option<Self>> {
51        let hint = Hint::new();
52        let format_opts = FormatOptions::default();
53        let metadata_opts = MetadataOptions::default();
54        let mut probed_format = get_probe()
55            .format(&hint, mss, &format_opts, &metadata_opts)?
56            .format;
57
58        let Some(stream) = probed_format.default_track() else {
59            return Ok(None);
60        };
61
62        // Select the first supported track
63        let track = probed_format
64            .tracks()
65            .iter()
66            .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
67            .ok_or(Error::Unsupported("No track with supported codec"))?;
68
69        let track_id = track.id;
70
71        let mut decoder = symphonia::default::get_codecs()
72            .make(&track.codec_params, &DecoderOptions::default())?;
73        let total_duration = stream
74            .codec_params
75            .time_base
76            .zip(stream.codec_params.n_frames)
77            .map(|(base, spans)| base.calc_time(spans).into());
78
79        let mut decode_errors: usize = 0;
80        let decoded_audio = loop {
81            let current_span = probed_format.next_packet()?;
82
83            // If the packet does not belong to the selected track, skip over it
84            if current_span.track_id() != track_id {
85                continue;
86            }
87
88            match decoder.decode(&current_span) {
89                Ok(audio) => break audio,
90                Err(Error::DecodeError(_)) if decode_errors < MAX_DECODE_RETRIES => {
91                    decode_errors += 1;
92                }
93                Err(e) => return Err(e),
94            }
95        };
96
97        let spec = decoded_audio.spec().to_owned();
98        let buffer = Self::get_buffer(decoded_audio, spec);
99        Ok(Some(Self {
100            decoder,
101            current_span_offset: 0,
102            format: probed_format,
103            total_duration,
104            buffer,
105            spec,
106        }))
107    }
108
109    #[inline]
110    fn get_buffer(decoded: AudioBufferRef, spec: SignalSpec) -> SampleBuffer<f32> {
111        let duration = units::Duration::from(decoded.capacity() as u64);
112        let mut buffer = SampleBuffer::<f32>::new(duration, spec);
113        buffer.copy_interleaved_ref(decoded);
114        buffer
115    }
116
117    #[inline]
118    #[must_use]
119    pub const fn total_duration(&self) -> Option<Duration> {
120        self.total_duration
121    }
122
123    #[inline]
124    #[must_use]
125    pub const fn sample_rate(&self) -> u32 {
126        self.spec.rate
127    }
128
129    #[inline]
130    #[must_use]
131    pub fn channels(&self) -> usize {
132        self.spec.channels.count()
133    }
134}
135
136impl Iterator for SymphoniaSource {
137    type Item = f32;
138
139    fn size_hint(&self) -> (usize, Option<usize>) {
140        (
141            self.buffer.samples().len(),
142            self.total_duration.map(|dur| {
143                (usize::try_from(dur.as_secs()).unwrap_or(usize::MAX) + 1)
144                    * self.spec.rate as usize
145                    * self.spec.channels.count()
146            }),
147        )
148    }
149
150    fn next(&mut self) -> Option<Self::Item> {
151        if self.current_span_offset < self.buffer.len() {
152            let sample = self.buffer.samples().get(self.current_span_offset);
153            self.current_span_offset += 1;
154
155            return sample.copied();
156        }
157
158        let mut decode_errors = 0;
159        let decoded = loop {
160            let packet = self.format.next_packet().ok()?;
161            match self.decoder.decode(&packet) {
162                // Loop until we get a packet with audio frames. This is necessary because some
163                // formats can have packets with only metadata, particularly when rewinding, in
164                // which case the iterator would otherwise end with `None`.
165                // Note: checking `decoded.frames()` is more reliable than `packet.dur()`, which
166                // can returns non-zero durations for packets without audio frames.
167                Ok(decoded) if decoded.frames() > 0 => break decoded,
168                Ok(_) => {}
169                Err(Error::DecodeError(_)) if decode_errors < MAX_DECODE_RETRIES => {
170                    decode_errors += 1;
171                }
172                Err(_) => return None,
173            }
174        };
175
176        decoded.spec().clone_into(&mut self.spec);
177        self.buffer = Self::get_buffer(decoded, self.spec);
178        self.current_span_offset = 1;
179        self.buffer.samples().first().copied()
180    }
181}
182
183#[allow(clippy::module_name_repetitions)]
184pub struct MecompDecoder<R = FastFixedIn<f32>> {
185    resampler: Pool<Result<R, ResamplerConstructionError>>,
186}
187
188impl MecompDecoder {
189    #[inline]
190    fn generate_resampler() -> Result<FastFixedIn<f32>, ResamplerConstructionError> {
191        FastFixedIn::new(1.0, 10.0, rubato::PolynomialDegree::Cubic, CHUNK_SIZE, 1)
192    }
193
194    /// Create a new `MecompDecoder`
195    ///
196    /// # Errors
197    ///
198    /// This function will return an error if the resampler could not be created.
199    #[inline]
200    pub fn new() -> Result<Self, AnalysisError> {
201        // try to generate a resampler first, so we can return an error if it fails (if it fails, it's likely all future calls will too)
202        let first = Self::generate_resampler()?;
203
204        let pool_size = std::thread::available_parallelism().map_or(1, NonZeroUsize::get);
205        let resampler = Pool::new(pool_size, Self::generate_resampler);
206        resampler.attach(Ok(first));
207
208        Ok(Self { resampler })
209    }
210
211    /// we need to collapse the audio source into one channel
212    /// channels are interleaved, so if we have 2 channels, `[1, 2, 3, 4]` and `[5, 6, 7, 8]`,
213    /// they will be stored as `[1, 5, 2, 6, 3, 7, 4, 8]`
214    ///
215    /// For stereo sound, we can make this mono by averaging the channels and multiplying by the square root of 2,
216    /// This recovers the exact behavior of ffmpeg when converting stereo to mono, however for 2.1 and 5.1 surround sound,
217    /// ffmpeg might be doing something different, and I'm not sure what that is (don't have a 5.1 surround sound file to test with)
218    ///
219    /// TODO: Figure out how ffmpeg does it for 2.1 and 5.1 surround sound, and do it the same way
220    #[inline]
221    #[doc(hidden)]
222    pub fn into_mono_samples(
223        source: Vec<f32>,
224        num_channels: usize,
225    ) -> Result<Vec<f32>, AnalysisError> {
226        match num_channels {
227            // no channels
228            0 => Err(AnalysisError::DecodeError(Error::DecodeError(
229                "The audio source has no channels",
230            ))),
231            // mono
232            1 => Ok(source),
233            // stereo
234            2 => Ok(source
235                .chunks_exact(2)
236                .map(|chunk| (chunk[0] + chunk[1]) * SQRT_2 / 2.)
237                .collect()),
238            // 2.1 or 5.1 surround
239            _ => {
240                log::warn!(
241                    "The audio source has more than 2 channels (might be 2.1 or 5.1 surround sound), will collapse to mono by averaging the channels"
242                );
243
244                #[allow(clippy::cast_precision_loss)]
245                let num_channels_f32 = num_channels as f32;
246                let mono_samples = source
247                    .chunks_exact(num_channels)
248                    .map(|chunk| chunk.iter().sum::<f32>() / num_channels_f32)
249                    .collect();
250
251                Ok(mono_samples)
252            }
253        }
254    }
255
256    /// Resample the given mono samples to 22050 Hz
257    #[inline]
258    #[doc(hidden)]
259    pub fn resample_mono_samples(
260        &self,
261        mut samples: Vec<f32>,
262        sample_rate: u32,
263        total_duration: Duration,
264    ) -> Result<Vec<f32>, AnalysisError> {
265        if sample_rate == SAMPLE_RATE {
266            samples.shrink_to_fit();
267            return Ok(samples);
268        }
269
270        let mut resampled_frames = Vec::with_capacity(
271            (usize::try_from(total_duration.as_secs()).unwrap_or(usize::MAX) + 1)
272                * SAMPLE_RATE as usize,
273        );
274
275        let (pool, resampler) = self.resampler.pull(Self::generate_resampler).detach();
276        let mut resampler = resampler?;
277        resampler.set_resample_ratio(f64::from(SAMPLE_RATE) / f64::from(sample_rate), false)?;
278
279        let delay = resampler.output_delay();
280
281        let new_length = samples.len() * SAMPLE_RATE as usize / sample_rate as usize;
282        let mut output_buffer = resampler.output_buffer_allocate(true);
283
284        // chunks of frames, each being CHUNKSIZE long.
285        let sample_chunks = samples.chunks_exact(CHUNK_SIZE);
286        let remainder = sample_chunks.remainder();
287
288        for chunk in sample_chunks {
289            debug_assert!(resampler.input_frames_next() == CHUNK_SIZE);
290
291            let (_, output_written) =
292                resampler.process_into_buffer(&[chunk], output_buffer.as_mut_slice(), None)?;
293            resampled_frames.extend_from_slice(&output_buffer[0][..output_written]);
294        }
295
296        // process the remainder
297        if !remainder.is_empty() {
298            let (_, output_written) = resampler.process_partial_into_buffer(
299                Some(&[remainder]),
300                output_buffer.as_mut_slice(),
301                None,
302            )?;
303            resampled_frames.extend_from_slice(&output_buffer[0][..output_written]);
304        }
305
306        // flush final samples from resampler
307        if resampled_frames.len() < new_length + delay {
308            let (_, output_written) = resampler.process_partial_into_buffer(
309                Option::<&[&[f32]]>::None,
310                output_buffer.as_mut_slice(),
311                None,
312            )?;
313            resampled_frames.extend_from_slice(&output_buffer[0][..output_written]);
314        }
315
316        resampler.reset();
317        pool.attach(Ok(resampler));
318
319        Ok(resampled_frames[delay..new_length + delay].to_vec())
320    }
321}
322
323impl Decoder for MecompDecoder {
324    /// A function that should decode and resample a song, optionally
325    /// extracting the song's metadata such as the artist, the album, etc.
326    ///
327    /// The output sample array should be resampled to f32le, one channel, with a sampling rate
328    /// of 22050 Hz. Anything other than that will yield wrong results.
329    #[allow(clippy::missing_inline_in_public_items)]
330    fn decode(&self, path: &std::path::Path) -> AnalysisResult<ResampledAudio> {
331        // open the file
332        let file = File::open(path)?;
333        // create the media source stream
334        let mss = MediaSourceStream::new(Box::new(file), MediaSourceStreamOptions::default());
335
336        let source = SymphoniaSource::new(mss)?;
337
338        // Convert the audio source into a mono channel
339        let sample_rate = source.spec.rate;
340        let Some(total_duration) = source.total_duration else {
341            return Err(AnalysisError::IndeterminantDuration);
342        };
343        let num_channels = source.channels();
344
345        let mono_sample_array =
346            Self::into_mono_samples(source.into_iter().collect(), num_channels)?;
347
348        // then we need to resample the audio source into 22050 Hz
349        let resampled_array =
350            self.resample_mono_samples(mono_sample_array, sample_rate, total_duration)?;
351
352        Ok(ResampledAudio {
353            path: path.to_owned(),
354            samples: resampled_array,
355        })
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use crate::NUMBER_FEATURES;
362
363    use super::{Decoder as DecoderTrait, MecompDecoder as Decoder};
364    use adler32::RollingAdler32;
365    use pretty_assertions::assert_eq;
366    use rstest::rstest;
367    use std::path::Path;
368
369    fn verify_decoding_output(path: &Path, expected_hash: u32) {
370        let decoder = Decoder::new().unwrap();
371        let song = decoder.decode(path).unwrap();
372        let mut hasher = RollingAdler32::new();
373        for sample in &song.samples {
374            hasher.update_buffer(&sample.to_le_bytes());
375        }
376
377        assert_eq!(expected_hash, hasher.hash());
378    }
379
380    // expected hash Obtained through
381    // ffmpeg -i data/s16_stereo_22_5kHz.flac -ar 22050 -ac 1 -c:a pcm_f32le -f hash -hash adler32 -
382    #[rstest]
383    #[ignore = "fails when asked to resample to 22050 Hz, ig ffmpeg does it differently, but I'm not sure what the difference actually is"]
384    #[case::resample_stereo(Path::new("data/s32_stereo_44_1_kHz.flac"), 0xbbcb_a1cf)]
385    #[ignore = "fails when asked to resample to 22050 Hz, ig ffmpeg does it differently, but I'm not sure what the difference actually is"]
386    #[case::resample_mono(Path::new("data/s32_mono_44_1_kHz.flac"), 0xa0f8_b8af)]
387    #[case::decode_stereo(Path::new("data/s16_stereo_22_5kHz.flac"), 0x1d7b_2d6d)]
388    #[case::decode_mono(Path::new("data/s16_mono_22_5kHz.flac"), 0x5e01_930b)]
389    #[ignore = "fails when asked to resample to 22050 Hz, ig ffmpeg does it differently, but I'm not sure what the difference actually is"]
390    #[case::resample_mp3(Path::new("data/s32_stereo_44_1_kHz.mp3"), 0x69ca_6906)]
391    #[case::decode_wav(Path::new("data/piano.wav"), 0xde83_1e82)]
392    fn test_decode(#[case] path: &Path, #[case] expected_hash: u32) {
393        verify_decoding_output(path, expected_hash);
394    }
395
396    #[test]
397    fn test_dont_panic_no_channel_layout() {
398        let path = Path::new("data/no_channel.wav");
399        Decoder::new().unwrap().decode(path).unwrap();
400    }
401
402    #[test]
403    fn test_decode_right_capacity_vec() {
404        let path = Path::new("data/s16_mono_22_5kHz.flac");
405        let song = Decoder::new().unwrap().decode(path).unwrap();
406        let sample_array = song.samples;
407        assert_eq!(
408            sample_array.len(), // + SAMPLE_RATE as usize, // The + SAMPLE_RATE is because bliss-rs would add an extra second as a buffer, we don't need to because we know the exact length of the song
409            sample_array.capacity()
410        );
411
412        let path = Path::new("data/s32_stereo_44_1_kHz.flac");
413        let song = Decoder::new().unwrap().decode(path).unwrap();
414        let sample_array = song.samples;
415        assert_eq!(
416            sample_array.len(), // + SAMPLE_RATE as usize,
417            sample_array.capacity()
418        );
419
420        // NOTE: originally used the .ogg file, but it was failing to decode with `DecodeError(IoError("end of stream"))`
421        let path = Path::new("data/capacity_fix.wav");
422        let song = Decoder::new().unwrap().decode(path).unwrap();
423        let sample_array = song.samples;
424        assert_eq!(
425            sample_array.len(), // + SAMPLE_RATE as usize,
426            sample_array.capacity()
427        );
428    }
429
430    const PATH_AND_EXPECTED_ANALYSIS: (&str, [f64; NUMBER_FEATURES]) = (
431        "data/s16_mono_22_5kHz.flac",
432        [
433            0.384_638_9,
434            -0.849_141_,
435            -0.754_810_45,
436            -0.879_074_8,
437            -0.632_582_66,
438            -0.725_895_9,
439            -0.775_738_,
440            -0.814_672_6,
441            0.271_672_6,
442            0.257_790_57,
443            -0.356_619_36,
444            -0.635_786_53,
445            -0.295_936_82,
446            0.064_213_04,
447            0.218_524_58,
448            -0.581_239,
449            -0.946_683_5,
450            -0.948_115_3,
451            -0.982_094_5,
452            -0.959_689_74,
453        ],
454    );
455
456    #[test]
457    fn test_analyze() {
458        let (path, expected_analysis) = PATH_AND_EXPECTED_ANALYSIS;
459        let analysis = Decoder::new()
460            .unwrap()
461            .analyze_path(Path::new(path))
462            .unwrap();
463        for (x, y) in analysis.as_vec().iter().zip(expected_analysis) {
464            assert!(
465                0.01 > (x - y).abs(),
466                "Expected {x} to be within 0.01 of {y}, but it was not"
467            );
468        }
469    }
470}