1use std::{f32::consts::SQRT_2, fs::File, num::NonZeroUsize, time::Duration};
4
5use object_pool::Pool;
6use rubato::{FastFixedIn, Resampler, ResamplerConstructionError};
7use symphonia::{
8 core::{
9 audio::{AudioBufferRef, SampleBuffer, SignalSpec},
10 codecs::{CODEC_TYPE_NULL, DecoderOptions},
11 errors::Error,
12 formats::{FormatOptions, FormatReader},
13 io::{MediaSourceStream, MediaSourceStreamOptions},
14 meta::MetadataOptions,
15 probe::Hint,
16 units,
17 },
18 default::get_probe,
19};
20
21use crate::{ResampledAudio, SAMPLE_RATE, errors::AnalysisError, errors::AnalysisResult};
22
23use super::Decoder;
24
25const MAX_DECODE_RETRIES: usize = 3;
26const CHUNK_SIZE: usize = 4096;
27
28#[doc(hidden)]
30pub struct SymphoniaSource {
31 decoder: Box<dyn symphonia::core::codecs::Decoder>,
32 current_span_offset: usize,
33 format: Box<dyn FormatReader>,
34 total_duration: Option<Duration>,
35 buffer: SampleBuffer<f32>,
36 spec: SignalSpec,
37}
38
39impl SymphoniaSource {
40 pub fn new(mss: MediaSourceStream) -> Result<Self, Error> {
47 Self::init(mss)?.ok_or(Error::DecodeError("No Streams"))
48 }
49
50 fn init(mss: MediaSourceStream) -> symphonia::core::errors::Result<Option<Self>> {
51 let hint = Hint::new();
52 let format_opts = FormatOptions::default();
53 let metadata_opts = MetadataOptions::default();
54 let mut probed_format = get_probe()
55 .format(&hint, mss, &format_opts, &metadata_opts)?
56 .format;
57
58 let Some(stream) = probed_format.default_track() else {
59 return Ok(None);
60 };
61
62 let track = probed_format
64 .tracks()
65 .iter()
66 .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
67 .ok_or(Error::Unsupported("No track with supported codec"))?;
68
69 let track_id = track.id;
70
71 let mut decoder = symphonia::default::get_codecs()
72 .make(&track.codec_params, &DecoderOptions::default())?;
73 let total_duration = stream
74 .codec_params
75 .time_base
76 .zip(stream.codec_params.n_frames)
77 .map(|(base, spans)| base.calc_time(spans).into());
78
79 let mut decode_errors: usize = 0;
80 let decoded_audio = loop {
81 let current_span = probed_format.next_packet()?;
82
83 if current_span.track_id() != track_id {
85 continue;
86 }
87
88 match decoder.decode(¤t_span) {
89 Ok(audio) => break audio,
90 Err(Error::DecodeError(_)) if decode_errors < MAX_DECODE_RETRIES => {
91 decode_errors += 1;
92 }
93 Err(e) => return Err(e),
94 }
95 };
96
97 let spec = decoded_audio.spec().to_owned();
98 let buffer = Self::get_buffer(decoded_audio, spec);
99 Ok(Some(Self {
100 decoder,
101 current_span_offset: 0,
102 format: probed_format,
103 total_duration,
104 buffer,
105 spec,
106 }))
107 }
108
109 #[inline]
110 fn get_buffer(decoded: AudioBufferRef, spec: SignalSpec) -> SampleBuffer<f32> {
111 let duration = units::Duration::from(decoded.capacity() as u64);
112 let mut buffer = SampleBuffer::<f32>::new(duration, spec);
113 buffer.copy_interleaved_ref(decoded);
114 buffer
115 }
116
117 #[inline]
118 #[must_use]
119 pub const fn total_duration(&self) -> Option<Duration> {
120 self.total_duration
121 }
122
123 #[inline]
124 #[must_use]
125 pub const fn sample_rate(&self) -> u32 {
126 self.spec.rate
127 }
128
129 #[inline]
130 #[must_use]
131 pub fn channels(&self) -> usize {
132 self.spec.channels.count()
133 }
134}
135
136impl Iterator for SymphoniaSource {
137 type Item = f32;
138
139 fn size_hint(&self) -> (usize, Option<usize>) {
140 (
141 self.buffer.samples().len(),
142 self.total_duration.map(|dur| {
143 (usize::try_from(dur.as_secs()).unwrap_or(usize::MAX) + 1)
144 * self.spec.rate as usize
145 * self.spec.channels.count()
146 }),
147 )
148 }
149
150 fn next(&mut self) -> Option<Self::Item> {
151 if self.current_span_offset < self.buffer.len() {
152 let sample = self.buffer.samples().get(self.current_span_offset);
153 self.current_span_offset += 1;
154
155 return sample.copied();
156 }
157
158 let mut decode_errors = 0;
159 let decoded = loop {
160 let packet = self.format.next_packet().ok()?;
161 match self.decoder.decode(&packet) {
162 Ok(decoded) if decoded.frames() > 0 => break decoded,
168 Ok(_) => {}
169 Err(Error::DecodeError(_)) if decode_errors < MAX_DECODE_RETRIES => {
170 decode_errors += 1;
171 }
172 Err(_) => return None,
173 }
174 };
175
176 decoded.spec().clone_into(&mut self.spec);
177 self.buffer = Self::get_buffer(decoded, self.spec);
178 self.current_span_offset = 1;
179 self.buffer.samples().first().copied()
180 }
181}
182
183#[allow(clippy::module_name_repetitions)]
184pub struct MecompDecoder<R = FastFixedIn<f32>> {
185 resampler: Pool<Result<R, ResamplerConstructionError>>,
186}
187
188impl MecompDecoder {
189 #[inline]
190 fn generate_resampler() -> Result<FastFixedIn<f32>, ResamplerConstructionError> {
191 FastFixedIn::new(1.0, 10.0, rubato::PolynomialDegree::Cubic, CHUNK_SIZE, 1)
192 }
193
194 #[inline]
200 pub fn new() -> Result<Self, AnalysisError> {
201 let first = Self::generate_resampler()?;
203
204 let pool_size = std::thread::available_parallelism().map_or(1, NonZeroUsize::get);
205 let resampler = Pool::new(pool_size, Self::generate_resampler);
206 resampler.attach(Ok(first));
207
208 Ok(Self { resampler })
209 }
210
211 #[inline]
221 #[doc(hidden)]
222 pub fn into_mono_samples(
223 source: Vec<f32>,
224 num_channels: usize,
225 ) -> Result<Vec<f32>, AnalysisError> {
226 match num_channels {
227 0 => Err(AnalysisError::DecodeError(Error::DecodeError(
229 "The audio source has no channels",
230 ))),
231 1 => Ok(source),
233 2 => Ok(source
235 .chunks_exact(2)
236 .map(|chunk| (chunk[0] + chunk[1]) * SQRT_2 / 2.)
237 .collect()),
238 _ => {
240 log::warn!(
241 "The audio source has more than 2 channels (might be 2.1 or 5.1 surround sound), will collapse to mono by averaging the channels"
242 );
243
244 #[allow(clippy::cast_precision_loss)]
245 let num_channels_f32 = num_channels as f32;
246 let mono_samples = source
247 .chunks_exact(num_channels)
248 .map(|chunk| chunk.iter().sum::<f32>() / num_channels_f32)
249 .collect();
250
251 Ok(mono_samples)
252 }
253 }
254 }
255
256 #[inline]
258 #[doc(hidden)]
259 pub fn resample_mono_samples(
260 &self,
261 mut samples: Vec<f32>,
262 sample_rate: u32,
263 total_duration: Duration,
264 ) -> Result<Vec<f32>, AnalysisError> {
265 if sample_rate == SAMPLE_RATE {
266 samples.shrink_to_fit();
267 return Ok(samples);
268 }
269
270 let mut resampled_frames = Vec::with_capacity(
271 (usize::try_from(total_duration.as_secs()).unwrap_or(usize::MAX) + 1)
272 * SAMPLE_RATE as usize,
273 );
274
275 let (pool, resampler) = self.resampler.pull(Self::generate_resampler).detach();
276 let mut resampler = resampler?;
277 resampler.set_resample_ratio(f64::from(SAMPLE_RATE) / f64::from(sample_rate), false)?;
278
279 let delay = resampler.output_delay();
280
281 let new_length = samples.len() * SAMPLE_RATE as usize / sample_rate as usize;
282 let mut output_buffer = resampler.output_buffer_allocate(true);
283
284 let sample_chunks = samples.chunks_exact(CHUNK_SIZE);
286 let remainder = sample_chunks.remainder();
287
288 for chunk in sample_chunks {
289 debug_assert!(resampler.input_frames_next() == CHUNK_SIZE);
290
291 let (_, output_written) =
292 resampler.process_into_buffer(&[chunk], output_buffer.as_mut_slice(), None)?;
293 resampled_frames.extend_from_slice(&output_buffer[0][..output_written]);
294 }
295
296 if !remainder.is_empty() {
298 let (_, output_written) = resampler.process_partial_into_buffer(
299 Some(&[remainder]),
300 output_buffer.as_mut_slice(),
301 None,
302 )?;
303 resampled_frames.extend_from_slice(&output_buffer[0][..output_written]);
304 }
305
306 if resampled_frames.len() < new_length + delay {
308 let (_, output_written) = resampler.process_partial_into_buffer(
309 Option::<&[&[f32]]>::None,
310 output_buffer.as_mut_slice(),
311 None,
312 )?;
313 resampled_frames.extend_from_slice(&output_buffer[0][..output_written]);
314 }
315
316 resampler.reset();
317 pool.attach(Ok(resampler));
318
319 Ok(resampled_frames[delay..new_length + delay].to_vec())
320 }
321}
322
323impl Decoder for MecompDecoder {
324 #[allow(clippy::missing_inline_in_public_items)]
330 fn decode(&self, path: &std::path::Path) -> AnalysisResult<ResampledAudio> {
331 let file = File::open(path)?;
333 let mss = MediaSourceStream::new(Box::new(file), MediaSourceStreamOptions::default());
335
336 let source = SymphoniaSource::new(mss)?;
337
338 let sample_rate = source.spec.rate;
340 let Some(total_duration) = source.total_duration else {
341 return Err(AnalysisError::IndeterminantDuration);
342 };
343 let num_channels = source.channels();
344
345 let mono_sample_array =
346 Self::into_mono_samples(source.into_iter().collect(), num_channels)?;
347
348 let resampled_array =
350 self.resample_mono_samples(mono_sample_array, sample_rate, total_duration)?;
351
352 Ok(ResampledAudio {
353 path: path.to_owned(),
354 samples: resampled_array,
355 })
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use crate::NUMBER_FEATURES;
362
363 use super::{Decoder as DecoderTrait, MecompDecoder as Decoder};
364 use adler32::RollingAdler32;
365 use pretty_assertions::assert_eq;
366 use rstest::rstest;
367 use std::path::Path;
368
369 fn verify_decoding_output(path: &Path, expected_hash: u32) {
370 let decoder = Decoder::new().unwrap();
371 let song = decoder.decode(path).unwrap();
372 let mut hasher = RollingAdler32::new();
373 for sample in &song.samples {
374 hasher.update_buffer(&sample.to_le_bytes());
375 }
376
377 assert_eq!(expected_hash, hasher.hash());
378 }
379
380 #[rstest]
383 #[ignore = "fails when asked to resample to 22050 Hz, ig ffmpeg does it differently, but I'm not sure what the difference actually is"]
384 #[case::resample_stereo(Path::new("data/s32_stereo_44_1_kHz.flac"), 0xbbcb_a1cf)]
385 #[ignore = "fails when asked to resample to 22050 Hz, ig ffmpeg does it differently, but I'm not sure what the difference actually is"]
386 #[case::resample_mono(Path::new("data/s32_mono_44_1_kHz.flac"), 0xa0f8_b8af)]
387 #[case::decode_stereo(Path::new("data/s16_stereo_22_5kHz.flac"), 0x1d7b_2d6d)]
388 #[case::decode_mono(Path::new("data/s16_mono_22_5kHz.flac"), 0x5e01_930b)]
389 #[ignore = "fails when asked to resample to 22050 Hz, ig ffmpeg does it differently, but I'm not sure what the difference actually is"]
390 #[case::resample_mp3(Path::new("data/s32_stereo_44_1_kHz.mp3"), 0x69ca_6906)]
391 #[case::decode_wav(Path::new("data/piano.wav"), 0xde83_1e82)]
392 fn test_decode(#[case] path: &Path, #[case] expected_hash: u32) {
393 verify_decoding_output(path, expected_hash);
394 }
395
396 #[test]
397 fn test_dont_panic_no_channel_layout() {
398 let path = Path::new("data/no_channel.wav");
399 Decoder::new().unwrap().decode(path).unwrap();
400 }
401
402 #[test]
403 fn test_decode_right_capacity_vec() {
404 let path = Path::new("data/s16_mono_22_5kHz.flac");
405 let song = Decoder::new().unwrap().decode(path).unwrap();
406 let sample_array = song.samples;
407 assert_eq!(
408 sample_array.len(), sample_array.capacity()
410 );
411
412 let path = Path::new("data/s32_stereo_44_1_kHz.flac");
413 let song = Decoder::new().unwrap().decode(path).unwrap();
414 let sample_array = song.samples;
415 assert_eq!(
416 sample_array.len(), sample_array.capacity()
418 );
419
420 let path = Path::new("data/capacity_fix.wav");
422 let song = Decoder::new().unwrap().decode(path).unwrap();
423 let sample_array = song.samples;
424 assert_eq!(
425 sample_array.len(), sample_array.capacity()
427 );
428 }
429
430 const PATH_AND_EXPECTED_ANALYSIS: (&str, [f64; NUMBER_FEATURES]) = (
431 "data/s16_mono_22_5kHz.flac",
432 [
433 0.384_638_9,
434 -0.849_141_,
435 -0.754_810_45,
436 -0.879_074_8,
437 -0.632_582_66,
438 -0.725_895_9,
439 -0.775_738_,
440 -0.814_672_6,
441 0.271_672_6,
442 0.257_790_57,
443 -0.356_619_36,
444 -0.635_786_53,
445 -0.295_936_82,
446 0.064_213_04,
447 0.218_524_58,
448 -0.581_239,
449 -0.946_683_5,
450 -0.948_115_3,
451 -0.982_094_5,
452 -0.959_689_74,
453 ],
454 );
455
456 #[test]
457 fn test_analyze() {
458 let (path, expected_analysis) = PATH_AND_EXPECTED_ANALYSIS;
459 let analysis = Decoder::new()
460 .unwrap()
461 .analyze_path(Path::new(path))
462 .unwrap();
463 for (x, y) in analysis.as_vec().iter().zip(expected_analysis) {
464 assert!(
465 0.01 > (x - y).abs(),
466 "Expected {x} to be within 0.01 of {y}, but it was not"
467 );
468 }
469 }
470}