1use std::{f32::consts::SQRT_2, fs::File, num::NonZeroUsize, time::Duration};
4
5use object_pool::Pool;
6use rubato::{FastFixedIn, Resampler, ResamplerConstructionError};
7use symphonia::{
8 core::{
9 audio::{AudioBufferRef, SampleBuffer, SignalSpec},
10 codecs::{CODEC_TYPE_NULL, DecoderOptions},
11 errors::Error,
12 formats::{FormatOptions, FormatReader},
13 io::{MediaSourceStream, MediaSourceStreamOptions},
14 meta::MetadataOptions,
15 probe::Hint,
16 units,
17 },
18 default::get_probe,
19};
20
21use crate::{ResampledAudio, SAMPLE_RATE, errors::AnalysisError, errors::AnalysisResult};
22
23use super::Decoder;
24
25const MAX_DECODE_RETRIES: usize = 3;
26const CHUNK_SIZE: usize = 4096;
27
28#[doc(hidden)]
30pub struct SymphoniaSource {
31 decoder: Box<dyn symphonia::core::codecs::Decoder>,
32 current_span_offset: usize,
33 format: Box<dyn FormatReader>,
34 total_duration: Option<Duration>,
35 buffer: SampleBuffer<f32>,
36 spec: SignalSpec,
37}
38
39impl SymphoniaSource {
40 pub fn new(mss: MediaSourceStream) -> Result<Self, Error> {
47 Self::init(mss)?.ok_or(Error::DecodeError("No Streams"))
48 }
49
50 fn init(mss: MediaSourceStream) -> symphonia::core::errors::Result<Option<Self>> {
51 let hint = Hint::new();
52 let format_opts = FormatOptions {
53 enable_gapless: true,
54 ..Default::default()
55 };
56 let metadata_opts = MetadataOptions::default();
57 let mut probed_format = get_probe()
58 .format(&hint, mss, &format_opts, &metadata_opts)?
59 .format;
60
61 let Some(stream) = probed_format.default_track() else {
62 return Ok(None);
63 };
64
65 let track = probed_format
67 .tracks()
68 .iter()
69 .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
70 .ok_or(Error::Unsupported("No track with supported codec"))?;
71
72 let track_id = track.id;
73
74 let mut decoder = symphonia::default::get_codecs()
75 .make(&track.codec_params, &DecoderOptions::default())?;
76 let total_duration = stream
77 .codec_params
78 .time_base
79 .zip(stream.codec_params.n_frames)
80 .map(|(base, spans)| base.calc_time(spans).into());
81
82 let mut decode_errors: usize = 0;
83 let decoded_audio = loop {
84 let current_span = probed_format.next_packet()?;
85
86 if current_span.track_id() != track_id {
88 continue;
89 }
90
91 match decoder.decode(¤t_span) {
92 Ok(audio) => break audio,
93 Err(Error::DecodeError(_)) if decode_errors < MAX_DECODE_RETRIES => {
94 decode_errors += 1;
95 }
96 Err(e) => return Err(e),
97 }
98 };
99
100 let spec = decoded_audio.spec().to_owned();
101 let buffer = Self::get_buffer(decoded_audio, spec);
102 Ok(Some(Self {
103 decoder,
104 current_span_offset: 0,
105 format: probed_format,
106 total_duration,
107 buffer,
108 spec,
109 }))
110 }
111
112 #[inline]
113 fn get_buffer(decoded: AudioBufferRef<'_>, spec: SignalSpec) -> SampleBuffer<f32> {
114 let duration = units::Duration::from(decoded.capacity() as u64);
115 let mut buffer = SampleBuffer::<f32>::new(duration, spec);
116 buffer.copy_interleaved_ref(decoded);
117 buffer
118 }
119
120 #[inline]
121 #[must_use]
122 pub const fn total_duration(&self) -> Option<Duration> {
123 self.total_duration
124 }
125
126 #[inline]
127 #[must_use]
128 pub const fn sample_rate(&self) -> u32 {
129 self.spec.rate
130 }
131
132 #[inline]
133 #[must_use]
134 pub fn channels(&self) -> usize {
135 self.spec.channels.count()
136 }
137}
138
139impl Iterator for SymphoniaSource {
140 type Item = f32;
141
142 fn size_hint(&self) -> (usize, Option<usize>) {
143 (
144 self.buffer.samples().len(),
145 self.total_duration.map(|dur| {
146 usize::try_from(
147 (dur.as_secs() + 1)
148 * u64::from(self.spec.rate)
149 * self.spec.channels.count() as u64,
150 )
151 .unwrap_or(usize::MAX)
152 }),
153 )
154 }
155
156 fn next(&mut self) -> Option<Self::Item> {
157 if self.current_span_offset < self.buffer.len() {
158 let sample = self.buffer.samples().get(self.current_span_offset);
159 self.current_span_offset += 1;
160
161 return sample.copied();
162 }
163
164 let mut decode_errors = 0;
165 let decoded = loop {
166 let packet = self.format.next_packet().ok()?;
167 match self.decoder.decode(&packet) {
168 Ok(decoded) if decoded.frames() > 0 => break decoded,
174 Ok(_) => {}
175 Err(Error::DecodeError(_)) if decode_errors < MAX_DECODE_RETRIES => {
176 decode_errors += 1;
177 }
178 Err(_) => return None,
179 }
180 };
181
182 decoded.spec().clone_into(&mut self.spec);
183 self.buffer = Self::get_buffer(decoded, self.spec);
184 self.current_span_offset = 1;
185 self.buffer.samples().first().copied()
186 }
187}
188
189#[allow(clippy::module_name_repetitions)]
190pub struct MecompDecoder<R = FastFixedIn<f32>> {
191 resampler: Pool<Result<R, ResamplerConstructionError>>,
192}
193
194impl MecompDecoder {
195 #[inline]
196 fn generate_resampler() -> Result<FastFixedIn<f32>, ResamplerConstructionError> {
197 FastFixedIn::new(1.0, 10.0, rubato::PolynomialDegree::Cubic, CHUNK_SIZE, 1)
198 }
199
200 #[inline]
206 pub fn new() -> Result<Self, AnalysisError> {
207 let first = Self::generate_resampler()?;
209
210 let pool_size = std::thread::available_parallelism().map_or(1, NonZeroUsize::get);
211 let resampler = Pool::new(pool_size, Self::generate_resampler);
212 resampler.attach(Ok(first));
213
214 Ok(Self { resampler })
215 }
216
217 #[inline]
227 #[doc(hidden)]
228 pub fn into_mono_samples(
229 source: Vec<f32>,
230 num_channels: usize,
231 ) -> Result<Vec<f32>, AnalysisError> {
232 match num_channels {
233 0 => Err(AnalysisError::DecodeError(Error::DecodeError(
235 "The audio source has no channels",
236 ))),
237 1 => Ok(source),
239 2 => Ok(source
241 .chunks_exact(2)
242 .map(|chunk| (chunk[0] + chunk[1]) * SQRT_2 / 2.)
243 .collect()),
244 _ => {
246 log::warn!(
247 "The audio source has more than 2 channels (might be 2.1 or 5.1 surround sound), will collapse to mono by averaging the channels"
248 );
249
250 #[allow(clippy::cast_precision_loss)]
251 let num_channels_f32 = num_channels as f32;
252 let mono_samples = source
253 .chunks_exact(num_channels)
254 .map(|chunk| chunk.iter().sum::<f32>() / num_channels_f32)
255 .collect();
256
257 Ok(mono_samples)
258 }
259 }
260 }
261
262 #[inline]
264 #[doc(hidden)]
265 pub fn resample_mono_samples(
266 &self,
267 mut samples: Vec<f32>,
268 sample_rate: u32,
269 total_duration: Duration,
270 ) -> Result<Vec<f32>, AnalysisError> {
271 if sample_rate == SAMPLE_RATE {
272 samples.shrink_to_fit();
273 return Ok(samples);
274 }
275
276 let mut resampled_frames = Vec::with_capacity(
277 usize::try_from((total_duration.as_secs() + 1) * u64::from(SAMPLE_RATE))
278 .unwrap_or(usize::MAX),
279 );
280
281 let (pool, resampler) = self.resampler.pull(Self::generate_resampler).detach();
282 let mut resampler = resampler?;
283 resampler.set_resample_ratio(f64::from(SAMPLE_RATE) / f64::from(sample_rate), false)?;
284
285 let delay = resampler.output_delay();
286
287 let new_length = samples.len() * SAMPLE_RATE as usize / sample_rate as usize;
288 let mut output_buffer = resampler.output_buffer_allocate(true);
289
290 let sample_chunks = samples.chunks_exact(CHUNK_SIZE);
292 let remainder = sample_chunks.remainder();
293
294 for chunk in sample_chunks {
295 debug_assert!(resampler.input_frames_next() == CHUNK_SIZE);
296
297 let (_, output_written) =
298 resampler.process_into_buffer(&[chunk], output_buffer.as_mut_slice(), None)?;
299 resampled_frames.extend_from_slice(&output_buffer[0][..output_written]);
300 }
301
302 if !remainder.is_empty() {
304 let (_, output_written) = resampler.process_partial_into_buffer(
305 Some(&[remainder]),
306 output_buffer.as_mut_slice(),
307 None,
308 )?;
309 resampled_frames.extend_from_slice(&output_buffer[0][..output_written]);
310 }
311
312 if resampled_frames.len() < new_length + delay {
314 let (_, output_written) = resampler.process_partial_into_buffer(
315 Option::<&[&[f32]]>::None,
316 output_buffer.as_mut_slice(),
317 None,
318 )?;
319 resampled_frames.extend_from_slice(&output_buffer[0][..output_written]);
320 }
321
322 resampler.reset();
323 pool.attach(Ok(resampler));
324
325 Ok(resampled_frames[delay..new_length + delay].to_vec())
326 }
327}
328
329impl Decoder for MecompDecoder {
330 #[allow(clippy::missing_inline_in_public_items)]
336 fn decode(&self, path: &std::path::Path) -> AnalysisResult<ResampledAudio> {
337 let file = File::open(path)?;
339 let mss = MediaSourceStream::new(Box::new(file), MediaSourceStreamOptions::default());
341
342 let source = SymphoniaSource::new(mss)?;
343
344 let sample_rate = source.spec.rate;
346 let Some(total_duration) = source.total_duration else {
347 return Err(AnalysisError::IndeterminantDuration);
348 };
349 let num_channels = source.channels();
350
351 let mono_sample_array =
352 Self::into_mono_samples(source.into_iter().collect(), num_channels)?;
353
354 let resampled_array =
356 self.resample_mono_samples(mono_sample_array, sample_rate, total_duration)?;
357
358 Ok(ResampledAudio {
359 path: path.to_owned(),
360 samples: resampled_array,
361 })
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use crate::NUMBER_FEATURES;
368
369 use super::{Decoder as DecoderTrait, MecompDecoder as Decoder};
370 use adler32::RollingAdler32;
371 use pretty_assertions::assert_eq;
372 use rstest::rstest;
373 use std::path::Path;
374
375 fn verify_decoding_output(path: &Path, expected_hash: u32) {
376 let decoder = Decoder::new().unwrap();
377 let song = decoder.decode(path).unwrap();
378 let mut hasher = RollingAdler32::new();
379 for sample in &song.samples {
380 hasher.update_buffer(&sample.to_le_bytes());
381 }
382
383 assert_eq!(expected_hash, hasher.hash());
384 }
385
386 #[rstest]
389 #[ignore = "fails when asked to resample to 22050 Hz, ig ffmpeg does it differently, but I'm not sure what the difference actually is"]
390 #[case::resample_stereo(Path::new("data/s32_stereo_44_1_kHz.flac"), 0xbbcb_a1cf)]
391 #[ignore = "fails when asked to resample to 22050 Hz, ig ffmpeg does it differently, but I'm not sure what the difference actually is"]
392 #[case::resample_mono(Path::new("data/s32_mono_44_1_kHz.flac"), 0xa0f8_b8af)]
393 #[case::decode_stereo(Path::new("data/s16_stereo_22_5kHz.flac"), 0x1d7b_2d6d)]
394 #[case::decode_mono(Path::new("data/s16_mono_22_5kHz.flac"), 0x5e01_930b)]
395 #[ignore = "fails when asked to resample to 22050 Hz, ig ffmpeg does it differently, but I'm not sure what the difference actually is"]
396 #[case::resample_mp3(Path::new("data/s32_stereo_44_1_kHz.mp3"), 0x69ca_6906)]
397 #[case::decode_wav(Path::new("data/piano.wav"), 0xde83_1e82)]
398 fn test_decode(#[case] path: &Path, #[case] expected_hash: u32) {
399 verify_decoding_output(path, expected_hash);
400 }
401
402 #[test]
403 fn test_dont_panic_no_channel_layout() {
404 let path = Path::new("data/no_channel.wav");
405 Decoder::new().unwrap().decode(path).unwrap();
406 }
407
408 #[test]
409 fn test_decode_right_capacity_vec() {
410 let path = Path::new("data/s16_mono_22_5kHz.flac");
411 let song = Decoder::new().unwrap().decode(path).unwrap();
412 let sample_array = song.samples;
413 assert_eq!(
414 sample_array.len(), sample_array.capacity()
416 );
417
418 let path = Path::new("data/s32_stereo_44_1_kHz.flac");
419 let song = Decoder::new().unwrap().decode(path).unwrap();
420 let sample_array = song.samples;
421 assert_eq!(
422 sample_array.len(), sample_array.capacity()
424 );
425
426 let path = Path::new("data/capacity_fix.wav");
428 let song = Decoder::new().unwrap().decode(path).unwrap();
429 let sample_array = song.samples;
430 assert_eq!(
431 sample_array.len(), sample_array.capacity()
433 );
434 }
435
436 const PATH_AND_EXPECTED_ANALYSIS: (&str, [f64; NUMBER_FEATURES]) = (
437 "data/s16_mono_22_5kHz.flac",
438 [
439 0.384_638_9,
440 -0.849_141_,
441 -0.754_810_45,
442 -0.879_074_8,
443 -0.632_582_66,
444 -0.725_895_9,
445 -0.775_737_9,
446 -0.814_672_6,
447 0.271_672_6,
448 0.257_790_57,
449 -0.342_925_13,
450 -0.628_034_23,
451 -0.280_950_96,
452 0.086_864_59,
453 0.244_460_82,
454 -0.572_325_7,
455 0.232_920_65,
456 0.199_811_46,
457 -0.585_944_06,
458 -0.067_842_96,
459 -0.060_007_63,
460 -0.584_857_17,
461 -0.078_803_78,
462 ],
463 );
464
465 #[test]
466 fn test_analyze() {
467 let (path, expected_analysis) = PATH_AND_EXPECTED_ANALYSIS;
468 let analysis = Decoder::new()
469 .unwrap()
470 .analyze_path(Path::new(path))
471 .unwrap();
472 for (x, y) in analysis.as_vec().iter().zip(expected_analysis) {
473 assert!(
474 1e-5 > (x - y).abs(),
475 "Expected {x} to be within 1e-5 of {y}, but it was not"
476 );
477 }
478 }
479
480 const RESAMPLED_PATH_AND_EXPECTED_ANALYSIS: (&str, [f64; NUMBER_FEATURES]) = (
481 "data/s32_stereo_44_1_kHz.flac",
482 [
483 0.38463664,
484 -0.85172224,
485 -0.7607465,
486 -0.8857495,
487 -0.63906085,
488 -0.73908424,
489 -0.7890965,
490 -0.8191868,
491 0.33856833,
492 0.3246863,
493 -0.34292227,
494 -0.62803173,
495 -0.2809453,
496 0.08687115,
497 0.2444489,
498 -0.5723239,
499 0.23292565,
500 0.19979525,
501 -0.58593845,
502 -0.06783122,
503 -0.060014784,
504 -0.5848569,
505 -0.07879859,
506 ],
507 );
508
509 #[test]
510 fn test_analyze_resampled() {
511 let (path, expected_analysis) = RESAMPLED_PATH_AND_EXPECTED_ANALYSIS;
512 let analysis = Decoder::new()
513 .unwrap()
514 .analyze_path(Path::new(path))
515 .unwrap();
516
517 for (x, y) in analysis.as_vec().iter().zip(expected_analysis) {
518 assert!(
519 0.1 > (x - y).abs(),
520 "Expected {x} to be within 0.1 of {y}, but it was not"
521 );
522 }
523 }
524}