1use std::{f32::consts::SQRT_2, fs::File, num::NonZeroUsize, time::Duration};
4
5use object_pool::Pool;
6use rubato::{FastFixedIn, Resampler, ResamplerConstructionError};
7use symphonia::{
8 core::{
9 audio::{AudioBufferRef, SampleBuffer, SignalSpec},
10 codecs::{DecoderOptions, CODEC_TYPE_NULL},
11 errors::Error,
12 formats::{FormatOptions, FormatReader},
13 io::{MediaSourceStream, MediaSourceStreamOptions},
14 meta::MetadataOptions,
15 probe::Hint,
16 units,
17 },
18 default::get_probe,
19};
20
21use crate::{errors::AnalysisError, errors::AnalysisResult, ResampledAudio, SAMPLE_RATE};
22
23use super::Decoder;
24
25const MAX_DECODE_RETRIES: usize = 3;
26const CHUNK_SIZE: usize = 4096;
27
28#[doc(hidden)]
30pub struct SymphoniaSource {
31 decoder: Box<dyn symphonia::core::codecs::Decoder>,
32 current_span_offset: usize,
33 format: Box<dyn FormatReader>,
34 total_duration: Option<Duration>,
35 buffer: SampleBuffer<f32>,
36 spec: SignalSpec,
37}
38
39impl SymphoniaSource {
40 pub fn new(mss: MediaSourceStream) -> Result<Self, Error> {
47 Self::init(mss)?.ok_or(Error::DecodeError("No Streams"))
48 }
49
50 fn init(mss: MediaSourceStream) -> symphonia::core::errors::Result<Option<Self>> {
51 let hint = Hint::new();
52 let format_opts = FormatOptions::default();
53 let metadata_opts = MetadataOptions::default();
54 let mut probed_format = get_probe()
55 .format(&hint, mss, &format_opts, &metadata_opts)?
56 .format;
57
58 let Some(stream) = probed_format.default_track() else {
59 return Ok(None);
60 };
61
62 let track = probed_format
64 .tracks()
65 .iter()
66 .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
67 .ok_or(Error::Unsupported("No track with supported codec"))?;
68
69 let track_id = track.id;
70
71 let mut decoder = symphonia::default::get_codecs()
72 .make(&track.codec_params, &DecoderOptions::default())?;
73 let total_duration = stream
74 .codec_params
75 .time_base
76 .zip(stream.codec_params.n_frames)
77 .map(|(base, spans)| base.calc_time(spans).into());
78
79 let mut decode_errors: usize = 0;
80 let decoded_audio = loop {
81 let current_span = probed_format.next_packet()?;
82
83 if current_span.track_id() != track_id {
85 continue;
86 }
87
88 match decoder.decode(¤t_span) {
89 Ok(audio) => break audio,
90 Err(Error::DecodeError(_)) if decode_errors < MAX_DECODE_RETRIES => {
91 decode_errors += 1;
92 }
93 Err(e) => return Err(e),
94 }
95 };
96
97 let spec = decoded_audio.spec().to_owned();
98 let buffer = Self::get_buffer(decoded_audio, spec);
99 Ok(Some(Self {
100 decoder,
101 current_span_offset: 0,
102 format: probed_format,
103 total_duration,
104 buffer,
105 spec,
106 }))
107 }
108
109 #[inline]
110 fn get_buffer(decoded: AudioBufferRef, spec: SignalSpec) -> SampleBuffer<f32> {
111 let duration = units::Duration::from(decoded.capacity() as u64);
112 let mut buffer = SampleBuffer::<f32>::new(duration, spec);
113 buffer.copy_interleaved_ref(decoded);
114 buffer
115 }
116
117 #[inline]
118 #[must_use]
119 pub const fn total_duration(&self) -> Option<Duration> {
120 self.total_duration
121 }
122
123 #[inline]
124 #[must_use]
125 pub const fn sample_rate(&self) -> u32 {
126 self.spec.rate
127 }
128
129 #[inline]
130 #[must_use]
131 pub fn channels(&self) -> usize {
132 self.spec.channels.count()
133 }
134}
135
136impl Iterator for SymphoniaSource {
137 type Item = f32;
138
139 fn size_hint(&self) -> (usize, Option<usize>) {
140 (
141 self.buffer.samples().len(),
142 self.total_duration.map(|dur| {
143 (usize::try_from(dur.as_secs()).unwrap_or(usize::MAX) + 1)
144 * self.spec.rate as usize
145 * self.spec.channels.count()
146 }),
147 )
148 }
149
150 fn next(&mut self) -> Option<Self::Item> {
151 if self.current_span_offset < self.buffer.len() {
152 let sample = self.buffer.samples().get(self.current_span_offset);
153 self.current_span_offset += 1;
154
155 return sample.copied();
156 }
157
158 let mut decode_errors = 0;
159 let decoded = loop {
160 let packet = self.format.next_packet().ok()?;
161 match self.decoder.decode(&packet) {
162 Ok(decoded) if decoded.frames() > 0 => break decoded,
168 Ok(_) => {}
169 Err(Error::DecodeError(_)) if decode_errors < MAX_DECODE_RETRIES => {
170 decode_errors += 1;
171 }
172 Err(_) => return None,
173 }
174 };
175
176 decoded.spec().clone_into(&mut self.spec);
177 self.buffer = Self::get_buffer(decoded, self.spec);
178 self.current_span_offset = 1;
179 self.buffer.samples().first().copied()
180 }
181}
182
183#[allow(clippy::module_name_repetitions)]
184pub struct MecompDecoder<R = FastFixedIn<f32>> {
185 resampler: Pool<Result<R, ResamplerConstructionError>>,
186}
187
188impl MecompDecoder {
189 #[inline]
190 fn generate_resampler() -> Result<FastFixedIn<f32>, ResamplerConstructionError> {
191 FastFixedIn::new(1.0, 10.0, rubato::PolynomialDegree::Cubic, CHUNK_SIZE, 1)
192 }
193
194 #[inline]
200 pub fn new() -> Result<Self, AnalysisError> {
201 let first = Self::generate_resampler()?;
203
204 let pool_size = std::thread::available_parallelism().map_or(1, NonZeroUsize::get);
205 let resampler = Pool::new(pool_size, Self::generate_resampler);
206 resampler.attach(Ok(first));
207
208 Ok(Self { resampler })
209 }
210
211 #[inline]
221 #[doc(hidden)]
222 pub fn into_mono_samples(
223 source: Vec<f32>,
224 num_channels: usize,
225 ) -> Result<Vec<f32>, AnalysisError> {
226 match num_channels {
227 0 => Err(AnalysisError::DecodeError(Error::DecodeError(
229 "The audio source has no channels",
230 ))),
231 1 => Ok(source),
233 2 => Ok(source
235 .chunks_exact(2)
236 .map(|chunk| (chunk[0] + chunk[1]) * SQRT_2 / 2.)
237 .collect()),
238 _ => {
240 log::warn!("The audio source has more than 2 channels (might be 2.1 or 5.1 surround sound), will collapse to mono by averaging the channels");
241
242 #[allow(clippy::cast_precision_loss)]
243 let num_channels_f32 = num_channels as f32;
244 let mono_samples = source
245 .chunks_exact(num_channels)
246 .map(|chunk| chunk.iter().sum::<f32>() / num_channels_f32)
247 .collect();
248
249 Ok(mono_samples)
250 }
251 }
252 }
253
254 #[inline]
256 #[doc(hidden)]
257 pub fn resample_mono_samples(
258 &self,
259 mut samples: Vec<f32>,
260 sample_rate: u32,
261 total_duration: Duration,
262 ) -> Result<Vec<f32>, AnalysisError> {
263 if sample_rate == SAMPLE_RATE {
264 samples.shrink_to_fit();
265 return Ok(samples);
266 }
267
268 let mut resampled_frames = Vec::with_capacity(
269 (usize::try_from(total_duration.as_secs()).unwrap_or(usize::MAX) + 1)
270 * SAMPLE_RATE as usize,
271 );
272
273 let (pool, resampler) = self.resampler.pull(Self::generate_resampler).detach();
274 let mut resampler = resampler?;
275 resampler.set_resample_ratio(f64::from(SAMPLE_RATE) / f64::from(sample_rate), false)?;
276
277 let delay = resampler.output_delay();
278
279 let new_length = samples.len() * SAMPLE_RATE as usize / sample_rate as usize;
280 let mut output_buffer = resampler.output_buffer_allocate(true);
281
282 let sample_chunks = samples.chunks_exact(CHUNK_SIZE);
284 let remainder = sample_chunks.remainder();
285
286 for chunk in sample_chunks {
287 debug_assert!(resampler.input_frames_next() == CHUNK_SIZE);
288
289 let (_, output_written) =
290 resampler.process_into_buffer(&[chunk], output_buffer.as_mut_slice(), None)?;
291 resampled_frames.extend_from_slice(&output_buffer[0][..output_written]);
292 }
293
294 if !remainder.is_empty() {
296 let (_, output_written) = resampler.process_partial_into_buffer(
297 Some(&[remainder]),
298 output_buffer.as_mut_slice(),
299 None,
300 )?;
301 resampled_frames.extend_from_slice(&output_buffer[0][..output_written]);
302 }
303
304 if resampled_frames.len() < new_length + delay {
306 let (_, output_written) = resampler.process_partial_into_buffer(
307 Option::<&[&[f32]]>::None,
308 output_buffer.as_mut_slice(),
309 None,
310 )?;
311 resampled_frames.extend_from_slice(&output_buffer[0][..output_written]);
312 }
313
314 resampler.reset();
315 pool.attach(Ok(resampler));
316
317 Ok(resampled_frames[delay..new_length + delay].to_vec())
318 }
319}
320
321impl Decoder for MecompDecoder {
322 #[allow(clippy::missing_inline_in_public_items)]
328 fn decode(&self, path: &std::path::Path) -> AnalysisResult<ResampledAudio> {
329 let file = File::open(path)?;
331 let mss = MediaSourceStream::new(Box::new(file), MediaSourceStreamOptions::default());
333
334 let source = SymphoniaSource::new(mss)?;
335
336 let sample_rate = source.spec.rate;
338 let Some(total_duration) = source.total_duration else {
339 return Err(AnalysisError::IndeterminantDuration);
340 };
341 let num_channels = source.channels();
342
343 let mono_sample_array =
344 Self::into_mono_samples(source.into_iter().collect(), num_channels)?;
345
346 let resampled_array =
348 self.resample_mono_samples(mono_sample_array, sample_rate, total_duration)?;
349
350 Ok(ResampledAudio {
351 path: path.to_owned(),
352 samples: resampled_array,
353 })
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use crate::NUMBER_FEATURES;
360
361 use super::{Decoder as DecoderTrait, MecompDecoder as Decoder};
362 use adler32::RollingAdler32;
363 use pretty_assertions::assert_eq;
364 use rstest::rstest;
365 use std::path::Path;
366
367 fn verify_decoding_output(path: &Path, expected_hash: u32) {
368 let decoder = Decoder::new().unwrap();
369 let song = decoder.decode(path).unwrap();
370 let mut hasher = RollingAdler32::new();
371 for sample in &song.samples {
372 hasher.update_buffer(&sample.to_le_bytes());
373 }
374
375 assert_eq!(expected_hash, hasher.hash());
376 }
377
378 #[rstest]
381 #[ignore = "fails when asked to resample to 22050 Hz, ig ffmpeg does it differently, but I'm not sure what the difference actually is"]
382 #[case::resample_stereo(Path::new("data/s32_stereo_44_1_kHz.flac"), 0xbbcb_a1cf)]
383 #[ignore = "fails when asked to resample to 22050 Hz, ig ffmpeg does it differently, but I'm not sure what the difference actually is"]
384 #[case::resample_mono(Path::new("data/s32_mono_44_1_kHz.flac"), 0xa0f8b8af)]
385 #[case::decode_stereo(Path::new("data/s16_stereo_22_5kHz.flac"), 0x1d7b_2d6d)]
386 #[case::decode_mono(Path::new("data/s16_mono_22_5kHz.flac"), 0x5e01_930b)]
387 #[ignore = "fails when asked to resample to 22050 Hz, ig ffmpeg does it differently, but I'm not sure what the difference actually is"]
388 #[case::resample_mp3(Path::new("data/s32_stereo_44_1_kHz.mp3"), 0x69ca_6906)]
389 #[case::decode_wav(Path::new("data/piano.wav"), 0xde83_1e82)]
390 fn test_decode(#[case] path: &Path, #[case] expected_hash: u32) {
391 verify_decoding_output(path, expected_hash);
392 }
393
394 #[test]
395 fn test_dont_panic_no_channel_layout() {
396 let path = Path::new("data/no_channel.wav");
397 Decoder::new().unwrap().decode(path).unwrap();
398 }
399
400 #[test]
401 fn test_decode_right_capacity_vec() {
402 let path = Path::new("data/s16_mono_22_5kHz.flac");
403 let song = Decoder::new().unwrap().decode(path).unwrap();
404 let sample_array = song.samples;
405 assert_eq!(
406 sample_array.len(), sample_array.capacity()
408 );
409
410 let path = Path::new("data/s32_stereo_44_1_kHz.flac");
411 let song = Decoder::new().unwrap().decode(path).unwrap();
412 let sample_array = song.samples;
413 assert_eq!(
414 sample_array.len(), sample_array.capacity()
416 );
417
418 let path = Path::new("data/capacity_fix.wav");
420 let song = Decoder::new().unwrap().decode(path).unwrap();
421 let sample_array = song.samples;
422 assert_eq!(
423 sample_array.len(), sample_array.capacity()
425 );
426 }
427
428 const PATH_AND_EXPECTED_ANALYSIS: (&str, [f64; NUMBER_FEATURES]) = (
429 "data/s16_mono_22_5kHz.flac",
430 [
431 0.3846389,
432 -0.849141,
433 -0.75481045,
434 -0.8790748,
435 -0.63258266,
436 -0.7258959,
437 -0.775738,
438 -0.8146726,
439 0.2716726,
440 0.25779057,
441 -0.35661936,
442 -0.63578653,
443 -0.29593682,
444 0.06421304,
445 0.21852458,
446 -0.581239,
447 -0.9466835,
448 -0.9481153,
449 -0.9820945,
450 -0.95968974,
451 ],
452 );
453
454 #[test]
455 fn test_analyze() {
456 let (path, expected_analysis) = PATH_AND_EXPECTED_ANALYSIS;
457 let analysis = Decoder::new()
458 .unwrap()
459 .analyze_path(Path::new(path))
460 .unwrap();
461 for (x, y) in analysis.as_vec().iter().zip(expected_analysis) {
462 assert!(0.01 > (x - y).abs());
463 }
464 }
465}